In [1]:
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 23 08:39:11 2021

@author: lpott
"""
import argparse
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

from preprocessing import *
from dataset import *
from metrics import *
from model import *
from utils import bert2dict


In [2]:
torch.cuda.empty_cache()

In [3]:
# variables

read_filename ="ml-1m\\ratings.dat"
read_bert_filename = "bert_sequence_1m.txt"
read_movie_filename = ""#"movies-1m.csv"
size = "1m"

num_epochs = 50
lr = 1e-3
batch_size = 64
reg = 1e-4
train_method = "normal"


hidden_dim = 256
embedding_dim = 256
bert_dim= 0

k = 10
max_length = 200
min_len = 10

In [4]:
# ------------------Data Initialization----------------------#

# convert .dat file to time-sorted pandas dataframe
ml_1m = create_df(read_filename,size=size)

# remove users who have sessions lengths less than min_len
ml_1m = filter_df(ml_1m,item_min=min_len)



  df = pd.read_csv(filename,sep='::',header=None)


user_id        6040
item_id        3706
rating            5
timestamp    458455
dtype: int64
(1000209, 4)
Minimum Session Length: 20
Maximum Session Length: 2314
Average Session Length: 165.60
user_id        6040
item_id        3706
rating            5
timestamp    458455
dtype: int64
(1000209, 4)
Minimum Session Length: 20
Maximum Session Length: 2314
Average Session Length: 165.60


In [5]:
# ------------------Data Initialization----------------------#
if read_movie_filename != "":
    ml_movie_df = create_movie_df(read_movie_filename,size=size)
    ml_movie_df = convert_genres(ml_movie_df)
    
    # initialize reset object
    reset_object = reset_df()
    
    # map all user ids, item ids, and genres to range 0 - number of users/items/genres
    ml_1m,ml_movie_df = reset_object.fit_transform(ml_1m,ml_movie_df)
    
    # value that padded genre tokens shall take
    pad_genre_token = reset_object.genre_enc.transform(["NULL"]).item()
    
    genre_dim = len(np.unique(np.concatenate(ml_movie_df.genre))) - 1

else:
    # initialize reset object
    reset_object = reset_df()
    
    # map all user ids and item ids to range 0 - Number of Users/Items 
    # i.e. [1,7,5] -> [0,2,1]
    ml_1m = reset_object.fit_transform(ml_1m)
    
    pad_genre_token = None
    ml_movie_df = None
    genre_dim = 0



In [6]:
# ------------------Data Initialization----------------------#
# how many unique users, items, ratings and timestamps are there
n_users,n_items,n_ratings,n_timestamp = ml_1m.nunique()

# value that padded tokens shall take
pad_token = n_items

# the output dimension for softmax layer
output_dim = n_items


# get the item id : bert plot embedding dictionary
if bert_dim != 0:
    feature_embed = bert2dict(bert_filename=read_bert_filename)

In [7]:
# create a dictionary of every user's session (history)
# i.e. {user: [user clicks]}
user_history = create_user_history(ml_1m)

# create a dictionary of all items a user has not clicked
# i.e. {user: [items not clicked by user]}
# user_noclicks = create_user_noclick(user_history,ml_1m,n_items)

  1%|█▏                                                                             | 90/6040 [00:00<00:06, 893.47it/s]



100%|█████████████████████████████████████████████████████████████████████████████| 6040/6040 [00:06<00:00, 945.60it/s]


In [8]:
#import pickle

#with open('userhistory.pickle', 'wb') as handle:
#    pickle.dump(user_history, handle, protocol=pickle.HIGHEST_PROTOCOL)

#with open('userhistory.pickle', 'rb') as handle:
#    user_history = pickle.load(handle)

In [9]:
# split data by leave-one-out strategy
# have train dictionary {user: [last 41 items prior to last 2 items in user session]}
# have val dictionary {user: [last 41 items prior to last item in user session]}
# have test dictionary {user: [last 41 items]}
# i.e. if max_length = 4, [1,2,3,4,5,6] -> [1,2,3,4] , [2,3,4,5] , [3,4,5,6]
train_history,val_history,test_history = train_val_test_split(user_history,max_length=max_length)

# initialize the train,validation, and test pytorch dataset objects
# eval pads all items except last token to predict
train_dataset = GRUDataset(train_history,genre_df=ml_movie_df,mode='train',max_length=max_length,pad_token=pad_token,pad_genre_token=pad_genre_token)
val_dataset = GRUDataset(val_history,genre_df=ml_movie_df,mode='eval',max_length=max_length,pad_token=pad_token,pad_genre_token=pad_genre_token)
test_dataset = GRUDataset(test_history,genre_df=ml_movie_df,mode='eval',max_length=max_length,pad_token=pad_token,pad_genre_token=pad_genre_token)

# create the train,validation, and test pytorch dataloader objects
train_dl = DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
val_dl = DataLoader(val_dataset,batch_size=64)
test_dl = DataLoader(test_dataset,batch_size=64)

100%|██████████████████████████████████████████████████████████████████████████| 6040/6040 [00:00<00:00, 208615.14it/s]






In [10]:
print("Bert dim: {:d}".format(bert_dim))
print("Genre dim: {:d}".format(genre_dim))
print("Pad Token: {}".format(pad_token))
print("Pad Genre Token: {}".format(pad_genre_token))

Bert dim: 0
Genre dim: 0
Pad Token: 3706
Pad Genre Token: None


In [11]:
# ------------------Model Initialization----------------------#

# initialize gru4rec model with arguments specified earlier
"""
model = gru4recF(embedding_dim=embedding_dim,
         hidden_dim=hidden_dim,
         output_dim=output_dim,
         genre_dim=genre_dim,
         batch_first=True,
         max_length=max_length,
         pad_token=pad_token,
         pad_genre_token=pad_genre_token,
         bert_dim=bert_dim)
