In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime
from utils.data_reader import amazon_dataset_iters
import torch.utils.tensorboard as tb
from os import path
from models.nrt import NRT
import utils.constants as constants
from utils.loss import mask_nll_loss, review_loss
import math
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate import bleu_score

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Loading the dataset
dataset_folder = './data/Musical_Instruments_5/'
text_vocab, tips_vocab, train_iter, val_iter, test_iter = (
    amazon_dataset_iters(dataset_folder)
)



Loading datasets...




datasets loaded
item vocab built
user vocab built
text vocab built
tips vocab built




In [4]:
# Count user and item number
items_count = int(max([i.item.max().cpu().data.numpy() for i in train_iter] + [i.item.max().cpu().data.numpy() for i in test_iter]))
users_count = int(max([i.user.max().cpu().data.numpy() for i in train_iter] + [i.user.max().cpu().data.numpy() for i in test_iter]))
vocab_size = len(text_vocab.itos)



In [5]:
# Load model
model = NRT(
        users_count + 2,
        items_count + 2,
        constants.EBD_SIZE,
        constants.RATER_MLP_SIZES,
        constants.HIDDEN_DIM,
        vocab_size,
        constants.WORD_LF_NUM,
        constants.TG_HIDDEN_LAYERS,
        constants.DROPOUT_RATE,
        constants.RNN_TYPE,
    )

In [6]:
model.to(device)
alpha = constants.RR_LOSS_WEIGHT
beta = constants.WG_LOSS_WEIGHT
gamma = constants.TG_LOSS_WEIGHT
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=constants.REG_WEIGHT)
pad_idx = tips_vocab.stoi['<pad>']

In [37]:
model.eval()
idx = 0
rmse_loss = 0.0
num_ratings = 0
for valid_batch in val_iter:
    tips = valid_batch.tips
    rate_output, tips_output, wd_output = model(valid_batch, tips[:-1], tf_rate=0)
    # compute rmse
#     rmse_loss += F.mse_loss(rate_output, valid_batch.rating, reduction='sum').item()
#     num_ratings += valid_batch.rating.shape[0]
    break

input sequence shape:torch.Size([21, 2])
init_hidden shape:torch.Size([1, 2, 400])
No teacher forcing
max_length:21
decoder_var:tensor([[2, 2]])
output_step: tensor([[[ -9.8530,  -9.9549,  -9.9495,  ...,  -9.9994,  -9.7837,  -9.8035],
         [ -9.8580,  -9.9588,  -9.9538,  ..., -10.0080,  -9.8057,  -9.8045]]],
       grad_fn=<LogSoftmaxBackward>)
output_step shape: torch.Size([1, 2, 19598])
decoder_var:tensor([[1301, 1301]])
decoder_var shape:torch.Size([1, 2])
output_step: tensor([[[-9.6105, -9.8587, -9.9511,  ..., -9.9865, -9.9172, -9.8148],
         [-9.6135, -9.8633, -9.9563,  ..., -9.9924, -9.9279, -9.8150]]],
       grad_fn=<LogSoftmaxBackward>)
output_step shape: torch.Size([1, 2, 19598])
decoder_var:tensor([[7093, 7093]])
decoder_var shape:torch.Size([1, 2])
output_step: tensor([[[ -9.8132, -10.0887, -10.0526,  ..., -10.1143, -10.1154,  -9.7452],
         [ -9.8143, -10.0920, -10.0561,  ..., -10.1177, -10.1202,  -9.7449]]],
       grad_fn=<LogSoftmaxBackward>)
output_step sha

In [38]:
tips_output.shape

torch.Size([21, 2, 19598])

In [10]:
_, generate_idx = tips_output.max(2)
print(generate_idx)

tensor([[ 1301,  1301],
        [ 7093,  7093],
        [ 5974,  5974],
        [11074, 11074],
        [ 3273,  3273],
        [12537, 12537],
        [13256, 13256],
        [ 4673,  4673],
        [ 6902,  6902],
        [10951, 10951],
        [ 3981,  3981],
        [ 2002,  2002],
        [ 3936,  3936],
        [ 6023,  6023],
        [13628, 13628],
        [ 7511,  7511],
        [ 9529,  9529],
        [  756,   756],
        [ 6512,  6512],
        [ 8548,  8548],
        [ 3228,  3228]])


In [13]:
gts = None
for valid_batch in val_iter:
    tips = valid_batch.tips
    gts = tips
    print(tips[1:])
    print(tips.shape)
    break

tensor([[  29,   35],
        [  65,  633],
        [  12, 1722],
        [   5,  404],
        [  64,   65],
        [   3,    3],
        [   1,    1],
        [   1,    1],
        [   1,    1],
        [   1,    1],
        [   1,    1],
        [   1,    1],
        [   1,    1],
        [   1,    1],
        [   1,    1],
        [   1,    1],
        [   1,    1],
        [   1,    1],
        [   1,    1],
        [   1,    1],
        [   1,    1]])
torch.Size([22, 2])


In [22]:
# batch first
gts = torch.transpose(gts, 0, 1)

In [23]:
print(gts)

tensor([[   2,   29,   65,   12,    5,   64,    3,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1],
        [   2,   35,  633, 1722,  404,   65,    3,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1]])


In [24]:
sentence_token = []
for token_ids in gts:
    current_sen = [tips_vocab.itos[id] for id in token_ids.detach().numpy()]
    sentence_token.append(current_sen)

In [30]:
generate_idx = torch.transpose(generate_idx, 0, 1)

In [31]:
generate_idx

tensor([[ 1301,  7093,  5974, 11074,  3273, 12537, 13256,  4673,  6902, 10951,
          3981,  2002,  3936,  6023, 13628,  7511,  9529,   756,  6512,  8548,
          3228],
        [ 1301,  7093,  5974, 11074,  3273, 12537, 13256,  4673,  6902, 10951,
          3981,  2002,  3936,  6023, 13628,  7511,  9529,   756,  6512,  8548,
          3228]])

In [32]:
sentence_generate = []
for token_ids in generate_idx:
    current_sen = [tips_vocab.itos[id] for id in token_ids.detach().numpy()]
    sentence_generate.append(current_sen)

In [33]:
sentence_generate

[['service',
  'shelling',
  'sdhc',
  'ampwas',
  'ratings',
  'dbm',
  'es-57s',
  'organ',
  'outfit',
  'agreat',
  'playback',
  'depends',
  'k',
  'stomping',
  'fluctuations',
  '9vdc',
  'uh',
  'quiet',
  'drumming',
  'klon',
  'hammer'],
 ['service',
  'shelling',
  'sdhc',
  'ampwas',
  'ratings',
  'dbm',
  'es-57s',
  'organ',
  'outfit',
  'agreat',
  'playback',
  'depends',
  'k',
  'stomping',
  'fluctuations',
  '9vdc',
  'uh',
  'quiet',
  'drumming',
  'klon',
  'hammer']]

In [36]:
bleu_score.sentence_bleu(sentence_token[1], sentence_generate[1], weights=[1.0,0.0,0.0,0.0])

0