In [None]:
import os
import pickle
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import torch
import torch.nn as nn

from sklearn.preprocessing import LabelEncoder

In [None]:
seed=22232
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
def load_obj(filepath):
    obj=None
    with open(filepath, 'rb') as file:
        obj = pickle.load(file)
    return obj

def save_obj(obj, filepath):
    with open(filepath, 'wb') as file:
        pickle.dump(obj, file)

# loading cooccurance matrices

In [None]:
%%time
prod_group_cooccur_matrix = np.load("../input/hm-cooccurance-matrix-dataset/prod_group_cooccur_matrix.npy")
prod_type_cooccur_matrix = np.load("../input/hm-cooccurance-matrix-dataset/prod_type_cooccur_matrix.npy")
garment_cooccur_matrix = np.load("../input/hm-cooccurance-matrix-dataset/garment_cooccur_matrix.npy")
section_cooccur_matrix = np.load("../input/hm-cooccurance-matrix-dataset/section_cooccur_matrix.npy")



In [None]:
%%time
train_df = pd.read_pickle("../input/hmdataset-candidate-generation/train_df.pkl")
article_df = pd.read_pickle("../input/hmdataset-candidate-generation/article_df.pkl")
article_info = load_obj("../input/hmdataset-candidate-generation/article_info.pkl")

train_df['num_unique_weeks'] = train_df.week.apply(lambda lst: len(set(lst)))
train_df.head()

In [None]:
%%time
candidate_map = np.load("../input/hm-train-candidate-for-ranking/candidate_map.npy")
print("candidate_map:", candidate_map.shape)

# Config

In [None]:
class config:
    NUM_EPOCHS = 20
    BATCH_SIZE = 1024
    VAL_BATCH_SIZE = 128
    
    NUM_NEGATIVE_SAMPLES = 200
    MAX_TARGETS = 10
    MAX_SEQ_LEN = 64
    
    NUM_ARTICLES = article_df.article_id.nunique()
    NUM_PROD_TYPES = article_df.product_type_name.nunique()
    NUM_PROD_GROUPS = article_df.product_group_name.nunique()
    NUM_SECTIONS = article_df.section_name.nunique()
    NUM_GARMENTS = article_df.garment_group_name.nunique()
    NUM_SALE_CHANNELS = 2
    
    PAD_ARTICLE_ID = article_df.article_id.nunique()
    PAD_PROD_TYPE = article_df.product_type_name.nunique()
    PAD_PROD_GROUP = article_df.product_group_name.nunique()
    PAD_SECTION = article_df.section_name.nunique()
    PAD_GARMENT_GROUP = article_df.garment_group_name.nunique()
    PAD_SALES_CHANNEL = 2

In [None]:
padding_map = {
    'price': 0,
    "sales_channel_id": 2,
    'article_id': config.PAD_ARTICLE_ID,
    'product_type_name': config.PAD_PROD_TYPE,
    'product_group_name': config.PAD_PROD_GROUP,
    'section_name':config.PAD_SECTION,
    'garment_group_name': config.PAD_GARMENT_GROUP
}

# Dataset

