In [33]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import torch
import torch.nn as nn
import torch.nn.functional as F

import argparse
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

from preprocessing import *
from dataset import *
from metrics import *
from model import *
from utils import bert2dict

In [2]:
# class gru4recFC_encoder(nn.Module):
#     """
#     embedding dim: the dimension of the item-embedding look-up table
#     hidden_dim: the dimension of the hidden state of the GRU-RNN
#     batch_first: whether the batch dimension should be the first dimension of input to GRU-RNN
#     output_dim: the output dimension of the last fully connected layer
#     max_length: the maximum session length for any user, used for packing/padding input to GRU-RNN
#     pad_token: the value that pad tokens should be set to for GRU-RNN and item embedding
#     bert_dim: the dimension of the feature-embedding look-up table
#     ... to do add all comments ... 
#     """
#     def __init__(self,embedding_dim,hidden_dim,output_dim,genre_dim=0,batch_first=True,max_length=200,bert_dim=0,tied=False,dropout=0):
#         super(gru4recFC_encoder,self).__init__()
        
#         self.batch_first =batch_first
        
#         self.embedding_dim = embedding_dim
#         self.hidden_dim =hidden_dim
#         self.output_dim =output_dim
#         self.genre_dim = genre_dim
#         self.bert_dim = bert_dim

#         self.max_length = max_length
# #         self.pad_token = pad_token
# #         self.pad_genre_token = pad_genre_token
#         self.tied = tied
        
#         self.dropout = dropout
    
#         if self.tied:
#             self.hidden_dim = embedding_dim
#         # initialize item-id lookup table
#         # add 1 to output dimension because we have to add a pad token
#         self.movie_embedding = nn.Embedding(output_dim,embedding_dim)
        
#         #  initialize plot lookup table
#         # add 1 to output dimensino because we have to add a pad token
#         if bert_dim != 0:
#             self.plot_embedding = nn.Embedding(output_dim,bert_dim)
#             #self.plot_embedding.requires_grad_(requires_grad=False)
#             #self.plot_embedding = torch.ones(output_dim+1,bert_dim).cuda() #nn.Embedding(output_dim+1,bert_dim,padding_idx=pad_token)
#             #self.plot_embedding[pad_token,:] = 0
        
#         if genre_dim != 0:
#             self.genre_embedding = nn.Embedding(genre_dim,embedding_dim)

#         self.projection_layer = nn.Linear(bert_dim+embedding_dim+genre_dim,embedding_dim)
        
#         self.encoder_layer = nn.GRU(embedding_dim,self.hidden_dim,batch_first=self.batch_first,dropout=self.dropout)

#         # add 1 to the output dimension because we have to add a pad token
# #         if not self.tied:
# #             self.output_layer = nn.Linear(hidden_dim,output_dim)
        
# #         if self.tied:
# #             self.output_layer = nn.Linear(hidden_dim,output_dim+1)
# #             self.output_layer.weight = self.movie_embedding.weight
            
    
#     def forward(self,x,x_lens,x_genre=None,pack=True):
#         # add the plot embedding and movie embedding
#         # do I add non-linearity or not? ... 
#         # concatenate or not? ...
#         # many questions ...
#         if (self.bert_dim != 0) and (self.genre_dim != 0):
#             x = torch.cat( (self.movie_embedding(x),self.plot_embedding(x),self.genre_embedding(x_genre).sum(2)) , 2)
#         elif (self.bert_dim != 0) and (self.genre_dim == 0):
#             x = torch.cat( (self.movie_embedding(x),self.plot_embedding(x) ) , 2)
#         elif (self.bert_dim == 0) and (self.genre_dim != 0):
#             x = torch.cat( (self.movie_embedding(x),self.genre_embedding(x_genre).sum(2)) , 2)
#         else:
#             x = self.movie_embedding(x)
        
#         x = self.projection_layer(x)
#         x = F.leaky_relu(x)
                    
# #         if pack:
# #             x = pack_padded_sequence(x,x_lens,batch_first=True,enforce_sorted=False)
        
# #         output_packed,_ = self.encoder_layer(x)        
#         x, hidden = self.encoder_layer(x)
        
# #         if pack:
# #             x, _ = pad_packed_sequence(output_packed, batch_first=self.batch_first,total_length=self.max_length,padding_value=self.pad_token)
            
# #         x = self.output_layer(x)
                
#         return x, hidden
    
#     def init_weight(self,reset_object,feature_embed):
#         for (item_id,embedding) in feature_embed.items():
#             if item_id not in reset_object.item_enc.classes_:
#                 continue
#             item_id = reset_object.item_enc.transform([item_id]).item()
#             self.plot_embedding.weight.data[item_id,:] = torch.DoubleTensor(embedding)
            
            
class gru4recF_encoder(nn.Module):
    """
    embedding dim: the dimension of the item-embedding look-up table
    hidden_dim: the dimension of the hidden state of the GRU-RNN
    batch_first: whether the batch dimension should be the first dimension of input to GRU-RNN
    output_dim: the output dimension of the last fully connected layer
    max_length: the maximum session length for any user, used for packing/padding input to GRU-RNN
    pad_token: the value that pad tokens should be set to for GRU-RNN and item embedding
    bert_dim: the dimension of the feature-embedding look-up table
    ... to do add all comments ... 
    """
    def __init__(self,embedding_dim,
                 hidden_dim,
                 output_dim,
                 genre_dim=0,
                 batch_first=True,
                 max_length=200,
                 pad_token=0,
                 pad_genre_token=0,
                 bert_dim=0,
                 dropout=0,
                 tied=False):
        
        super(gru4recF_encoder,self).__init__()
        
        self.batch_first =batch_first
        
        self.embedding_dim = embedding_dim
        self.hidden_dim =hidden_dim
        self.output_dim =output_dim
        self.genre_dim = genre_dim
        self.bert_dim = bert_dim

        self.max_length = max_length
        self.pad_token = pad_token
        self.pad_genre_token = pad_genre_token
        
        self.tied = tied
        self.dropout = dropout
        
        if self.tied:
            self.hidden_dim = embedding_dim
    
        # initialize item-id lookup table
        # add 1 to output dimension because we have to add a pad token
        self.movie_embedding = nn.Embedding(output_dim+1,embedding_dim,padding_idx=pad_token)
        
        #  initialize plot lookup table
        # add 1 to output dimensino because we have to add a pad token
        if bert_dim != 0:
            self.plot_embedding = nn.Embedding(output_dim+1,bert_dim,padding_idx=pad_token)
            #self.plot_embedding.requires_grad_(requires_grad=False)
            #self.plot_embedding = torch.ones(output_dim+1,bert_dim).cuda() #nn.Embedding(output_dim+1,bert_dim,padding_idx=pad_token)
            #self.plot_embedding[pad_token,:] = 0
            
            # project plot embedding to same dimensionality as movie embedding
            self.plot_projection = nn.Linear(bert_dim,embedding_dim)
                    
        if genre_dim != 0:
            self.genre_embedding = nn.Embedding(genre_dim+1,embedding_dim,padding_idx=pad_genre_token)


        self.encoder_layer = nn.GRU(embedding_dim,self.hidden_dim,batch_first=self.batch_first,dropout=self.dropout)