"""
"""
model = gru4recFC(embedding_dim=embedding_dim,
         hidden_dim=hidden_dim,
         output_dim=output_dim,
         genre_dim=genre_dim,
         batch_first=True,
         max_length=max_length,
         pad_token=pad_token,
         pad_genre_token=pad_genre_token,
         bert_dim=bert_dim)
"""
model = gru4rec_vanilla(hidden_dim=hidden_dim,
                        output_dim=output_dim,
                        batch_first=True,
                        max_length=max_length,
                        pad_token=pad_token)

In [12]:
model

gru4rec_vanilla(
  (encoder_layer): GRU(3707, 256, batch_first=True)
  (output_layer): Linear(in_features=256, out_features=3706, bias=True)
)

In [13]:
if bert_dim != 0:
    model.init_weight(reset_object,feature_embed)
    
model = model.cuda()

In [14]:
model

gru4rec_vanilla(
  (encoder_layer): GRU(3707, 256, batch_first=True)
  (output_layer): Linear(in_features=256, out_features=3706, bias=True)
)

In [15]:
[name for name,param in model.named_parameters() if (("movie" not in name) or ("plot_embedding" in name) or ("genre" in name)) ]

['encoder_layer.weight_ih_l0',
 'encoder_layer.weight_hh_l0',
 'encoder_layer.bias_ih_l0',
 'encoder_layer.bias_hh_l0',
 'output_layer.weight',
 'output_layer.bias']

In [16]:
[name for name,param in model.named_parameters() if ("plot" not in name) and ("genre" not in name)]

['encoder_layer.weight_ih_l0',
 'encoder_layer.weight_hh_l0',
 'encoder_layer.bias_ih_l0',
 'encoder_layer.bias_hh_l0',
 'output_layer.weight',
 'output_layer.bias']

In [17]:
# initialize Adam optimizer with gru4rec model parameters
if train_method != "normal":
    optimizer_features = torch.optim.Adam([param for name,param in model.named_parameters() if (("movie" not in name) or ("plot_embedding" in name) or ("genre" in name)) ],
                                          lr=lr,weight_decay=reg)
    
    optimizer_ids = torch.optim.Adam([param for name,param in model.named_parameters() if ("plot" not in name) and ("genre" not in name)],
                                     lr=lr,weight_decay=reg)

elif train_method == "normal":
    optimizer = torch.optim.Adam(model.parameters(),lr=lr,weight_decay=reg)

In [18]:
loss_fn = nn.CrossEntropyLoss(ignore_index=n_items)
#Recall_Object = Recall_E_prob(ml_1m,user_history,n_users,n_items,k=k)
#Recall_Object = Recall_E_Noprob(ml_1m,user_history,n_users,n_items,k=k)

In [19]:
Recall_Object = Recall_E_prob(ml_1m,user_history,n_users,n_items,k=k)



In [20]:
n_users

6040

In [21]:
print("Baseline POP results: ",Recall_Object.popular_baseline())

Baseline POP results:  14.33774834437086


In [22]:
#training_hit = Recall_Object(model,train_dl)
#validation_hit = Recall_Object(model,val_dl)
#testing_hit = Recall_Object(model,test_dl)
#print("Training Hits@{:d}: {:.2f}".format(k,training_hit))
#print("Validation Hits@{:d}: {:.2f}".format(k,validation_hit))
#print("Testing Hits@{:d}: {:.2f}".format(k,testing_hit))

In [None]:

# ------------------Training Initialization----------------------#
max_train_hit = 0
max_val_hit = 0
max_test_hit = 0

for epoch in range(num_epochs):
    print("="*20,"Epoch {}".format(epoch+1),"="*20)
    
    model.train()  
    
    running_loss = 0

    for j,data in enumerate(tqdm(train_dl,position=0,leave=True)):
        
        if train_method != "normal":
            optimizer_features.zero_grad()
            optimizer_ids.zero_grad()
            
        elif train_method == "normal": 
            optimizer.zero_grad()
        
        if genre_dim != 0:            
            inputs,genre_inputs,labels,x_lens,uid = data
            outputs = model(x=inputs.cuda(),x_lens=x_lens.squeeze().tolist(),x_genre=genre_inputs.cuda())
        
        else:
            inputs,labels,x_lens,uid = data
            outputs = model(x=inputs.cuda(),x_lens=x_lens.squeeze().tolist())

        loss = loss_fn(outputs.view(-1,outputs.size(-1)),labels.view(-1).cuda())
        loss.backward()
        
        if train_method != "normal":
            if train_method == "interleave":
                # interleave on the epochs
                if (j+1) % 2 == 0:
                    optimizer_features.step()
                else:
                    optimizer_ids.step()

            elif train_method == "alternate":
                if (epoch+1) % 2 == 0:
                    optimizer_features.step()
                else:
                    optimizer_ids.step()

        elif train_method == "normal":
            optimizer.step()

        running_loss += loss.detach().cpu().item()

    training_hit = Recall_Object(model,train_dl)
    validation_hit = Recall_Object(model,val_dl)
    testing_hit = Recall_Object(model,test_dl)
    
    if max_val_hit < validation_hit:
        max_val_hit = validation_hit
        max_test_hit = testing_hit
        max_train_hit = training_hit
    

    print("Training CE Loss: {:.5f}".format(running_loss/len(train_dl)))
    print("Training Hits@{:d}: {:.2f}".format(k,training_hit))
    print("Validation Hits@{:d}: {:.2f}".format(k,validation_hit))
    print("Testing Hits@{:d}: {:.2f}".format(k,testing_hit))
    

print("="*100)
print("Maximum Training Hit@{:d}: {:.2f}".format(k,max_train_hit))
print("Maximum Validation Hit@{:d}: {:.2f}".format(k,max_val_hit))
print("Maximum Testing Hit@{:d}: {:.2f}".format(k,max_test_hit))

  0%|                                                                                           | 0/95 [00:00<?, ?it/s]



100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.51it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 7.57304
Training Hits@10: 14.52
Validation Hits@10: 13.97
Testing Hits@10: 14.40


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.33it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.97it/s]

Training CE Loss: 7.39622
Training Hits@10: 21.66
Validation Hits@10: 21.27
Testing Hits@10: 19.74


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.62it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.83it/s]

Training CE Loss: 7.11076
Training Hits@10: 29.16
Validation Hits@10: 26.67
Testing Hits@10: 25.93


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.12it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.83it/s]

Training CE Loss: 6.83142
Training Hits@10: 41.11
Validation Hits@10: 38.49
Testing Hits@10: 35.58


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.43it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.93it/s]

Training CE Loss: 6.58545
Training Hits@10: 48.94
Validation Hits@10: 46.31
Testing Hits@10: 43.05


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.78it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:10,  9.09it/s]

Training CE Loss: 6.38691
Training Hits@10: 53.56
Validation Hits@10: 51.08
Testing Hits@10: 47.45


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.67it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.93it/s]

Training CE Loss: 6.24436
Training Hits@10: 57.38
Validation Hits@10: 54.44
Testing Hits@10: 50.08


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.74it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.93it/s]

Training CE Loss: 6.14366
Training Hits@10: 58.77
Validation Hits@10: 56.06
Testing Hits@10: 51.82


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.77it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.83it/s]

Training CE Loss: 6.07187
Training Hits@10: 59.74
Validation Hits@10: 56.95
Testing Hits@10: 53.44


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.77it/s]
  2%|█▋                                                                                 | 2/95 [00:00<00:08, 11.20it/s]

Training CE Loss: 6.00995
Training Hits@10: 62.75
Validation Hits@10: 59.30
Testing Hits@10: 55.36


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.62it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:11,  8.29it/s]

Training CE Loss: 5.96053
Training Hits@10: 62.96
Validation Hits@10: 59.98
Testing Hits@10: 55.71


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.04it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.64it/s]

Training CE Loss: 5.92164
Training Hits@10: 64.01
Validation Hits@10: 61.09
Testing Hits@10: 56.69


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.08it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:10,  9.20it/s]

Training CE Loss: 5.88431
Training Hits@10: 64.57
Validation Hits@10: 61.31
Testing Hits@10: 57.04


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.09it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.85372
Training Hits@10: 66.27
Validation Hits@10: 62.55
Testing Hits@10: 58.15


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.49it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:10,  9.37it/s]

Training CE Loss: 5.83168
Training Hits@10: 66.66
Validation Hits@10: 63.25
Testing Hits@10: 58.77


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.56it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:10,  9.19it/s]

Training CE Loss: 5.80202
Training Hits@10: 67.58
Validation Hits@10: 63.21
Testing Hits@10: 58.71


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.53it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.78451
Training Hits@10: 67.95
Validation Hits@10: 63.76
Testing Hits@10: 59.47


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.55it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.54it/s]

Training CE Loss: 5.76282
Training Hits@10: 67.95
Validation Hits@10: 63.99
Testing Hits@10: 59.39


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.58it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.63it/s]

Training CE Loss: 5.74463
Training Hits@10: 68.63
Validation Hits@10: 64.59
Testing Hits@10: 60.02


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 11.97it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.97it/s]

Training CE Loss: 5.72923
Training Hits@10: 68.58
Validation Hits@10: 64.54
Testing Hits@10: 59.98


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.66it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:09,  9.93it/s]

Training CE Loss: 5.71871
Training Hits@10: 69.21
Validation Hits@10: 64.95
Testing Hits@10: 60.35


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:07<00:00, 12.70it/s]