In [None]:
class HMDataset(torch.utils.data.Dataset):
    def __init__(self, df, phase='train'):
        self.df = df
        self.phase = phase
    
    def get_article_metadata(self, article_ids):
        product_type_name = [ article_info[article_id]['product_type_name'] for article_id in article_ids]
        product_group_name = [ article_info[article_id]['product_group_name'] for article_id in article_ids]
        section_name = [ article_info[article_id]['section_name'] for article_id in article_ids]
        garment_group_name = [ article_info[article_id]['garment_group_name'] for article_id in article_ids]
        
        product_type_name = product_type_name[-config.MAX_SEQ_LEN:]
        product_group_name = product_group_name[-config.MAX_SEQ_LEN:]
        section_name = section_name[-config.MAX_SEQ_LEN:]
        garment_group_name = garment_group_name[-config.MAX_SEQ_LEN:]
        
        return (product_type_name, product_group_name, section_name, garment_group_name)
    
    def get_price_details(self, prices):
        price_feats = np.zeros(15)
        
        min_price = np.min(prices)
        max_price = np.max(prices)
        mean_price = np.mean(prices)
        price_diff1 = np.abs(mean_price - min_price)
        price_diff2 = np.abs(max_price - mean_price)
        
        price_feats[0] = min_price
        price_feats[1] = max_price
        price_feats[2] = mean_price
        price_feats[3] = price_diff1
        price_feats[4] = price_diff2
        
        price_feats[5] = np.sqrt(min_price)
        price_feats[6] = np.sqrt(max_price)
        price_feats[7] = np.sqrt(mean_price)
        price_feats[8] = np.sqrt(price_diff1)
        price_feats[9] = np.sqrt(price_diff2)
        
        price_feats[10] = np.square(min_price)
        price_feats[11] = np.square(max_price)
        price_feats[12] = np.square(mean_price)
        price_feats[13] = np.square(price_diff1)
        price_feats[14] = np.square(price_diff2)
        
        price_feats = price_feats+1e-9
        return price_feats
    
    
    def get_popularity_details(self, row):
        week = row.week[-config.MAX_SEQ_LEN:]
        article_ids = row.article_id[-config.MAX_SEQ_LEN:]
        seqlen = len(week)
        popularities = []
        
        for i in range(seqlen):
            popularities.append( article_info[ article_ids[i] ][ week[i] ]['normalized_popularity'] )
        
        min_popularity = np.min(popularities)
        max_popularity = np.max(popularities)
        mean_popularity = np.mean(popularities)
        popularity_diff1 = np.abs(mean_popularity - min_popularity)
        popularity_diff2 = np.abs(max_popularity - mean_popularity)
        
        popularity_feats=np.zeros(15)
        popularity_feats[0] = min_popularity
        popularity_feats[1] = max_popularity
        popularity_feats[2] = mean_popularity
        popularity_feats[3] = popularity_diff1
        popularity_feats[4] = popularity_diff2
        
        popularity_feats[5] = np.sqrt(min_popularity)
        popularity_feats[6] = np.sqrt(max_popularity)
        popularity_feats[7] = np.sqrt(mean_popularity)
        popularity_feats[8] = np.sqrt(popularity_diff1)
        popularity_feats[9] = np.sqrt(popularity_diff2)
        
        popularity_feats[10] = np.square(min_popularity)
        popularity_feats[11] = np.square(max_popularity)
        popularity_feats[12] = np.square(mean_popularity)
        popularity_feats[13] = np.square(popularity_diff1)
        popularity_feats[14] = np.square(popularity_diff2)
        
        popularity_feats = popularity_feats+1e-9
        return popularity_feats
    
    def get_cooccurrence_feats(self, weeks, article_ids, last_week):
        week_diff = np.array(weeks) - last_week
        week_diff = np.clip(week_diff, 0, 38)
        
        seqlen = len(week_diff)
        product_type_name = [ article_info[article_id]['product_type_name'] for article_id in article_ids]
        product_group_name = [ article_info[article_id]['product_group_name'] for article_id in article_ids]
        section_name = [ article_info[article_id]['section_name'] for article_id in article_ids]
        garment_group_name = [ article_info[article_id]['garment_group_name'] for article_id in article_ids]
        
        
        prodtype_feats=np.zeros(config.NUM_PROD_TYPES)
        prodgroup_feats = np.zeros(config.NUM_PROD_GROUPS)
        section_feats = np.zeros(config.NUM_SECTIONS)
        garment_feats = np.zeros(config.NUM_GARMENTS)
        
        for i in range(len(week_diff)):
            w = week_diff[i]
            week_product_type = product_type_name[i]
            week_product_group = product_group_name[i]
            week_section_name = section_name[i]
            week_garment = garment_group_name[i]
            
            
            prodtype_feats += prod_type_cooccur_matrix[week_product_type, w, :]
            prodgroup_feats += prod_group_cooccur_matrix[week_product_group, w, :]
            section_feats += section_cooccur_matrix[week_section_name, w, :]
            garment_feats += garment_cooccur_matrix[week_garment, w, :]
        
        
        prodtype_feats = prodtype_feats/max(1, prodtype_feats.sum())
        prodgroup_feats = prodgroup_feats/max(1, prodgroup_feats.sum())
        section_feats = section_feats/max(1, section_feats.sum())
        garment_feats = garment_feats/max(1, garment_feats.sum())
        
        cooccur_feats = np.concatenate([prodtype_feats, prodgroup_feats, section_feats , garment_feats])
        return cooccur_feats
        
    
    def get_inputs(self, idx, last_week_idx, last_week):
        row = self.df.iloc[idx]
        article_id = row.article_id.copy()[:last_week_idx]
        price = row.price.copy()[:last_week_idx]
        sales_channel_id = row.sales_channel_id.copy()[:last_week_idx]
        week = row.week.copy()[:last_week_idx]
        
        
        price_feats = self.get_price_details(price)
        popularity_feats = self.get_popularity_details(row)
        cooccurrence_feats = self.get_cooccurrence_feats(week, article_id, last_week)
        (product_type_name, product_group_name, section_name, garment_group_name) = self.get_article_metadata(article_id)
        
        inputs = {
            "sales_channel_id": sales_channel_id,
            'article_id':article_id,
            'product_type_name': product_type_name,
            'product_group_name': product_group_name,
            'section_name':section_name,
            'garment_group_name': garment_group_name,
        }
        return inputs, price_feats, popularity_feats, cooccurrence_feats
    
    def pad_sequence(self, x, max_seqlen, padid):
        seqdiff = max_seqlen - len(x)
        if seqdiff == 0:
            return x
        x = x + [padid]*seqdiff
        return x
    
    def get_targets(self, article_ids, last_week_idx, week):
        next_week_idx = last_week_idx+1
        seqlen = len(week)
        while next_week_idx < seqlen and week[next_week_idx]==week[last_week_idx]:
            next_week_idx+=1
        targets = article_ids[last_week_idx: next_week_idx].copy()
        targets = list(set(targets))
        return targets
        
    def get_inputs_perweekid(self, idx, last_week, week):
        last_week_idx = week.index(last_week)
        if last_week_idx == 0:
            last_week_idx = -1
        
        inputs, price_feats, popularity_feats, cooccurrence_feats = self.get_inputs(idx, last_week_idx, last_week)
        targets = self.get_targets(self.df.iloc[idx]['article_id'], last_week_idx, week)
        for k,v in inputs.items():
            inputs[k] = v[-config.MAX_SEQ_LEN:]
            
        
        seqlen = len(inputs['article_id'])
        for k,v in inputs.items():
            inputs[k] = self.pad_sequence(v, config.MAX_SEQ_LEN, padding_map[k])
        
        inputs['seqlen'] = seqlen
        #Convert To Tensors
        price_feats = torch.tensor(price_feats, dtype=torch.float32)
        popularity_feats = torch.tensor(popularity_feats, dtype=torch.float32)
        cooccurrence_feats = torch.tensor(cooccurrence_feats, dtype=torch.float32)
        for k,v in inputs.items():
            inputs[k] = torch.tensor(v, dtype=torch.long)
        
        targets = targets.copy()
        np.random.seed(np.random.randint(1, 10000))
        np.random.shuffle(targets)
        targets = targets[:config.MAX_TARGETS]
        targets = targets + [config.PAD_ARTICLE_ID]*(config.MAX_TARGETS - len(targets))
        targets = torch.tensor(targets, dtype=torch.long)
        return (inputs, targets, price_feats, popularity_feats, cooccurrence_feats)
    
    def get_train_inputs(self, idx):
        row  = self.df.iloc[idx]
        user_id = row.customer_id
        
        week = row.week[-config.MAX_SEQ_LEN:].copy()
        num_weeks = len(set(week))
        unique_weeks = sorted(set(week), reverse=True)
        
        all_inputs = []
        all_targets = []
        all_prices = []
        all_popularities = []
        all_cooccurences=[]
        all_trainable = []
        
        negsamples = candidate_map[user_id]
        for _ in range(1):
            is_trainable = torch.tensor(0.0, dtype=torch.long)
            cur_inputs = {
                'seqlen': torch.tensor(0, dtype=torch.long),
                "sales_channel_id": torch.full((config.MAX_SEQ_LEN, ), config.PAD_SALES_CHANNEL),
                'article_id': torch.full((config.MAX_SEQ_LEN, ), config.PAD_ARTICLE_ID),
                'product_type_name': torch.full((config.MAX_SEQ_LEN, ), config.PAD_PROD_TYPE),
                'product_group_name': torch.full((config.MAX_SEQ_LEN, ), config.PAD_PROD_GROUP),
                'section_name': torch.full((config.MAX_SEQ_LEN, ), config.PAD_SECTION),
                'garment_group_name': torch.full((config.MAX_SEQ_LEN, ), config.PAD_GARMENT_GROUP)
            }
            cur_target = torch.full((config.MAX_TARGETS, ), config.PAD_ARTICLE_ID)
            cur_price_feats = torch.zeros(15)
            cur_popularity_feats = torch.zeros(15)
            cur_cooccurence_feats = torch.zeros(217)
            
            all_inputs.append(cur_inputs)
            all_prices.append(cur_price_feats)
            all_popularities.append(cur_popularity_feats)
            all_cooccurences.append(cur_cooccurence_feats)
            all_targets.append(cur_target)
            all_trainable.append(is_trainable)
            
        
        #if num_weeks<=10:
        (inputs, targets, price_feats, popularity_feats, cooccurrence_feats) = self.get_inputs_perweekid(idx, unique_weeks[-1], week)
        all_inputs[0] = inputs
        all_targets[0] = targets
        all_prices[0] = price_feats
        all_popularities[0] = popularity_feats
        all_cooccurences[0] = cooccurrence_feats
        all_trainable[0] = torch.tensor(1.0, dtype=torch.long)
            
        #else:
        #    train_set_idx=0
        #    for k in range(num_weeks-1, max(2, num_weeks-2-1), -1):
        #        (inputs, targets, price_feats, popularity_feats, cooccurrence_feats) = self.get_inputs_perweekid(idx, unique_weeks[k], week)
                
        #        all_inputs[train_set_idx] = inputs
        #        all_targets[train_set_idx] = targets
        #        all_prices[train_set_idx] = price_feats
        #        all_popularities[train_set_idx] = popularity_feats
        #        all_cooccurences[train_set_idx] = cooccurrence_feats
        #        all_trainable[train_set_idx] = torch.tensor(1.0, dtype=torch.long)
        #        train_set_idx += 1
        return (all_inputs, all_prices, all_popularities, all_cooccurences, all_targets, all_trainable, negsamples)
    
    
    def get_val_inputs(self, idx):
        row = self.df.iloc[idx]
        user_id = row.customer_id
        
        negsamples = candidate_map[user_id]
        article_id = row.article_id.copy()[-config.MAX_SEQ_LEN:]
        sales_channel_id = row.sales_channel_id.copy()[-config.MAX_SEQ_LEN:]
        price = row.price.copy()[-config.MAX_SEQ_LEN:]
        week = row.week.copy()[-config.MAX_SEQ_LEN:]
        
        price_feats = self.get_price_details(price)
        popularity_feats = self.get_popularity_details(row)
        cooccurrence_feats = self.get_cooccurrence_feats(week, article_id, 0)
        (product_type_name, product_group_name, section_name, garment_group_name) = self.get_article_metadata(article_id)
        
        inputs = {
            "sales_channel_id": sales_channel_id,
            'article_id':article_id,
            'product_type_name': product_type_name,
            'product_group_name': product_group_name,
            'section_name':section_name,
            'garment_group_name': garment_group_name,
        }
        
        seqlen = len(inputs['article_id'])
        for k,v in inputs.items():
            inputs[k] = self.pad_sequence(v, config.MAX_SEQ_LEN, padding_map[k])
        
        inputs['seqlen'] = seqlen
        price_feats = torch.tensor(price_feats, dtype=torch.float32)
        popularity_feats = torch.tensor(popularity_feats, dtype=torch.float32)
        cooccurrence_feats = torch.tensor(cooccurrence_feats, dtype=torch.float32)
        for k,v in inputs.items():
            inputs[k] = torch.tensor(v, dtype=torch.long)
        
        
        negsamples = torch.tensor(negsamples, dtype=torch.long)
        targets = row.val_article_ids
        targets = list(set(targets.copy()))
        targets = targets[:config.MAX_TARGETS]
        targets = targets + [config.PAD_ARTICLE_ID]*(config.MAX_TARGETS - len(targets))
        targets = torch.tensor(targets, dtype=torch.long)
        return (inputs, targets, price_feats, popularity_feats, cooccurrence_feats, negsamples)
        
    def __getitem__(self, idx):
        if self.phase == 'train':
            (all_inputs, all_prices, all_popularities, all_cooccurences, all_targets, all_trainable, negsamples) = self.get_train_inputs(idx)
            negsamples = torch.tensor(negsamples, dtype=torch.long)
            return (all_inputs, all_prices, all_popularities, all_cooccurences, all_targets, all_trainable, negsamples)
        else:
            (inputs, targets, price_feats, popularity_feats, cooccurrence_feats, negsamples) = self.get_val_inputs(idx)
            return (inputs, price_feats, popularity_feats, cooccurrence_feats, negsamples, targets)
    
    def __len__(self):
        return len(self.df)