#         # add 1 to the output dimension because we have to add a pad token
#         if not self.tied:
#             self.output_layer = nn.Linear(hidden_dim,output_dim)
        
#         if self.tied:
#             self.output_layer = nn.Linear(hidden_dim,output_dim+1)
#             self.output_layer.weight = self.movie_embedding.weight
    
    def forward(self,x,x_lens,x_genre=None,pack=True):
        # add the plot embedding and movie embedding
        # do I add non-linearity or not? ... 
        # concatenate or not? ...
        # many questions ...
        if (self.bert_dim != 0) and (self.genre_dim != 0):
            x = self.movie_embedding(x) + self.plot_projection(F.leaky_relu(self.plot_embedding(x))) + self.genre_embedding(x_genre).sum(2)
        elif (self.bert_dim != 0) and (self.genre_dim == 0):
            x = self.movie_embedding(x) + self.plot_projection(F.leaky_relu(self.plot_embedding(x)))
        elif (self.bert_dim == 0) and (self.genre_dim != 0):
            x = self.movie_embedding(x) + self.genre_embedding(x_genre).sum(2)
        else:
            x = self.movie_embedding(x)
        
#         print("Embedder Dimension: ")
#         print(x.size())
        
        if pack:
            x = pack_padded_sequence(x,x_lens,batch_first=True,enforce_sorted=False)
        
        output_packed,hidden_state = self.encoder_layer(x) 
        
        if pack:
            x, _ = pad_packed_sequence(output_packed, batch_first=self.batch_first,total_length=self.max_length,padding_value=self.pad_token)
        
        return x,hidden_state
        
#         if pack:
#             x, _ = pad_packed_sequence(output_packed, batch_first=self.batch_first,total_length=self.max_length,padding_value=self.pad_token)
            
#         x = self.output_layer(x)
        
                
#         return x
    
    def init_weight(self,reset_object,feature_embed):
        for (item_id,embedding) in feature_embed.items():
            if item_id not in reset_object.item_enc.classes_:
                continue
            item_id = reset_object.item_enc.transform([item_id]).item()
            self.plot_embedding.weight.data[item_id,:] = torch.DoubleTensor(embedding)

In [3]:
MAX_LENGTH = 200
# class gru4recFC_decoder(nn.Module):
#     def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
#         super(gru4recFC_decoder, self).__init__()
#         self.hidden_size = hidden_size
#         self.output_size = output_size
#         self.dropout_p = dropout_p
#         self.max_length = max_length

#         self.embedding = nn.Embedding(self.output_size, self.hidden_size)
#         self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
#         self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
#         self.dropout = nn.Dropout(self.dropout_p)
#         self.gru = nn.GRU(self.hidden_size, self.hidden_size)
#         self.out = nn.Linear(self.hidden_size, self.output_size)

#     def forward(self, input, hidden, encoder_outputs):
#         embedded = self.embedding(input).view(1, 1, -1)
#         embedded = self.dropout(embedded)

#         attn_weights = F.softmax(
#             self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
#         attn_applied = torch.bmm(attn_weights.unsqueeze(0),
#                                  encoder_outputs.unsqueeze(0))

#         output = torch.cat((embedded[0], attn_applied[0]), 1)
#         output = self.attn_combine(output).unsqueeze(0)

#         output = F.relu(output)
#         output, hidden = self.gru(output, hidden)

#         output = F.log_softmax(self.out(output[0]), dim=1)
#         return output, hidden, attn_weights

#     def initHidden(self):
#         return torch.zeros(1, 1, self.hidden_size, device=device)
    
# # class EncoderRNN(nn.Module):
# #     def __init__(self, input_size, hidden_size):
# #         super(EncoderRNN, self).__init__()
# #         self.hidden_size = hidden_size

# #         self.embedding = nn.Embedding(input_size, hidden_size)
# #         self.gru = nn.GRU(hidden_size, hidden_size)

# #     def forward(self, input, hidden):
# #         embedded = self.embedding(input).view(1, 1, -1)
# #         output = embedded
# #         output, hidden = self.gru(output, hidden)
# #         return output, hidden

# #     def initHidden(self):
# #         return torch.zeros(1, 1, self.hidden_size, device=device)

class gru4recF_decoder(nn.Module):
    def __init__(self, hidden_dim, output_dim, dropout=0, max_length=MAX_LENGTH):
        super(gru4recF_decoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.dropout = dropout
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_dim+1, self.hidden_dim)
        self.attn = nn.Linear(self.hidden_dim * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_dim * 2, self.hidden_dim)
        self.dropout = nn.Dropout(self.dropout)
        self.gru = nn.GRU(self.hidden_dim, self.hidden_dim)
        self.out = nn.Linear(self.hidden_dim, self.output_dim)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, input.size()[1], -1)
        embedded = self.dropout(embedded)
        
#         print("Pre-Embedded Tensor: ")
#         print(input[0])
        
#         print("Embedded Dimension: ")
#         print(embedded.size())
        
#         print("Hidden Dimension: ")
#         print(hidden.size())

#         print("Embedded Tensor: ")
#         print(embedded[0])
        
#         print("Hidden Tensor: ")
#         print(hidden[0])
        
        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        
#         print("Attention Dimension: ")
#         print(attn_weights.size())
        
#         print("Encoder Outputs Dimension: ")
#         print(encoder_outputs.size())
        
        attn_applied = torch.bmm(attn_weights.unsqueeze(1),
                                 encoder_outputs).squeeze(1)
        
#         print("Attention Applied Dimension: ")
#         print(attn_applied.size())

        output = torch.cat((embedded[0], attn_applied), 1)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        output, hidden = self.gru(output, hidden)

        output = self.out(output[0])
        
#         print("Output Dimension: ")
#         print(output.size())
        
#         print("New Hidden Dimension: ")
#         print(hidden.size())
        
        return output, hidden, attn_weights

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [None]:
# class gru4recF_attention(nn.Module):
#         def __init__(self,embedding_dim,
#                  hidden_dim,
#                  output_dim,
#                  genre_dim=0,
#                  batch_first=True,
#                  max_length=200,
#                  pad_token=0,
#                  pad_genre_token=0,
#                  bert_dim=0,
#                  dropout=0,
#                  tied=False):
#             self.bert_dim = bert_dim
#             self.genre_dim = genre_dim
            
