In [12]:
%load_ext autoreload
%autoreload 2


from src.model.configs import CONFIG_IN_IMPLEMENTATION, CONFIG_IN_PAPER, MINIMAL_CONFIG
from src.model.pprec import PPRec
from src.data.dataset import EBNeRDTrainDataset
from src.data.split import EBNeRDSplit
from src.training.loss import BPRPairwiseLoss


import torch
from torch.utils.data import DataLoader
import numpy as np

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Some data for testing


In [55]:
split = EBNeRDSplit()
dataset = EBNeRDTrainDataset(split=split)

Configuration of the model

In [49]:
max_clicked = 50

model = PPRec(
    max_clicked=max_clicked,
    device=torch.device("cpu"),
    # config=CONFIG_IN_IMPLEMENTATION,
    # config=CONFIG_IN_PAPER
    config=MINIMAL_CONFIG
)

In [50]:
batch_size = 32
candidate_size = 4

candidates = PPRec.CandidateBatch(
    ids=np.array(
        [
            [dataset.split.get_random_article_id() for _ in range(candidate_size)]
            for _ in range(batch_size)
        ]
    ),
    ctr=torch.rand(batch_size, candidate_size),
    recencies=torch.rand(batch_size, candidate_size),
)

user_clicks = PPRec.ClicksBatch(
    ids=np.array(
        [
            [dataset.split.get_random_article_id() for _ in range(max_clicked)]
            for _ in range(batch_size)
        ]
    ),
    ctr=torch.rand(batch_size, max_clicked),
)

In [51]:
# Works with random inputs!
result = model(PPRec.Inputs(clicks=user_clicks, candidates=candidates))
result.personalized_matching_score.shape, result.popularity_score.shape, result.score.shape

(torch.Size([32, 4]), torch.Size([32, 4]), torch.Size([32, 4]))

In [52]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

178416

In [53]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=dataset.collate_fn)
loss = BPRPairwiseLoss()

In [54]:
# Works with dataloader and loss!
for batch in dataloader:
    inputs = loss.preprocess_train_batch(batch, max_clicked=max_clicked)
    predictions = model(inputs)
    loss_value = loss(predictions)
    loss_value.backward()
    break