# model

In [None]:
class Embedding(nn.Module):
    def __init__(self):
        super(Embedding, self).__init__()
        self.article_embeddings = nn.Embedding(1+config.NUM_ARTICLES, 256 , padding_idx=config.PAD_ARTICLE_ID)
        self.prod_type_embeddings = nn.Embedding(1+config.NUM_PROD_TYPES, 32, padding_idx=config.PAD_PROD_TYPE)
        self.prodgroup_embeddings = nn.Embedding(1+config.NUM_PROD_GROUPS, 32 ,padding_idx=config.PAD_PROD_GROUP)
        
        self.section_embeddings = nn.Embedding(1+config.NUM_SECTIONS, 32, padding_idx=config.PAD_SECTION)
        self.garment_embeddings = nn.Embedding(1+config.NUM_GARMENTS, 32, padding_idx=config.PAD_GARMENT_GROUP)
        self.sales_embeddings   = nn.Embedding(1+config.NUM_SALE_CHANNELS, 32, padding_idx=config.PAD_SALES_CHANNEL)
        
        self.dropout = nn.Dropout(0.1)
    
    def get_article_embeddings(self, article_ids):
        return self.article_embeddings(article_ids)
    
    def forward(self, inputs):
        xarticles = self.article_embeddings(inputs['article_id'])
        xprod_type = self.prod_type_embeddings(inputs['product_type_name'])
        xprod_group = self.prodgroup_embeddings(inputs['product_group_name'])
        xsection = self.section_embeddings(inputs['section_name'])
        xgarment = self.garment_embeddings(inputs['garment_group_name'])
        xsales_channel = self.sales_embeddings(inputs['sales_channel_id'])
        
        x = torch.cat([xarticles, xprod_type, xprod_group, xsection, xgarment, xsales_channel], dim=-1)
        x = self.dropout(x)
        x = torch.sum(x, dim=1)
        return x