#             self.model = gru4recF_encoder(embedding_dim=embedding_dim,
#              hidden_dim=hidden_dim,
#              output_dim=output_dim,
#              genre_dim=genre_dim,
#              batch_first=True,
#              max_length=max_length,
#              bert_dim=bert_dim,
#              tied = tied,
#              dropout=dropout)

#             self.modelD = gru4recF_decoder(hidden_dim=hidden_dim, output_dim=output_dim, dropout=0, max_length=max_length)
            
#         def forward(self,x,x_lens,labels,x_genre=None,pack=True):
#             encoder_outputs, hidden_states = self.model(x=inputs.to(device),x_lens=x_lens.squeeze().tolist())
#             decoder_inputs = inputs[:,0].view(1,-1).to(device)
#             decoder_hidden = hidden_states
            
#             outputs = torch.zeros(inputs.size()[0],max_length,output_dim, device=device)
            
#             for i in range(max_length):
#                 decoder_outputs, decoder_hidden, decoder_attention = self.modelD(decoder_inputs, decoder_hidden, encoder_outputs)
#                 outputs[:,i,:] = decoder_outputs
#                 decoder_inputs = labels[:,i].view(1,-1)

In [48]:
class gru4recF_attention(nn.Module):
    """
    embedding dim: the dimension of the item-embedding look-up table
    hidden_dim: the dimension of the hidden state of the GRU-RNN
    batch_first: whether the batch dimension should be the first dimension of input to GRU-RNN
    output_dim: the output dimension of the last fully connected layer
    max_length: the maximum session length for any user, used for packing/padding input to GRU-RNN
    pad_token: the value that pad tokens should be set to for GRU-RNN and item embedding
    bert_dim: the dimension of the feature-embedding look-up table
    ... to do add all comments ... 
    """
    def __init__(self,embedding_dim,
                 hidden_dim,
                 output_dim,
                 genre_dim=0,
                 batch_first=True,
                 max_length=200,
                 pad_token=0,
                 pad_genre_token=0,
                 bert_dim=0,
                 dropout=0,
                 tied=False):
        
        super(gru4recF_attention,self).__init__()
        
        self.batch_first =batch_first
        
        self.embedding_dim = embedding_dim
        self.hidden_dim =hidden_dim
        self.output_dim =output_dim
        self.genre_dim = genre_dim
        self.bert_dim = bert_dim

        self.max_length = max_length
        self.pad_token = pad_token
        self.pad_genre_token = pad_genre_token
        
        self.tied = tied
        self.dropout = dropout
        
        if self.tied:
            self.hidden_dim = embedding_dim
    
        # initialize item-id lookup table
        # add 1 to output dimension because we have to add a pad token
        self.movie_embedding = nn.Embedding(output_dim+1,embedding_dim,padding_idx=pad_token)
        
        #  initialize plot lookup table
        # add 1 to output dimensino because we have to add a pad token
        if bert_dim != 0:
            self.plot_embedding = nn.Embedding(output_dim+1,bert_dim,padding_idx=pad_token)
            #self.plot_embedding.requires_grad_(requires_grad=False)
            #self.plot_embedding = torch.ones(output_dim+1,bert_dim).cuda() #nn.Embedding(output_dim+1,bert_dim,padding_idx=pad_token)
            #self.plot_embedding[pad_token,:] = 0
            
            # project plot embedding to same dimensionality as movie embedding
            self.plot_projection = nn.Linear(bert_dim,embedding_dim)
                    
        if genre_dim != 0:
            self.genre_embedding = nn.Embedding(genre_dim+1,embedding_dim,padding_idx=pad_genre_token)


        self.encoder_layer = nn.GRU(embedding_dim,self.hidden_dim,batch_first=self.batch_first,dropout=self.dropout)

#         # add 1 to the output dimension because we have to add a pad token
        if not self.tied:
            self.output_layer = nn.Linear(hidden_dim,output_dim)
        
        if self.tied:
            self.output_layer = nn.Linear(hidden_dim,output_dim+1)
            self.output_layer.weight = self.movie_embedding.weight
    
    def forward(self,x,x_lens,x_genre=None,pack=True):
        # add the plot embedding and movie embedding
        # do I add non-linearity or not? ... 
        # concatenate or not? ...
        # many questions ...
        batch_size = x.size()[0]
        if (self.bert_dim != 0) and (self.genre_dim != 0):
            x = self.movie_embedding(x) + self.plot_projection(F.leaky_relu(self.plot_embedding(x))) + self.genre_embedding(x_genre).sum(2)
        elif (self.bert_dim != 0) and (self.genre_dim == 0):
            x = self.movie_embedding(x) + self.plot_projection(F.leaky_relu(self.plot_embedding(x)))
        elif (self.bert_dim == 0) and (self.genre_dim != 0):
            x = self.movie_embedding(x) + self.genre_embedding(x_genre).sum(2)
        else:
            x = self.movie_embedding(x)
        
#         print("Embedder Dimension: ")
#         print(x.size())
        
        if pack:
            x = pack_padded_sequence(x,x_lens,batch_first=True,enforce_sorted=False)
        
        output_packed,hidden_state = self.encoder_layer(x) 
        
        if pack:
            encoder_states, _ = pad_packed_sequence(output_packed, batch_first=self.batch_first,total_length=self.max_length,padding_value=self.pad_token)
        
        # CCs = BS x MS x 2HS
        combined_contexts = torch.zeros(batch_size,max_length,self.hidden_dim*2)
        
        for t in range(max_length):
            # CF = BS x (t+1) x HS
            context_frame = encoder_states[:,:t+1,:]
            # CH = BS x HS x 1
            current_hidden = encoder_states[:,t,:].squeeze(1).unsqueeze(2)
            # AS = BS x (t+1) x 1
            attention_score = torch.bmm(context_frame,current_hidden)
            # CFT = BS x HS x (t+1)
            context_frame_transposed = torch.transpose(context_frame,1,2)
            # CV = BS x HS
            context_vector = torch.bmm(context_frame_transposed,attention_score).squeeze(2)
            # CH = BS x HS
            current_hidden = current_hidden.squeeze(2)
            # CC = BS x 1 x 2HS
            # combined_context = torch.cat((current_hidden,context_vector),1).unsqueeze(1)
            combined_contexts[:,t,:] = context_vector
            
        # CCs = BS x MS x 2HS
        # O = BS x MS x V
        x = self.output_layer(combined_contexts)
        
        return x

In [4]:
teacher_forcing_ratio = 0.5


