In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 23 08:39:11 2021

@author: lpott
"""
import argparse
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

from preprocessing import *
from dataset import *
from metrics import *
from model import *
from utils import bert2dict


In [3]:
torch.cuda.empty_cache()

In [4]:
# variables

read_filename ="ml-1m\\ratings.dat"
read_bert_filename = "bert_sequence_20m.txt"
read_movie_filename = ""#"movies-1m.csv"
size = "1m"

num_epochs = 100
lr = 1e-3
batch_size = 64
reg = 1e-4
train_method = "alternate"


hidden_dim = 256
embedding_dim = 256
bert_dim= 768

freeze_plot = False
tied = False
dropout= 0

k = 10
max_length = 200
min_len = 10

model_type = "feature_add"

In [5]:
# ------------------Data Initialization----------------------#

# convert .dat file to time-sorted pandas dataframe
ml_1m = create_df(read_filename,size=size)

# remove users who have sessions lengths less than min_len
ml_1m = filter_df(ml_1m,item_min=min_len)



  df = pd.read_csv(filename,sep='::',header=None)


user_id        6040
item_id        3706
rating            5
timestamp    458455
dtype: int64
(1000209, 4)
Minimum Session Length: 20
Maximum Session Length: 2314
Average Session Length: 165.60
user_id        6040
item_id        3706
rating            5
timestamp    458455
dtype: int64
(1000209, 4)
Minimum Session Length: 20
Maximum Session Length: 2314
Average Session Length: 165.60


In [6]:
# ------------------Data Initialization----------------------#
if read_movie_filename != "":
    ml_movie_df = create_movie_df(read_movie_filename,size=size)
    ml_movie_df = convert_genres(ml_movie_df)
    
    # initialize reset object
    reset_object = reset_df()
    
    # map all user ids, item ids, and genres to range 0 - number of users/items/genres
    ml_1m,ml_movie_df = reset_object.fit_transform(ml_1m,ml_movie_df)
    
    # value that padded genre tokens shall take
    pad_genre_token = reset_object.genre_enc.transform(["NULL"]).item()
    
    genre_dim = len(np.unique(np.concatenate(ml_movie_df.genre))) - 1

else:
    # initialize reset object
    reset_object = reset_df()
    
    # map all user ids and item ids to range 0 - Number of Users/Items 
    # i.e. [1,7,5] -> [0,2,1]
    ml_1m = reset_object.fit_transform(ml_1m)
    
    pad_genre_token = None
    ml_movie_df = None
    genre_dim = 0



In [7]:
# ------------------Data Initialization----------------------#
# how many unique users, items, ratings and timestamps are there
n_users,n_items,n_ratings,n_timestamp = ml_1m.nunique()

# value that padded tokens shall take
pad_token = n_items

# the output dimension for softmax layer
output_dim = n_items


# get the item id : bert plot embedding dictionary
if bert_dim != 0:
    feature_embed = bert2dict(bert_filename=read_bert_filename)



In [8]:
# create a dictionary of every user's session (history)
# i.e. {user: [user clicks]}
if size == "1m":
    user_history = create_user_history(ml_1m)

elif size == "20m":
    import pickle
    with open('userhistory.pickle', 'rb') as handle:
        user_history = pickle.load(handle)
# create a dictionary of all items a user has not clicked
# i.e. {user: [items not clicked by user]}
# user_noclicks = create_user_noclick(user_history,ml_1m,n_items)

  1%|█                                                                              | 83/6040 [00:00<00:08, 721.94it/s]



100%|█████████████████████████████████████████████████████████████████████████████| 6040/6040 [00:08<00:00, 708.89it/s]


In [9]:
#import pickle

#with open('userhistory.pickle', 'wb') as handle:
#    pickle.dump(user_history, handle, protocol=pickle.HIGHEST_PROTOCOL)

#with open('userhistory.pickle', 'rb') as handle:
#    user_history = pickle.load(handle)

In [10]:
# split data by leave-one-out strategy
# have train dictionary {user: [last 41 items prior to last 2 items in user session]}
# have val dictionary {user: [last 41 items prior to last item in user session]}
# have test dictionary {user: [last 41 items]}
# i.e. if max_length = 4, [1,2,3,4,5,6] -> [1,2,3,4] , [2,3,4,5] , [3,4,5,6]
train_history,val_history,test_history = train_val_test_split(user_history,max_length=max_length)

# initialize the train,validation, and test pytorch dataset objects
# eval pads all items except last token to predict
train_dataset = GRUDataset(train_history,genre_df=ml_movie_df,mode='train',max_length=max_length,pad_token=pad_token,pad_genre_token=pad_genre_token)
val_dataset = GRUDataset(val_history,genre_df=ml_movie_df,mode='eval',max_length=max_length,pad_token=pad_token,pad_genre_token=pad_genre_token)
test_dataset = GRUDataset(test_history,genre_df=ml_movie_df,mode='eval',max_length=max_length,pad_token=pad_token,pad_genre_token=pad_genre_token)

# create the train,validation, and test pytorch dataloader objects
train_dl = DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
val_dl = DataLoader(val_dataset,batch_size=64)
test_dl = DataLoader(test_dataset,batch_size=64)

100%|██████████████████████████████████████████████████████████████████████████| 6040/6040 [00:00<00:00, 147715.18it/s]






In [11]:
print("Bert dim: {:d}".format(bert_dim))
print("Genre dim: {:d}".format(genre_dim))
print("Pad Token: {}".format(pad_token))
print("Pad Genre Token: {}".format(pad_genre_token))

Bert dim: 768
Genre dim: 0
Pad Token: 3706
Pad Genre Token: None


In [12]:
# ------------------Model Initialization----------------------#

# initialize gru4rec model with arguments specified earlier
if model_type == "feature_add":
    model = gru4recF(embedding_dim=embedding_dim,
             hidden_dim=hidden_dim,
             output_dim=output_dim,
             genre_dim=genre_dim,
             batch_first=True,
             max_length=max_length,
             pad_token=pad_token,
             pad_genre_token=pad_genre_token,
             bert_dim=bert_dim,
             tied = tied,
             dropout=dropout)


if model_type == "feature_concat":
    model = gru4recFC(embedding_dim=embedding_dim,
             hidden_dim=hidden_dim,
             output_dim=output_dim,
             genre_dim=genre_dim,
             batch_first=True,
             max_length=max_length,
             pad_token=pad_token,
             pad_genre_token=pad_genre_token,
             bert_dim=bert_dim,
             tied = tied,
             dropout=dropout)

if model_type == "vanilla":
    model = gru4rec_vanilla(hidden_dim=hidden_dim,
                            output_dim=output_dim,
                            batch_first=True,
                            max_length=max_length,
                            pad_token=pad_token)

if model_type =="feature_only":
    model = gru4rec_feature(hidden_dim=hidden_dim,
                            output_dim=output_dim,
                            batch_first=True,
                            max_length=max_length,
                            pad_token=pad_token,bert_dim=bert_dim)


In [13]:
model

gru4recF(
  (movie_embedding): Embedding(3707, 256, padding_idx=3706)
  (plot_embedding): Embedding(3707, 768, padding_idx=3706)
  (plot_projection): Linear(in_features=768, out_features=256, bias=True)
  (encoder_layer): GRU(256, 256, batch_first=True)
  (output_layer): Linear(in_features=256, out_features=3706, bias=True)
)

In [14]:
if bert_dim != 0:
    model.init_weight(reset_object,feature_embed)
    
model = model.cuda()

In [15]:
[name for name,param in model.named_parameters() if (("movie" not in name) or ("plot_embedding" in name) or ("genre" in name))]

['plot_embedding.weight',
 'plot_projection.weight',
 'plot_projection.bias',
 'encoder_layer.weight_ih_l0',
 'encoder_layer.weight_hh_l0',
 'encoder_layer.bias_ih_l0',
 'encoder_layer.bias_hh_l0',
 'output_layer.weight',
 'output_layer.bias']

In [16]:
[name for name,param in model.named_parameters() if ("plot" not in name) and ("genre" not in name)]

['movie_embedding.weight',
 'encoder_layer.weight_ih_l0',
 'encoder_layer.weight_hh_l0',
 'encoder_layer.bias_ih_l0',
 'encoder_layer.bias_hh_l0',
 'output_layer.weight',
 'output_layer.bias']

In [17]:
# initialize Adam optimizer with gru4rec model parameters
if train_method != "normal":
    optimizer_features = torch.optim.Adam([param for name,param in model.named_parameters() if (("movie" not in name) or ("plot_embedding" in name) or ("genre" in name)) ],
                                          lr=lr/10,weight_decay=reg)
    
    optimizer_ids = torch.optim.Adam([param for name,param in model.named_parameters() if ("plot" not in name) and ("genre" not in name)],
                                     lr=lr,weight_decay=reg)

elif train_method == "normal":
    optimizer = torch.optim.Adam(model.parameters(),lr=lr,weight_decay=reg)
    
if freeze_plot and bert_dim !=0:
    model.plot_embedding.weight.requires_grad = False

In [18]:
loss_fn = nn.CrossEntropyLoss(ignore_index=n_items)
#Recall_Object = Recall_E_prob(ml_1m,user_history,n_users,n_items,k=k)
#Recall_Object = Recall_E_Noprob(ml_1m,user_history,n_users,n_items,k=k)

In [19]:
Recall_Object = Recall_E_prob(ml_1m,user_history,n_users,n_items,k=k)



In [20]:
#print("Baseline POP results: ",Recall_Object.popular_baseline())

In [21]:
#training_hit = Recall_Object(model,train_dl)
#validation_hit = Recall_Object(model,val_dl)
#testing_hit = Recall_Object(model,test_dl)
#print("Training Hits@{:d}: {:.2f}".format(k,training_hit))
#print("Validation Hits@{:d}: {:.2f}".format(k,validation_hit))
#print("Testing Hits@{:d}: {:.2f}".format(k,testing_hit))

In [None]:
# ------------------Training Initialization----------------------#
max_train_hit = 0
max_val_hit = 0
max_test_hit = 0

for epoch in range(num_epochs):
    print("="*20,"Epoch {}".format(epoch+1),"="*20)
    
    model.train()  
    
    running_loss = 0

    for j,data in enumerate(tqdm(train_dl,position=0,leave=True)):
        
        if train_method != "normal":
            optimizer_features.zero_grad()
            optimizer_ids.zero_grad()
            
        elif train_method == "normal": 
            optimizer.zero_grad()
        
        if genre_dim != 0:            
            inputs,genre_inputs,labels,x_lens,uid = data
            outputs = model(x=inputs.cuda(),x_lens=x_lens.squeeze().tolist(),x_genre=genre_inputs.cuda())
        
        else:
            inputs,labels,x_lens,uid = data
            outputs = model(x=inputs.cuda(),x_lens=x_lens.squeeze().tolist())
       
        if tied:
            outputs_ignore_pad = outputs[:,:,:-1]
            loss = loss_fn(outputs_ignore_pad.view(-1,outputs_ignore_pad.size(-1)),labels.view(-1).cuda())
            
        else:
            loss = loss_fn(outputs.view(-1,outputs.size(-1)),labels.view(-1).cuda())
            
        loss.backward()
        
        if train_method != "normal":
            if train_method == "interleave":
                # interleave on the epochs
                if (j+1) % 2 == 0:
                    optimizer_features.step()
                else:
                    optimizer_ids.step()

            elif train_method == "alternate":
                if (epoch+1) % 2 == 0:
                    optimizer_features.step()
                else:
                    optimizer_ids.step()

        elif train_method == "normal":
            optimizer.step()

        running_loss += loss.detach().cpu().item()

    del outputs
    torch.cuda.empty_cache()
    training_hit = Recall_Object(model,train_dl,"train")
    validation_hit = Recall_Object(model,val_dl,"validation")
    testing_hit = Recall_Object(model,test_dl,"test")
    
    if max_val_hit < validation_hit:
        max_val_hit = validation_hit
        max_test_hit = testing_hit
        max_train_hit = training_hit
    
    torch.cuda.empty_cache()
    print("Training CE Loss: {:.5f}".format(running_loss/len(train_dl)))
    print("Training Hits@{:d}: {:.2f}".format(k,training_hit))
    print("Validation Hits@{:d}: {:.2f}".format(k,validation_hit))
    print("Testing Hits@{:d}: {:.2f}".format(k,testing_hit))


print("="*100)
print("Maximum Training Hit@{:d}: {:.2f}".format(k,max_train_hit))
print("Maximum Validation Hit@{:d}: {:.2f}".format(k,max_val_hit))
print("Maximum Testing Hit@{:d}: {:.2f}".format(k,max_test_hit))

  0%|                                                                                           | 0/95 [00:00<?, ?it/s]



100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 11.00it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:10,  9.05it/s]

Training CE Loss: 7.46793
Training Hits@10: 27.81
Validation Hits@10: 26.19
Testing Hits@10: 25.84


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 10.86it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:13,  6.73it/s]

Training CE Loss: 7.03688
Training Hits@10: 31.11
Validation Hits@10: 29.83
Testing Hits@10: 28.56


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:09<00:00, 10.31it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:11,  8.29it/s]

Training CE Loss: 6.60358
Training Hits@10: 52.75
Validation Hits@10: 50.05
Testing Hits@10: 46.97


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 10.79it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:11,  8.29it/s]

Training CE Loss: 6.26038
Training Hits@10: 54.59
Validation Hits@10: 52.27
Testing Hits@10: 48.64


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 10.91it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:12,  7.44it/s]

Training CE Loss: 6.14068
Training Hits@10: 61.14
Validation Hits@10: 58.33
Testing Hits@10: 54.74


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:10<00:00,  9.10it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:10,  9.00it/s]

Training CE Loss: 5.98321
Training Hits@10: 62.45
Validation Hits@10: 58.87
Testing Hits@10: 54.77


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 10.87it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.93694
Training Hits@10: 65.84
Validation Hits@10: 62.12
Testing Hits@10: 58.48


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 10.83it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:13,  7.07it/s]

Training CE Loss: 5.82231
Training Hits@10: 66.19
Validation Hits@10: 62.35
Testing Hits@10: 58.08


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:10<00:00,  8.89it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:11,  8.27it/s]

Training CE Loss: 5.80980
Training Hits@10: 67.40
Validation Hits@10: 63.64
Testing Hits@10: 58.99


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 10.82it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:10,  9.20it/s]

Training CE Loss: 5.72095
Training Hits@10: 68.34
Validation Hits@10: 64.62
Testing Hits@10: 59.98


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 10.99it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:11,  8.25it/s]

Training CE Loss: 5.72303
Training Hits@10: 68.91
Validation Hits@10: 64.67
Testing Hits@10: 60.35


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:09<00:00,  9.60it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:10,  9.05it/s]

Training CE Loss: 5.64546
Training Hits@10: 69.80
Validation Hits@10: 65.25
Testing Hits@10: 60.93


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 11.01it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:10,  8.98it/s]

Training CE Loss: 5.65823
Training Hits@10: 70.28
Validation Hits@10: 65.55
Testing Hits@10: 61.67


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 10.85it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:11,  8.28it/s]

Training CE Loss: 5.58959
Training Hits@10: 70.50
Validation Hits@10: 65.65
Testing Hits@10: 61.21


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 11.02it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:11,  8.29it/s]

Training CE Loss: 5.60875
Training Hits@10: 71.57
Validation Hits@10: 65.89
Testing Hits@10: 62.12


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 10.78it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:12,  7.50it/s]

Training CE Loss: 5.54207
Training Hits@10: 71.62
Validation Hits@10: 66.64
Testing Hits@10: 62.48


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:09<00:00,  9.56it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.87it/s]

Training CE Loss: 5.56858
Training Hits@10: 71.39
Validation Hits@10: 66.06
Testing Hits@10: 62.27


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 10.69it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.91it/s]

Training CE Loss: 5.50177
Training Hits@10: 72.81
Validation Hits@10: 66.84
Testing Hits@10: 62.37


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:08<00:00, 10.95it/s]