In [None]:
class HistoryEncoder(nn.Module):
    def __init__(self):
        super(HistoryEncoder, self).__init__()
        self.mlp = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(446+217, 1024),
            nn.LeakyReLU(),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.2),
            
            nn.Linear(1024, 512),
            nn.LeakyReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.15),
            
            nn.Linear(512, 256),
            nn.LeakyReLU()
        )
        
    def forward(self, x, inputs, price_feats, popularity_feats, cooccur_feats):
        seqlen = inputs['seqlen'].unsqueeze(dim=-1)
        x = x * (1/seqlen)
        x = torch.cat([x, price_feats, popularity_feats, cooccur_feats], dim=-1)
        x = self.mlp(x)
        return x

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.embeddings = Embedding()
        self.history_encoder = HistoryEncoder()
        self.n_sqrt = np.sqrt(256)
    
    def get_article_embeddings(self, article_ids):
        return self.embeddings.article_embeddings(article_ids)
    
    def get_history_embeddings(self, inputs, price_feats, popularity_feats, cooccur_feats):
        x = self.embeddings(inputs)
        h = self.history_encoder(x, inputs, price_feats, popularity_feats, cooccur_feats)
        return h
    
    def get_logits(self, h, v):
        z = h.unsqueeze(dim=1) * v
        z = z.sum(dim=-1)
        z = z/self.n_sqrt
        return z
    
    def forward(self, inputs, price_feats, popularity_feats, cooccur_feats, targets):
        h = self.get_history_embeddings(inputs, price_feats, popularity_feats, cooccur_feats)
        v = self.embeddings.get_article_embeddings(targets)
        z = self.get_logits(h, v)
        return (h, v, z)