def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
    encoder_hidden = encoder.initHidden()

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)

    loss = 0

    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(
            input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]

    decoder_input = torch.tensor([[SOS_token]], device=device)

    decoder_hidden = encoder_hidden

    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        # Teacher forcing: Feed the target as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            loss += criterion(decoder_output, target_tensor[di])
            decoder_input = target_tensor[di]  # Teacher forcing

    else:
        # Without teacher forcing: use its own predictions as the next input
        for di in range(target_length):
            decoder_output, decoder_hidden, decoder_attention = decoder(
                decoder_input, decoder_hidden, encoder_outputs)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()  # detach from history as input

            loss += criterion(decoder_output, target_tensor[di])
            if decoder_input.item() == EOS_token:
                break

    loss.backward()

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length

In [5]:
def trainIters(encoder, decoder, n_iters, print_every=1000, plot_every=100, learning_rate=0.01):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    training_pairs = [tensorsFromPair(random.choice(pairs))
                      for i in range(n_iters)]
    criterion = nn.NLLLoss()

    for iter in range(1, n_iters + 1):
        training_pair = training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]

        loss = train(input_tensor, target_tensor, encoder,
                     decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
                                         iter, iter / n_iters * 100, print_loss_avg))

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    showPlot(plot_losses)

In [6]:
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np


def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

Questions:

What is the current gru4recFC outputting?
How far in the future do we need to predict?
Is our problem formulation compatible enough  w/ seq2seq?
Max length vs. set length of input?
How much to give encoder, how much to expect from decoder

Resources:
https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
https://towardsdatascience.com/intuitive-understanding-of-attention-mechanism-in-deep-learning-6c9482aecf4f

In [35]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [36]:
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 23 08:39:11 2021

