In [62]:
import sys

sys.path.append("../scripts")
from AttentionTransformer.Encoder import Encoder
from bert4rec_dataset import Bert4RecDataset
import torch.nn.modules as nn

In [160]:
class RecModel(nn.Module):
    def __init__(self,
                 vocab_size,
                 heads=4,
                 layers=6,
                 emb_dim=512,
                 pad_id=0,
                 num_pos=120):
        super().__init__()
        self.emb_dim = emb_dim
        self.pad_id = pad_id
        self.num_pos = num_pos
        self.vocab_size = vocab_size
        self.channel_dim = num_pos * emb_dim
        self.encoder = Encoder(source_vocab_size=vocab_size,
                               emb_dim=emb_dim,
                               layers=layers,
                               heads=heads,
                               dim_key=emb_dim,
                               dim_value=emb_dim,
                               dim_model=emb_dim,
                               dim_inner=emb_dim * 2,
                               pad_id=pad_id,
                               num_pos=num_pos)
        self.lin_op = nn.Linear(512, self.vocab_size)

    def forward(self, x):
        bs = x.size(0)
        x = self.encoder(x, None)
        x = x.view(-1, x.size(2))
        print(x.size())
        x = self.lin_op(x)
        x = x.view(bs,int(x.size(0)/bs), -1)
        return x


In [161]:
model = RecModel(vocab_size = 9725)

In [162]:
import pandas as pd
from torch.utils.data import DataLoader

data = pd.read_csv("../data/ratings_mapped.csv")
ds = Bert4RecDataset(data_csv=data,
                     group_by_col="userId",
                     data_col="movieId_mapped")
dl = DataLoader(ds, batch_size=2, shuffle=True)
tnsr = next(iter(dl))

In [163]:
tnsr.keys()

dict_keys(['source', 'target', 'source_mask', 'target_mask'])

In [164]:
tnsr["source"].size()

torch.Size([2, 120])

In [165]:
tnsr["source_mask"].size()

torch.Size([2, 120])

In [166]:
op = model(tnsr["source"])

torch.Size([240, 512])


In [167]:
op.size()

torch.Size([2, 120, 9725])

In [168]:
import torch

In [169]:
src = tnsr["source"]

In [170]:
mask = src == 1

In [171]:
import torch

In [177]:
_, predicted = op.max(2)

In [178]:
op.size()

torch.Size([2, 120, 9725])

In [179]:
src.size()

torch.Size([2, 120])

In [180]:
predicted.size()

torch.Size([2, 120])

In [181]:
y_true = torch.masked_select(src, mask)
predicted = torch.masked_select(predicted, mask)
acc = (y_true.view(-1) == predicted).double().mean()


In [182]:
acc

tensor(0., dtype=torch.float64)

In [183]:
import torch.nn.functional as F

In [184]:
op.view(-1, op.size(2)).size()

torch.Size([240, 9725])

In [185]:
trg = tnsr["target"]

In [186]:
trg.view(-1).size()

torch.Size([240])

In [188]:
y_pred = op.view(-1, op.size(2))
y_true = trg.view(-1)

In [189]:
y_pred.size(), y_true.size()

(torch.Size([240, 9725]), torch.Size([240]))

In [190]:
loss = F.cross_entropy(op.view(-1, op.size(2)), trg.view(-1), reduction="none")

In [196]:
loss = loss * mask.view(-1)

In [197]:
loss.sum() / (mask.sum() + 1e-8)

tensor(9.3686, grad_fn=<DivBackward0>)

In [193]:
loss.size()

torch.Size([240])

In [194]:
mask.size()

torch.Size([2, 120])