# evaluation

In [None]:
def evaluate(val_dataloader, model):
    model.eval()
    total_pos = 0.0
    rank_5 = 0.0
    rank_10 = 0.0
    rank_50 = 0.0
    
    n_sqrt = np.sqrt(256)
    for it, (inputs, price_feats, popularity_feats, cooccurrence_feats, negsamples, targets) in enumerate(val_dataloader):
        negsamples = negsamples.to(device)
        targets = targets.to(device)
        for k,v in inputs.items():
            inputs[k] = v.to(device)
        
        price_feats = price_feats.to(device)
        popularity_feats = popularity_feats.to(device)
        cooccurrence_feats = cooccurrence_feats.to(device)
        
        with torch.no_grad():
            u = model.get_article_embeddings(negsamples)
            (h, v, zpos) = model(inputs, price_feats, popularity_feats, cooccurrence_feats, targets)
            
            z = (h.unsqueeze(dim=1) * u).sum(dim=-1)
            z = z/n_sqrt
            for i in range(config.MAX_TARGETS):
                z1 = zpos[:, i].unsqueeze(dim=-1)
                label = targets[:, i]
                z1 = z1[label!=config.PAD_ARTICLE_ID]
                
                score_diff = z1 - z[label!=config.PAD_ARTICLE_ID, :]
                score_diff = (score_diff < 0).sum(dim=-1)
                
                rank_5 += (score_diff<=5).sum().item()
                rank_10 += (score_diff<=10).sum().item()
                rank_50 += (score_diff<=50).sum().item()
                total_pos += len(score_diff)
    rank_5 /= total_pos
    rank_10 /= total_pos
    rank_50 /= total_pos
    return rank_5, rank_10, rank_50