@author: lpott
"""
import argparse
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

from preprocessing import *
from dataset import *
from metrics import *
from model import *
from utils import bert2dict


In [37]:
# variables

read_filename ="ml-1m\\ratings.dat"
read_bert_filename = "bert_sequence_20m.txt"
read_movie_filename = ""#"movies-1m.csv"
size = "1m"

num_epochs = 100
lr = 1e-2
batch_size = 64
reg = 1e-4
train_method = "normal"


hidden_dim = 256
embedding_dim = 128
bert_dim= 768
window = 0

freeze_plot = False
tied = False
dropout= 0

k = 10
max_length = 200
min_len = 10


# nextitnet options...
hidden_layers = 3
dilations = [1,2,2,4]

model_type = "attention"
# model_type = "feature_add"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [38]:
torch.cuda.empty_cache()

In [39]:
# ------------------Data Initialization----------------------#

# convert .dat file to time-sorted pandas dataframe
ml_1m = create_df(read_filename,size=size)

# remove users who have sessions lengths less than min_len
ml_1m = filter_df(ml_1m,item_min=min_len)

user_id        6040
item_id        3706
rating            5
timestamp    458455
dtype: int64
(1000209, 4)
Minimum Session Length: 20
Maximum Session Length: 2314
Average Session Length: 165.60
user_id        6040
item_id        3706
rating            5
timestamp    458455
dtype: int64
(1000209, 4)
Minimum Session Length: 20
Maximum Session Length: 2314
Average Session Length: 165.60


In [40]:
# ------------------Data Initialization----------------------#
if read_movie_filename != "":
    ml_movie_df = create_movie_df(read_movie_filename,size=size)
    ml_movie_df = convert_genres(ml_movie_df)
    
    # initialize reset object
    reset_object = reset_df()
    
    # map all user ids, item ids, and genres to range 0 - number of users/items/genres
    ml_1m,ml_movie_df = reset_object.fit_transform(ml_1m,ml_movie_df)
    
    # value that padded genre tokens shall take
    pad_genre_token = reset_object.genre_enc.transform(["NULL"]).item()
    
    genre_dim = len(np.unique(np.concatenate(ml_movie_df.genre))) - 1

else:
    # initialize reset object
    reset_object = reset_df()
    
    # map all user ids and item ids to range 0 - Number of Users/Items 
    # i.e. [1,7,5] -> [0,2,1]
    ml_1m = reset_object.fit_transform(ml_1m)
    
    pad_genre_token = None
    ml_movie_df = None
    genre_dim = 0



In [41]:
# ------------------Data Initialization----------------------#
# how many unique users, items, ratings and timestamps are there
n_users,n_items,n_ratings,n_timestamp = ml_1m.nunique()

# value that padded tokens shall take
pad_token = n_items

# the output dimension for softmax layer
output_dim = n_items


# get the item id : bert plot embedding dictionary
if bert_dim != 0:
    feature_embed = bert2dict(bert_filename=read_bert_filename)



In [42]:
# create a dictionary of every user's session (history)
# i.e. {user: [user clicks]}
if size == "1m":
    user_history = create_user_history(ml_1m)

elif size == "20m":
    import pickle
    with open('userhistory.pickle', 'rb') as handle:
        user_history = pickle.load(handle)
# create a dictionary of all items a user has not clicked
# i.e. {user: [items not clicked by user]}
# user_noclicks = create_user_noclick(user_history,ml_1m,n_items)

  1%|▋                                                                              | 57/6040 [00:00<00:10, 564.35it/s]



100%|█████████████████████████████████████████████████████████████████████████████| 6040/6040 [00:10<00:00, 569.49it/s]


In [43]:
# split data by leave-one-out strategy
# have train dictionary {user: [last 41 items prior to last 2 items in user session]}
# have val dictionary {user: [last 41 items prior to last item in user session]}
# have test dictionary {user: [last 41 items]}
# i.e. if max_length = 4, [1,2,3,4,5,6] -> [1,2,3,4] , [2,3,4,5] , [3,4,5,6]
train_history,val_history,test_history = train_val_test_split(user_history,max_length=max_length)

# initialize the train,validation, and test pytorch dataset objects
# eval pads all items except last token to predict
train_dataset = GRUDataset(train_history,genre_df=ml_movie_df,mode='train',max_length=max_length,pad_token=pad_token,pad_genre_token=pad_genre_token)
val_dataset = GRUDataset(val_history,genre_df=ml_movie_df,mode='eval',max_length=max_length,pad_token=pad_token,pad_genre_token=pad_genre_token)
test_dataset = GRUDataset(test_history,genre_df=ml_movie_df,mode='eval',max_length=max_length,pad_token=pad_token,pad_genre_token=pad_genre_token)

# create the train,validation, and test pytorch dataloader objects
train_dl = DataLoader(train_dataset,batch_size = batch_size,shuffle=True)
val_dl = DataLoader(val_dataset,batch_size=64)
test_dl = DataLoader(test_dataset,batch_size=64)

100%|██████████████████████████████████████████████████████████████████████████| 6040/6040 [00:00<00:00, 147317.47it/s]






In [44]:
print("Bert dim: {:d}".format(bert_dim))
print("Genre dim: {:d}".format(genre_dim))
print("Pad Token: {}".format(pad_token))
print("Pad Genre Token: {}".format(pad_genre_token))

Bert dim: 768
Genre dim: 0
Pad Token: 3706
Pad Genre Token: None


In [49]:
# ------------------Model Initialization----------------------#

# initialize gru4rec model with arguments specified earlier
if model_type == "feature_add":
    model = gru4recF(embedding_dim=embedding_dim,
             hidden_dim=hidden_dim,
             output_dim=output_dim,
             genre_dim=genre_dim,
             batch_first=True,
             max_length=max_length,
             pad_token=pad_token,
             pad_genre_token=pad_genre_token,
             bert_dim=bert_dim,
             tied = tied,
             dropout=dropout)


if model_type == "feature_concat":
    model = gru4recFC(embedding_dim=embedding_dim,
             hidden_dim=hidden_dim,
             output_dim=output_dim,
             genre_dim=genre_dim,
             batch_first=True,
             max_length=max_length,
             pad_token=pad_token,
             pad_genre_token=pad_genre_token,
             bert_dim=bert_dim,
             tied = tied,
             dropout=dropout)

if model_type == "vanilla":
    model = gru4rec_vanilla(hidden_dim=hidden_dim,
                            output_dim=output_dim,
                            batch_first=True,
                            max_length=max_length,
                            pad_token=pad_token,
                            tied=tied,
                            embedding_dim=embedding_dim)

if model_type =="feature_only":
    model = gru4rec_feature(hidden_dim=hidden_dim,
                            output_dim=output_dim,
                            batch_first=True,
                            max_length=max_length,
                            pad_token=pad_token,
                            bert_dim=bert_dim)

if model_type == "conv":
    model = gru4rec_conv(embedding_dim,
                 hidden_dim,
                 output_dim,
                 batch_first=True,
                 max_length=200,
                 pad_token=0,
                 dropout=0,
                 window=3,
                 tied=tied)
    
if model_type == "nextitnet":
    model = NextItNet(embedding_dim=embedding_dim,
                      output_dim=output_dim,
                      hidden_layers=hidden_layers,
                      dilations=dilations,
                      pad_token=n_items,
                      max_len=max_length)

if model_type == "attention":
    model = gru4recF(embedding_dim=embedding_dim,
             hidden_dim=hidden_dim,
             output_dim=output_dim,
             genre_dim=genre_dim,
             batch_first=True,
             max_length=max_length,
             pad_token=pad_token,
             pad_genre_token=pad_genre_token,
             bert_dim=bert_dim,
             tied = tied,
             dropout=dropout)
#     modelD = gru4recF_decoder(hidden_dim=hidden_dim, output_dim=output_dim, dropout=0, max_length=max_length)
#     modelD = modelD.to(device)

In [50]:
if bert_dim != 0:
    model.init_weight(reset_object,feature_embed)
    
model = model.to(device)

In [51]:
[name for name,param in model.named_parameters() if (("movie" not in name) or ("plot_embedding" in name) or ("genre" in name))]

['plot_embedding.weight',
 'plot_projection.weight',
 'plot_projection.bias',
 'encoder_layer.weight_ih_l0',
 'encoder_layer.weight_hh_l0',
 'encoder_layer.bias_ih_l0',
 'encoder_layer.bias_hh_l0',
 'output_layer.weight',
 'output_layer.bias']

In [52]:
[name for name,param in model.named_parameters() if ("plot" not in name) and ("genre" not in name)]

['movie_embedding.weight',
 'encoder_layer.weight_ih_l0',
 'encoder_layer.weight_hh_l0',
 'encoder_layer.bias_ih_l0',
 'encoder_layer.bias_hh_l0',
 'output_layer.weight',
 'output_layer.bias']

In [53]:
# initialize Adam optimizer with gru4rec model parameters
if train_method != "normal":
    optimizer_features = torch.optim.Adam([param for name,param in model.named_parameters() if (("movie" not in name) or ("plot_embedding" in name) or ("genre" in name)) ],
                                          lr=lr/10,weight_decay=reg)
    
    optimizer_ids = torch.optim.Adam([param for name,param in model.named_parameters() if ("plot" not in name) and ("genre" not in name)],
                                     lr=lr,weight_decay=reg)

elif train_method == "normal":
    optimizer = torch.optim.Adam(model.parameters(),lr=lr,weight_decay=reg)
    decoder_optimizer = torch.optim.Adam(model.parameters(), lr=lr,weight_decay=reg)
if freeze_plot and bert_dim !=0:
    model.plot_embedding.weight.requires_grad = False

In [54]:
loss_fn = nn.CrossEntropyLoss(ignore_index=n_items)
#Recall_Object = Recall_E_prob(ml_1m,user_history,n_users,n_items,k=k)
#Recall_Object = Recall_E_Noprob(ml_1m,user_history,n_users,n_items,k=k)

In [55]:
Recall_Object = Recall_E_prob(ml_1m,user_history,n_users,n_items,k=k,device=device)



In [56]:
# ------------------Training Initialization----------------------#
max_train_hit = 0
max_val_hit = 0
max_test_hit = 0

i = 0;
for epoch in range(num_epochs):
    print("="*20,"Epoch {}".format(epoch+1),"="*20)
    
    model.train()  
    
    running_loss = 0

    for j,data in enumerate(tqdm(train_dl,position=0,leave=True)):
        
        if train_method != "normal":
            optimizer_features.zero_grad()
            optimizer_ids.zero_grad()
            
        elif train_method == "normal": 
            optimizer.zero_grad()
        
        if genre_dim != 0:            
            inputs,genre_inputs,labels,x_lens,uid = data
            outputs = model(x=inputs.to(device),x_lens=x_lens.squeeze().tolist(),x_genre=genre_inputs.to(device))
        
        elif genre_dim == 0:
            inputs,labels,x_lens,uid = data 
            outputs = model(x=inputs.to(device),x_lens=x_lens.squeeze().tolist())
       
        if tied:
            outputs_ignore_pad = outputs[:,:,:-1]
            loss = loss_fn(outputs_ignore_pad.view(-1,outputs_ignore_pad.size(-1)),labels.view(-1).to(device))
            
        else:
            loss = loss_fn(outputs.view(-1,outputs.size(-1)),labels.view(-1).to(device))
            
        loss.backward()
        
        
        if train_method != "normal":
            if train_method == "interleave":
                # interleave on the epochs
                if (j+1) % 2 == 0:
                    optimizer_features.step()
                else:
                    optimizer_ids.step()

            elif train_method == "alternate":
                if (epoch+1) % 2 == 0:
                    optimizer_features.step()
                else:
                    optimizer_ids.step()
        
    
                    
        elif train_method == "normal":
            optimizer.step()

        running_loss += loss.detach().cpu().item()

    del outputs
    torch.cuda.empty_cache()
    training_hit = Recall_Object(model,train_dl,"train")
    validation_hit = Recall_Object(model,val_dl,"validation")
    testing_hit = Recall_Object(model,test_dl,"test")
    
    if max_val_hit < validation_hit:
        max_val_hit = validation_hit
        max_test_hit = testing_hit
        max_train_hit = training_hit
    
    torch.cuda.empty_cache()
    print("Training CE Loss: {:.5f}".format(running_loss/len(train_dl)))
    print("Training Hits@{:d}: {:.2f}".format(k,training_hit))
    print("Validation Hits@{:d}: {:.2f}".format(k,validation_hit))
    print("Testing Hits@{:d}: {:.2f}".format(k,testing_hit))


print("="*100)
print("Maximum Training Hit@{:d}: {:.2f}".format(k,max_train_hit))
print("Maximum Validation Hit@{:d}: {:.2f}".format(k,max_val_hit))
print("Maximum Testing Hit@{:d}: {:.2f}".format(k,max_test_hit))

  0%|                                                                                           | 0/95 [00:00<?, ?it/s]



100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:15<00:00,  6.28it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 7.39208
Training Hits@10: 34.64
Validation Hits@10: 32.25
Testing Hits@10: 30.15


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.36it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 6.46878
Training Hits@10: 59.64
Validation Hits@10: 55.41
Testing Hits@10: 52.50


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:15<00:00,  6.32it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 6.04423
Training Hits@10: 66.04
Validation Hits@10: 62.48
Testing Hits@10: 58.15


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.43it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.88014
Training Hits@10: 68.34
Validation Hits@10: 64.39
Testing Hits@10: 60.33


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.37it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.79427
Training Hits@10: 70.66
Validation Hits@10: 66.85
Testing Hits@10: 61.72


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.42it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.73500
Training Hits@10: 70.81
Validation Hits@10: 66.95
Testing Hits@10: 61.97


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.50it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.69409
Training Hits@10: 72.24
Validation Hits@10: 67.57
Testing Hits@10: 62.40


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:15<00:00,  6.24it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.68849
Training Hits@10: 72.20
Validation Hits@10: 67.68
Testing Hits@10: 63.49


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.35it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.64308
Training Hits@10: 73.25
Validation Hits@10: 68.36
Testing Hits@10: 65.02


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:15<00:00,  6.25it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.62958
Training Hits@10: 73.20
Validation Hits@10: 68.66
Testing Hits@10: 64.39


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.43it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.61060
Training Hits@10: 72.63
Validation Hits@10: 68.58
Testing Hits@10: 63.66


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.37it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.60677
Training Hits@10: 73.56
Validation Hits@10: 68.49
Testing Hits@10: 64.65


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.56it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.58658
Training Hits@10: 74.07
Validation Hits@10: 68.79
Testing Hits@10: 64.82


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.54it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.60954
Training Hits@10: 74.19
Validation Hits@10: 69.57
Testing Hits@10: 65.07


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.49it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.58192
Training Hits@10: 73.61
Validation Hits@10: 68.66
Testing Hits@10: 64.82


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.60it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.57142
Training Hits@10: 73.66
Validation Hits@10: 69.16
Testing Hits@10: 64.70


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.55it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.57898
Training Hits@10: 74.29
Validation Hits@10: 69.17
Testing Hits@10: 64.97


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.58it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.56743
Training Hits@10: 75.13
Validation Hits@10: 69.64
Testing Hits@10: 65.26


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.58it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.55689
Training Hits@10: 74.93
Validation Hits@10: 69.62
Testing Hits@10: 64.98


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.58it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.55075
Training Hits@10: 74.65
Validation Hits@10: 69.29
Testing Hits@10: 65.23


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.60it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.54799
Training Hits@10: 74.65
Validation Hits@10: 69.40
Testing Hits@10: 65.18


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.39it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.55011
Training Hits@10: 73.66
Validation Hits@10: 69.45
Testing Hits@10: 65.00


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.56it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.55809
Training Hits@10: 75.00
Validation Hits@10: 68.82
Testing Hits@10: 65.40


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:15<00:00,  6.26it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.54755
Training Hits@10: 75.55
Validation Hits@10: 70.35
Testing Hits@10: 66.08


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:15<00:00,  6.29it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.54243
Training Hits@10: 74.98
Validation Hits@10: 69.67
Testing Hits@10: 66.56


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:15<00:00,  6.23it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.53240
Training Hits@10: 75.20
Validation Hits@10: 69.50
Testing Hits@10: 64.93


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:15<00:00,  6.24it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.54912
Training Hits@10: 74.87
Validation Hits@10: 69.80
Testing Hits@10: 64.92


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:15<00:00,  5.95it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.53671
Training Hits@10: 75.26
Validation Hits@10: 69.45
Testing Hits@10: 65.58


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:15<00:00,  6.23it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.53026
Training Hits@10: 75.05
Validation Hits@10: 70.17
Testing Hits@10: 66.24


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:15<00:00,  6.20it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.52495
Training Hits@10: 74.64
Validation Hits@10: 68.94
Testing Hits@10: 65.02


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.33it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.53050
Training Hits@10: 75.45
Validation Hits@10: 69.60
Testing Hits@10: 65.46


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:15<00:00,  6.26it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.52330
Training Hits@10: 75.65
Validation Hits@10: 69.65
Testing Hits@10: 65.81


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.66it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.52111
Training Hits@10: 75.68
Validation Hits@10: 70.38
Testing Hits@10: 66.08


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.84it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:15,  5.88it/s]

Training CE Loss: 5.52108
Training Hits@10: 75.03
Validation Hits@10: 69.85
Testing Hits@10: 65.94


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.10it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.56840
Training Hits@10: 75.55
Validation Hits@10: 70.02
Testing Hits@10: 65.89


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.97it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:16,  5.81it/s]

Training CE Loss: 5.52141
Training Hits@10: 75.68
Validation Hits@10: 69.92
Testing Hits@10: 65.71


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.86it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:16,  5.59it/s]

Training CE Loss: 5.52395
Training Hits@10: 75.07
Validation Hits@10: 69.27
Testing Hits@10: 65.48


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.83it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.51950
Training Hits@10: 75.41
Validation Hits@10: 69.44
Testing Hits@10: 65.55


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.82it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.51309
Training Hits@10: 74.92
Validation Hits@10: 69.75
Testing Hits@10: 65.46


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.79it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:16,  5.59it/s]

Training CE Loss: 5.51056
Training Hits@10: 76.11
Validation Hits@10: 69.95
Testing Hits@10: 65.83


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.82it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.51380
Training Hits@10: 75.75
Validation Hits@10: 70.07
Testing Hits@10: 65.23


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.86it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.51063
Training Hits@10: 76.24
Validation Hits@10: 70.08
Testing Hits@10: 66.01


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.88it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.50113
Training Hits@10: 75.33
Validation Hits@10: 70.33
Testing Hits@10: 65.96


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.20it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:15,  5.88it/s]

Training CE Loss: 5.50498
Training Hits@10: 74.88
Validation Hits@10: 69.83
Testing Hits@10: 65.53


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.19it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.50315
Training Hits@10: 75.75
Validation Hits@10: 69.93
Testing Hits@10: 65.07


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.97it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:16,  5.85it/s]

Training CE Loss: 5.50567
Training Hits@10: 75.46
Validation Hits@10: 69.62
Testing Hits@10: 65.33


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.89it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.51121
Training Hits@10: 75.56
Validation Hits@10: 70.08
Testing Hits@10: 66.26


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.56it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.50526
Training Hits@10: 76.13
Validation Hits@10: 70.07
Testing Hits@10: 66.14


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.52it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.49937
Training Hits@10: 75.13
Validation Hits@10: 69.35
Testing Hits@10: 65.40


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.56it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.50510
Training Hits@10: 76.09
Validation Hits@10: 69.39
Testing Hits@10: 65.83


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.54it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.50155
Training Hits@10: 74.97
Validation Hits@10: 70.36
Testing Hits@10: 66.03


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.53it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.49650
Training Hits@10: 76.37
Validation Hits@10: 69.87
Testing Hits@10: 66.46


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.51it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.49661
Training Hits@10: 75.96
Validation Hits@10: 69.02
Testing Hits@10: 65.30


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.49it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.49907
Training Hits@10: 75.45
Validation Hits@10: 70.13
Testing Hits@10: 66.46


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.54it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.58630
Training Hits@10: 74.98
Validation Hits@10: 69.67
Testing Hits@10: 65.60


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.56it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.53040
Training Hits@10: 75.81
Validation Hits@10: 69.67
Testing Hits@10: 65.41


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.51it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.51143
Training Hits@10: 76.08
Validation Hits@10: 71.03
Testing Hits@10: 65.53


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.54it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.49596
Training Hits@10: 75.83
Validation Hits@10: 70.10
Testing Hits@10: 65.75


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:15<00:00,  6.20it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.49405
Training Hits@10: 76.39
Validation Hits@10: 70.75
Testing Hits@10: 66.67


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.53it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.49625
Training Hits@10: 76.32
Validation Hits@10: 70.43
Testing Hits@10: 66.41


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:17<00:00,  5.34it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48792
Training Hits@10: 76.57
Validation Hits@10: 70.46
Testing Hits@10: 65.71


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.85it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48465
Training Hits@10: 76.08
Validation Hits@10: 69.98
Testing Hits@10: 66.18


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.77it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:16,  5.68it/s]

Training CE Loss: 5.48627
Training Hits@10: 75.99
Validation Hits@10: 70.17
Testing Hits@10: 66.11


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.15it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48797
Training Hits@10: 76.42
Validation Hits@10: 70.20
Testing Hits@10: 65.93


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.86it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48814
Training Hits@10: 76.27
Validation Hits@10: 69.97
Testing Hits@10: 66.49


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.17it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:16,  5.56it/s]

Training CE Loss: 5.48587
Training Hits@10: 75.35
Validation Hits@10: 70.08
Testing Hits@10: 65.55


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.15it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:16,  5.75it/s]

Training CE Loss: 5.48347
Training Hits@10: 76.39
Validation Hits@10: 69.64
Testing Hits@10: 66.13


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.06it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48615
Training Hits@10: 75.78
Validation Hits@10: 70.31
Testing Hits@10: 66.34


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.84it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48812
Training Hits@10: 76.03
Validation Hits@10: 69.83
Testing Hits@10: 66.08


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.57it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48153
Training Hits@10: 76.39
Validation Hits@10: 70.05
Testing Hits@10: 65.58


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.81it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48709
Training Hits@10: 76.67
Validation Hits@10: 69.83
Testing Hits@10: 66.13


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.80it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48052
Training Hits@10: 75.40
Validation Hits@10: 69.64
Testing Hits@10: 65.91


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.76it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48918
Training Hits@10: 76.04
Validation Hits@10: 70.18
Testing Hits@10: 66.44


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.80it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48417
Training Hits@10: 77.00
Validation Hits@10: 70.75
Testing Hits@10: 66.89


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.77it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48203
Training Hits@10: 75.84
Validation Hits@10: 70.40
Testing Hits@10: 65.98


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.84it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48471
Training Hits@10: 76.13
Validation Hits@10: 70.35
Testing Hits@10: 66.57


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.87it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.47922
Training Hits@10: 75.81
Validation Hits@10: 69.90
Testing Hits@10: 65.73


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.80it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48569
Training Hits@10: 76.14
Validation Hits@10: 70.05
Testing Hits@10: 66.29


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:14<00:00,  6.73it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:17,  5.52it/s]

Training CE Loss: 5.47859
Training Hits@10: 75.94
Validation Hits@10: 70.05
Testing Hits@10: 66.46


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  6.99it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48490
Training Hits@10: 76.16
Validation Hits@10: 70.00
Testing Hits@10: 66.09


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.21it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48052
Training Hits@10: 76.08
Validation Hits@10: 69.55
Testing Hits@10: 65.79


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.18it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:16,  5.78it/s]

Training CE Loss: 5.47571
Training Hits@10: 75.53
Validation Hits@10: 70.26
Testing Hits@10: 65.75


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.19it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.47641
Training Hits@10: 76.23
Validation Hits@10: 70.08
Testing Hits@10: 66.44


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.18it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.47502
Training Hits@10: 75.75
Validation Hits@10: 69.72
Testing Hits@10: 66.18


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.10it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:16,  5.75it/s]

Training CE Loss: 5.47840
Training Hits@10: 75.84
Validation Hits@10: 69.83
Testing Hits@10: 65.66


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.11it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:16,  5.81it/s]

Training CE Loss: 5.47968
Training Hits@10: 76.44
Validation Hits@10: 70.08
Testing Hits@10: 65.79


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.14it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:15,  5.92it/s]

Training CE Loss: 5.48088
Training Hits@10: 75.10
Validation Hits@10: 68.77
Testing Hits@10: 65.08


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.12it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48868
Training Hits@10: 76.54
Validation Hits@10: 70.30
Testing Hits@10: 66.23


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.13it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:15,  5.88it/s]

Training CE Loss: 5.47456
Training Hits@10: 76.26
Validation Hits@10: 69.92
Testing Hits@10: 65.99


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.07it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:16,  5.78it/s]

Training CE Loss: 5.51109
Training Hits@10: 76.75
Validation Hits@10: 70.00
Testing Hits@10: 65.58


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.15it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.48741
Training Hits@10: 76.52
Validation Hits@10: 70.50
Testing Hits@10: 66.27


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.14it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:15,  6.06it/s]

Training CE Loss: 5.47361
Training Hits@10: 76.59
Validation Hits@10: 70.13
Testing Hits@10: 67.00


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.16it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.47312
Training Hits@10: 76.08
Validation Hits@10: 70.28
Testing Hits@10: 66.39


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.15it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:16,  5.75it/s]

Training CE Loss: 5.47848
Training Hits@10: 76.31
Validation Hits@10: 70.46
Testing Hits@10: 66.13


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.12it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:15,  6.10it/s]

Training CE Loss: 5.47321
Training Hits@10: 76.03
Validation Hits@10: 69.67
Testing Hits@10: 65.86


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.11it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.47571
Training Hits@10: 76.85
Validation Hits@10: 69.69
Testing Hits@10: 66.21


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.16it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:15,  5.89it/s]

Training CE Loss: 5.46896
Training Hits@10: 76.23
Validation Hits@10: 69.12
Testing Hits@10: 65.43


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.20it/s]
  1%|▊                                                                                  | 1/95 [00:00<00:16,  5.59it/s]

Training CE Loss: 5.46607
Training Hits@10: 76.64
Validation Hits@10: 70.00
Testing Hits@10: 66.41


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.16it/s]
  0%|                                                                                           | 0/95 [00:00<?, ?it/s]

Training CE Loss: 5.47178
Training Hits@10: 75.79
Validation Hits@10: 70.63
Testing Hits@10: 66.61


100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [00:13<00:00,  7.16it/s]


Training CE Loss: 5.47857
Training Hits@10: 76.85
Validation Hits@10: 70.56
Testing Hits@10: 66.24
Maximum Training Hit@10: 76.08
Maximum Validation Hit@10: 71.03
Maximum Testing Hit@10: 65.53


In [28]:
# ------------------Training Initialization----------------------#
max_train_hit = 0
max_val_hit = 0
max_test_hit = 0

#device = 'cpu'

for epoch in range(num_epochs):
    print("="*20,"Epoch {}".format(epoch+1),"="*20)
    
    model.train()  
    
    running_loss = 0

    for j,data in enumerate(tqdm(train_dl,position=0,leave=True)):
        
        if train_method != "normal":
            optimizer_features.zero_grad()
            optimizer_ids.zero_grad()
            
        elif train_method == "normal": 
            optimizer.zero_grad()
            if model_type == "attention":
                decoder_optimizer.zero_grad()
        
        if genre_dim != 0:            
            inputs,genre_inputs,labels,x_lens,uid = data
            outputs = model(x=inputs.cuda(),x_lens=x_lens.squeeze().tolist(),x_genre=genre_inputs.cuda())
        
        elif genre_dim == 0:
            # input = B * max seq
            # labels = B * max seq
            # x_lens = B, length of each input
            
            inputs,labels,x_lens,uid = data
            # pseudo code
            # iterate thru inputs to generate encoder outputs (first for loop)
            # use encoder outputs to create initial hidden context vector, use SOS token to create initial decoder input
            # use second for loop to generate M number of outputs, then save those to the outputs variable
            # questions:
            # how should I iterate thru the given data set
            # how many outputs should I use?
            # format of input / output is useful
            
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            encoder_outputs, hidden_states = model(x=inputs.to(device),x_lens=x_lens.squeeze().tolist())
            decoder_inputs = inputs[:,0].view(1,-1).to(device)
            decoder_hidden = hidden_states
            
            outputs = torch.zeros(inputs.size()[0],max_length,output_dim, device=device)
            
            for i in range(max_length):
#                 print("Step {}".format(i))
#                 print("Decoder Input Dimension: ")
#                 print(decoder_inputs.size())
#                 print("Decoder Hidden States Dimension: ")
#                 print(decoder_hidden.size())
                decoder_outputs, decoder_hidden, decoder_attention = modelD(decoder_inputs, decoder_hidden, encoder_outputs)
                outputs[:,i,:] = decoder_outputs
                decoder_inputs = labels[:,i].view(1,-1)
#                 print(decoder_inputs)
    
        
            
       
        if tied:
            outputs_ignore_pad = outputs[:,:,:-1]
            loss = loss_fn(outputs_ignore_pad.view(-1,outputs_ignore_pad.size(-1)),labels.view(-1).cuda())
            
        else:
            loss = loss_fn(outputs.view(-1,outputs.size(-1)),labels.view(-1).cuda())
            
        loss.backward()
        
        if train_method != "normal":
            if train_method == "interleave":
                # interleave on the epochs
                if (j+1) % 2 == 0:
                    optimizer_features.step()
                else:
                    optimizer_ids.step()

            elif train_method == "alternate":
                if (epoch+1) % 2 == 0:
                    optimizer_features.step()
                else:
                    optimizer_ids.step()

        elif train_method == "normal":
            optimizer.step()

        running_loss += loss.detach().cpu().item()

    del outputs
    torch.cuda.empty_cache()
    training_hit = Recall_Object(model,train_dl,"train")
    validation_hit = Recall_Object(model,val_dl,"validation")
    testing_hit = Recall_Object(model,test_dl,"test")
    
    if max_val_hit < validation_hit:
        max_val_hit = validation_hit
        max_test_hit = testing_hit
        max_train_hit = training_hit
    
    torch.cuda.empty_cache()
    print("Training CE Loss: {:.5f}".format(running_loss/len(train_dl)))
    print("Training Hits@{:d}: {:.2f}".format(k,training_hit))
    print("Validation Hits@{:d}: {:.2f}".format(k,validation_hit))
    print("Testing Hits@{:d}: {:.2f}".format(k,testing_hit))


print("="*100)
print("Maximum Training Hit@{:d}: {:.2f}".format(k,max_train_hit))
print("Maximum Validation Hit@{:d}: {:.2f}".format(k,max_val_hit))
print("Maximum Testing Hit@{:d}: {:.2f}".format(k,max_test_hit))

  0%|                                                                                           | 0/95 [00:00<?, ?it/s]



100%|██████████████████████████████████████████████████████████████████████████████████| 95/95 [02:24<00:00,  1.52s/it]


TypeError: tuple indices must be integers or slices, not tuple

In [None]:
print("="*100)
print("Maximum Training Hit@{:d}: {:.2f}".format(k,max_train_hit))
print("Maximum Validation Hit@{:d}: {:.2f}".format(k,max_val_hit))
print("Maximum Testing Hit@{:d}: {:.2f}".format(k,max_test_hit))

In [30]:
# input and label
x = [5,3,10,11]
y = [3,10,11,13]
print(torch.cuda.is_available())
print(y[0:5])
# notes: gru layer stores hidden layer when using sequence input
# use final hidden state from packed output
# cross entropy loss used


True
[3, 10, 11, 13]