# train epoch

In [None]:
def train_ops(is_trainable, inputs, price_feats, popularity_feats, cooccur_feats, targets, negsamples, model):
    is_trainable = is_trainable.to(bool)
    loss=torch.tensor(0.0, device=device)
    if is_trainable.sum().item() <= 5:
        return loss
    
    batch_max_seqlen = inputs['seqlen'].max()
    for k,v in inputs.items():
        if k=='seqlen':
            continue
        inputs[k]=v[:, :batch_max_seqlen]
    for k,v in inputs.items():
        inputs[k]=v[is_trainable].to(device)
    
    price_feats = price_feats[is_trainable].to(device)
    popularity_feats = popularity_feats[is_trainable].to(device)
    cooccur_feats = cooccur_feats[is_trainable].to(device)
    targets = targets[is_trainable].to(device)
    negsamples = negsamples[is_trainable].to(device)
    
    vneg = model.get_article_embeddings(negsamples)
    (h, v, zpos) = model(inputs, price_feats, popularity_feats, cooccur_feats, targets)
    zneg = model.get_logits(h, vneg)

    loss = torch.tensor(0.0, device=device)
    total_sample_count=0
    for i in range(config.MAX_TARGETS):
        pos_score = zpos[:, i].unsqueeze(-1)
        bpe_loss = -torch.log(torch.sigmoid(pos_score - zneg)).mean(dim=-1)
        mask = (targets[:, i]!=config.PAD_ARTICLE_ID)
        loss += bpe_loss[mask].sum()
        total_sample_count+=mask.sum()
    
    loss = loss/total_sample_count
    h_norm = torch.abs(torch.norm(h, dim=-1) - 1).mean()
    loss = loss + 1e-3 * h_norm    
    return loss, h_norm

In [None]:
def train_epoch(train_dataloader, model, optimizer, schedular):
    epoch_loss=[]
    epoch_norm = []
    model.train()
    for it, (all_inputs, all_prices, all_popularities, all_cooccurences, all_targets, all_trainable, negsamples) in enumerate(train_dataloader):
        batch_losses = []
        batch_norms = []
        
        for train_id in range(len(all_inputs)):
            inputs = all_inputs[train_id]
            price_feats = all_prices[train_id]
            popularity_feats = all_popularities[train_id]
            cooccur_feats = all_cooccurences[train_id]
            targets = all_targets[train_id]
            is_trainable = all_trainable[train_id]
            cur_loss,cur_norm = train_ops(is_trainable, inputs, price_feats, popularity_feats, cooccur_feats, targets, negsamples, model)
            
            batch_losses.append(cur_loss)
            batch_norms.append(cur_norm)
        
        loss = (batch_losses[0])# + 0.2 * batch_losses[1])/2
        h_norm = (batch_norms[0])# + batch_norms[1])/2
        
        model.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()
        schedular.step()
        
        epoch_loss.append(loss.item())
        epoch_norm.append(h_norm.item())
        
        if it%200 == 0:
            print("iteration:{} | loss: {:.4f} | hnorm:{:.4f}".format(it, np.mean(epoch_loss), np.mean(epoch_norm) ))
    return np.mean(epoch_loss), np.mean(epoch_norm)

# train setup

In [None]:
train_dataset = HMDataset(train_df)
val_dataset = HMDataset(train_df[train_df.has_val==1], phase='eval')


train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.BATCH_SIZE, 
                                               shuffle=True,
                                               pin_memory=True, drop_last=True)

val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=config.VAL_BATCH_SIZE, 
                                             shuffle=False, pin_memory=True, drop_last=False)

print("number of train iterations:", len(train_dataloader))
print("number of val iteration:", len(val_dataloader))

model = Model().to(device)
optimizer=torch.optim.AdamW(model.parameters(), lr=5e-3, weight_decay=1e-4)
schedular = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    (config.NUM_EPOCHS * len(train_dataloader))+10,
    eta_min = 1e-6
)

# train model

In [None]:
%%time
best_rank5 = None
best_rank10 = None
best_rank50  = None

for e in range(config.NUM_EPOCHS):
    epoch_loss, epoch_norm = train_epoch(train_dataloader, model, optimizer, schedular)
    print("epoch: {} | train loss: {:.4f} | norm:{:.4f}".format(e, epoch_loss, epoch_norm))
    (rank_5, rank_10, rank_50) = evaluate(val_dataloader, model)
    
    if best_rank5 is None or rank_5 > best_rank5:
        best_rank5 = rank_5
        torch.save(model, "model_best_rank_05.pt")
    
    if best_rank10 is None or rank_10 > best_rank10:
        best_rank10 = rank_10
        torch.save(model, "model_best_rank_10.pt")
        
    if best_rank50 is None or rank_50 > best_rank50:
        best_rank50 = rank_50
        torch.save(model, "model_best_rank_50.pt")
    
    torch.save(model, "model.pt")
    print("====================")
    print()
    print("precision_05:{:.4f} | precision_10:{:.4f} | precision_50:{:.4f}".format(rank_5, rank_10, rank_50))
    print()
    print("best rank05:{:.4f} | best rank10:{:.4f} | best rank50:{:.4f}".format(best_rank5, best_rank10, best_rank50))