# Prepare environment

In [1]:
# allow "hot-reloading" of modules
%load_ext autoreload
%autoreload 2
# needed for inline plots in some contexts
%matplotlib inline

import os

os.environ["REPO"] = "caser-pytorch"
os.environ["BRANCH"] = "main"
os.environ["GHPROFILE"] = "peterleviant"
os.environ["GHTOKEN"] = ""
os.environ["WORKDIR"] = "."

In [2]:
# codeblock adapted from https://github.com/the-full-stack/fsdl-text-recognizer-2022-labs
if "bootstrap" not in locals() or bootstrap.run:
    # path management for Python
    pythonpath, = !echo $PYTHONPATH
    if "." not in pythonpath.split(":"):
        pythonpath = ".:" + pythonpath
        %env PYTHONPATH={pythonpath}
        !echo $PYTHONPATH

    # get both Colab and local notebooks into the same state
    !wget --quiet https://raw.githubusercontent.com/peterleviant/dev-utils/main/bootstrap.py -O bootstrap.py
    import bootstrap

    bootstrap.run = False  # change to True re-run setup

!pwd
%ls

env: PYTHONPATH=.:/env/python
.:/env/python
/content/caser-pytorch
LICENSE  README.md  [0m[01;34mrequirements[0m/


# CASER recommendation system implementation

Lets start with an initial e2e training sketch on a small dataset

### Dataloader

In [43]:
import pytorch_lightning as pl
import zipfile
import requests
import torch
import random
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view
import pandas as pd

class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, users, sequences, target_items, nonnegative_samples, all_items, num_negative_samples=3):
        self.users_pandas_series = users
        self.U = torch.LongTensor(users) # N - number of interactions
        self.I = torch.LongTensor(sequences) # N x L
        self.T = torch.LongTensor(target_items) # N x T
        self.all_items = all_items #todo - try with padding in the negatives
        self.nonnegative_samples = nonnegative_samples.to_dict() # dict - U x {seen per user in train}
        self.negative_samples = torch.zeros((self.U.size(0), num_negative_samples), dtype=torch.int64) # N x num_negative_samples

    def generate_negative_samples(self):
        num_negative_samples = self.negative_samples.size(-1)
        sample_counts = (pd.Series(self.users_pandas_series).value_counts(sort=False) * num_negative_samples)
        negatives = sample_counts.rename("samples_per_user").groupby(sample_counts.index).apply(
            lambda df: random.choices(list(self.all_items - self.nonnegative_samples[df.name]), k=df.iloc[0])
        ).explode()
        self.negative_samples = torch.LongTensor(negatives.values.astype(int).reshape(-1,num_negative_samples))

    def __getitem__(self, idx):
        return self.U[idx], self.I[idx], self.T[idx], self.negative_samples[idx]

    def __len__(self):
        return self.U.size(0)

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, users, sequences, target_items, new_items, rated_items, total_num_items):
        """
        L - number of items per sequence
        U - users vector - N
        I - items vector - N x (L + 1), where the last column is the target
        """
        self.U = torch.LongTensor(users) # U
        self.I = torch.LongTensor(sequences) # U x L
        self.test_items = target_items.apply(list) # pd.DataFrame(U x Items)
        self.new_items = new_items.apply(list) # pd.DataFrame(U x {new per user})
        self.rated_items = rated_items.apply(list) # pd.DataFrame(U x Items)
        self.total_num_items = total_num_items

    def __getitem__(self, idx):
        test_items = torch.zeros(self.total_num_items)
        test_items[self.test_items.iloc[idx]] = 1
        new_items = torch.zeros(self.total_num_items)
        new_items[self.new_items.iloc[idx]] = 1
        rated_items = torch.zeros(self.total_num_items)
        rated_items[self.rated_items.iloc[idx]] = 1
        return self.U[idx], self.I[idx], test_items, new_items, rated_items

    def __len__(self):
        return self.U.size(0)

class InteractionsDataModuleBase(pl.LightningDataModule):
    DATA_DIR = "data"
    DATASET_NAME = "default"
    TRAIN_SPLIT = 0.8
    TEST_SPLIT = 0
    VAL_SPLIT = 0.2

    def __init__(self, L, T, data_dir: str = DATA_DIR, train_split: float = TRAIN_SPLIT,
                 val_split: float = VAL_SPLIT, test_split: float = TEST_SPLIT,
                 batch_size: int = 32, cold_start_n: int = 5, num_negative_samples: int = 3, num_workers=0
                 ):
        super().__init__()
        self.L = L
        self.T = T
        self.batch_size = batch_size
        self.data_dir = data_dir
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.cold_start_n = cold_start_n
        self.user_id_to_idx = {}
        self.item_id_to_idx = {}
        self.stage = "uninitialized"
        self.num_negative_samples = num_negative_samples
        self.num_workers = num_workers

    def setup(self, stage: str = "fit"):
        if stage == "fit":  # other stages: "test", "predict"
            self._setup_train_test_split(self.train_split, self.val_split)
        elif stage == "test":
            self._setup_train_test_split(self.train_split + self.val_split, self.test_split)
        else:
            raise ValueError(f"Unknown stage: {stage}")
        self.stage = stage

    def train_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:
        self.train_dataset.generate_negative_samples()
        return torch.utils.data.DataLoader(
            self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers
            )

    def val_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:
        # there is no seperate validation data loader because when testing after finding hp we train on the train and val data together
        return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self: pl.LightningDataModule) -> torch.utils.data.DataLoader:
        return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def __repr__(self):
        return f"{self.__class__.__name__}(stage={self.stage}, L={self.L}, batch_size={self.batch_size},\
         distinct users |U|={self.num_users}, distinct items |I|={self.num_items} (0 = padding), cold_start_n={self.cold_start_n})"

    def __str__(self):
        return self.__repr__()

    def _add_counts_to_interactions(self, interactions):
        interactions['interactions_per_user'] = interactions.groupby('user_id').user_id.transform('count')
        interactions['interactions_ix_for_user'] = interactions.groupby('user_id').user_id.cumcount()
        return interactions

    def _return_non_cold_start_users_and_items(self, interactions):
        prev_len = len(interactions)
        while True:
            interactions = interactions[interactions.groupby('item_id').item_id.transform('count') >= self.cold_start_n]
            interactions = interactions[interactions.groupby('user_id').user_id.transform('count') >= self.cold_start_n]
            if len(interactions) == prev_len:
                break
            prev_len = len(interactions)
        interactions.reset_index(drop=True, inplace=True)
        items = list(interactions.item_id.drop_duplicates())
        users = list(interactions.user_id.drop_duplicates())
        return users, items

    def _translate_platform_id_to_idx_(self, interactions, user_id_to_idx, item_id_to_idx):
        interactions.item_id = interactions.item_id.map(item_id_to_idx)
        interactions.user_id = interactions.user_id.map(user_id_to_idx)
        return interactions

    def _split_interactions_into_train_test(self, train_split, test_split):
        train_ix = self.interactions.interactions_ix_for_user < train_split * self.interactions.interactions_per_user
        test_ix = self.interactions.interactions_ix_for_user >= train_split * self.interactions.interactions_per_user
        test_ix &= self.interactions.interactions_ix_for_user < (train_split+test_split) * self.interactions.interactions_per_user

        train_interactions = self.interactions[train_ix]
        test_interactions = self.interactions[test_ix]
        return train_interactions, test_interactions

    def _filter_out_cold_start_users_and_items(self, interactions):
        valid_users, valid_items = self._return_non_cold_start_users_and_items(interactions)
        users_set, items_set = set(valid_users), set(valid_items)
        interactions = interactions[
            (interactions.user_id.isin(valid_users)) & (interactions.item_id.isin(valid_items))
        ].reset_index(drop=True).copy()
        return interactions, valid_users, valid_items

    def _generate_translations_from_platform_id_to_idx(self, users, items, padding_item_platform_id = -1):
        items = [padding_item_platform_id] + items # add padding item
        user_id_to_idx = {user_id: idx for idx, user_id in enumerate(users)}
        item_id_to_idx = {item_id: idx for idx, item_id in enumerate(items)}
        return user_id_to_idx, item_id_to_idx

    def _translate_platform_id_to_idx(self, train_interactions, test_interactions):
        valid_train_users = list(train_interactions.user_id.drop_duplicates())
        valid_train_items = list(train_interactions.item_id.drop_duplicates())
        valid_test_users = list(set(test_interactions.user_id.drop_duplicates()) - set(valid_train_users))
        valid_test_items = list(set(test_interactions.item_id.drop_duplicates()) - set(valid_train_items))
        valid_items = valid_train_items + valid_test_items
        valid_users = valid_train_users + valid_test_users
        user_id_to_idx, item_id_to_idx = self._generate_translations_from_platform_id_to_idx(valid_users, valid_items)
        train_interactions = self._translate_platform_id_to_idx_(train_interactions.copy(), user_id_to_idx, item_id_to_idx)
        test_interactions = self._translate_platform_id_to_idx_(test_interactions.copy(), user_id_to_idx, item_id_to_idx)
        return train_interactions, test_interactions, user_id_to_idx, item_id_to_idx

    def _prepare_interactions_data(self, interactions):
        interactions, valid_users, valid_items = self._filter_out_cold_start_users_and_items(interactions)
        interactions = self._add_counts_to_interactions(interactions)
        return interactions

    def _setup_train_test_split(self, train_split, test_split):
        train_interactions, test_interactions = self._split_interactions_into_train_test(train_split, test_split)
        train_interactions, test_interactions, self.user_id_to_idx, self.item_id_to_idx = self._translate_platform_id_to_idx(train_interactions, test_interactions)

        def to_train_sequence(df, L, T):
            items = df.values
            window_size = L + T
            if len(items) < window_size:
                items = np.expand_dims(np.concatenate([np.zeros(window_size - len(items)), items]), axis=0)
            else:
                items = sliding_window_view(items, window_size)
            return pd.DataFrame(items)

        def to_test_sequence(df, L):
            items = df.values
            window_size = L
            if len(items) < window_size:
                items = np.concatenate([np.zeros(window_size - len(items)), items])
            items = np.expand_dims(items[-L:], axis=0)
            return pd.DataFrame(items)

        items_seen_in_training_per_user = train_interactions.groupby('user_id').item_id.apply(set)
        items_seen_in_test_per_user = test_interactions.groupby('user_id').item_id.apply(set)

        train_df = train_interactions.groupby('user_id').item_id.apply(lambda df: to_train_sequence(df, self.L, self.T)).reset_index()
        self.train_dataset = TrainDataset(
            users=train_df['user_id'].values,
            sequences=train_df[range(self.L)].values,
            target_items=train_df[range(self.L, self.L+self.T)].values,
            nonnegative_samples=items_seen_in_training_per_user,
            all_items= set(self.item_id_to_idx.values()) - set([0]),
            num_negative_samples=self.num_negative_samples
        )

        test_df = train_interactions.groupby('user_id').item_id.apply(lambda df: to_test_sequence(df, self.L)).reset_index()
        test_df['target_items'] = items_seen_in_test_per_user
        test_df['rated_items'] = items_seen_in_training_per_user

        def set_difference(row):
            return row['target_items'] - row['rated_items']

        # Apply the function to each row
        test_df['new_items'] = test_df.apply(lambda row: set_difference(row), axis=1)

        self.test_dataset = TestDataset(
            users=list(test_df['user_id']),
            sequences=test_df[range(self.L)].values,
            target_items=test_df['target_items'],
            new_items=test_df['new_items'],
            rated_items=test_df['rated_items'],
            total_num_items=self.num_items
        )

    @property
    def num_users(self):
        return len(self.user_id_to_idx)

    @property
    def num_items(self):
        return len(self.item_id_to_idx)

class MovielensDataModule(InteractionsDataModuleBase):
    VERSION = "ml-1m"
    DATASET_NAME = "movielens"
    def __init__(self, L, version: str = VERSION, **kwargs):
        super().__init__(L, **kwargs)
        self.version = version

    def _download_data(self):
        if not os.path.exists(os.path.join(self.data_dir, self.DATASET_NAME)):
            os.makedirs(os.path.join(self.data_dir, self.DATASET_NAME))
        url = f"https://files.grouplens.org/datasets/movielens/{self.version}.zip"
        zip_file_path = os.path.join(self.data_dir, self.DATASET_NAME ,"movielens.zip")

        response = requests.get(url)
        with open(zip_file_path, "wb") as f:
            f.write(response.content)

        with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
            zip_ref.extractall(os.path.join(self.data_dir, self.DATASET_NAME))
        os.remove(zip_file_path)

        self.users_file = os.path.join(self.data_dir, self.DATASET_NAME, self.version, "users.dat")
        self.movies_file = os.path.join(self.data_dir, self.DATASET_NAME, self.version, "movies.dat")
        self.ratings_file = os.path.join(self.data_dir, self.DATASET_NAME, self.version, "ratings.dat")
        self.readme_file = os.path.join(self.data_dir, self.DATASET_NAME, self.version, "README")

    def _read_raw_data(self):
        with open(self.ratings_file, "r") as f:
            interactions = pd.read_csv(f, sep="::", names=["user_id", "movie_id", "rating", "timestamp"], engine="python")
        interactions = interactions.sort_values(["user_id","timestamp"]).rename(columns={"movie_id": "item_id"})[['user_id','item_id']]
        interactions.reset_index(drop=True, inplace=True)
        return interactions, None

    def prepare_data(self):
        self._download_data()
        interactions, self.metadata = self._read_raw_data()
        self.interactions = self._prepare_interactions_data(interactions)

def set_seed(seed, cuda=False):
    np.random.seed(seed)
    random.seed(seed)
    if cuda:
        torch.cuda.manual_seed(seed)
    else:
        torch.manual_seed(seed)

### Metrics

In [4]:
from torchmetrics import Metric
from typing import Optional

class TopNRecall(Metric):
    is_differentiable: Optional[bool] = False
    higher_is_better: Optional[bool] = True
    full_state_update: bool = False
    plot_lower_bound: float = 0.0
    plot_upper_bound: float = 1.0

    def __init__(self, top_n, **kwargs):
        super().__init__(**kwargs)
        self.top_n = top_n
        self.add_state("recall", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        """
        preds: tensor of shape (B, C) where B is the batch size and C is the number of classes.
        target: tensor of shape (B, C) with binary values indicating the true labels.
        """
        # Get the top N predictions
        if preds.shape != target.shape:
            raise ValueError("preds and target must have the same shape")
        top_n_preds = torch.zeros_like(preds, dtype=torch.int)
        top_n_indices = torch.topk(preds, self.top_n, dim=1).indices
        top_n_preds.scatter_(1, top_n_indices, 1)

        intersection = (top_n_preds * target).sum(dim=1).float()
        true_labels_count = target.sum(dim=1).float()
        recall = torch.where(true_labels_count == 0, torch.tensor(0.0, device=preds.device), intersection / true_labels_count)

        # Update state
        self.recall += recall.sum()
        self.total += target.size(0)

    def plot(self, val = None, ax = None):
        return self._plot(val, ax)

    def compute(self):
        return self.recall / self.total

class TopNRecall(Metric):
    is_differentiable: Optional[bool] = False
    higher_is_better: Optional[bool] = True
    full_state_update: bool = False
    plot_lower_bound: float = 0.0
    plot_upper_bound: float = 1.0

    def __init__(self, top_n, **kwargs):
        super().__init__(**kwargs)
        self.top_n = top_n
        self.add_state("recall", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        """
        preds: tensor of shape (B, C) where B is the batch size and C is the number of classes.
        target: tensor of shape (B, C) with binary values indicating the true labels.
        """
        # Get the top N predictions
        if preds.shape != target.shape:
            raise ValueError("preds and target must have the same shape")
        top_n_preds = torch.zeros_like(preds, dtype=torch.int)
        top_n_indices = torch.topk(preds, self.top_n, dim=1).indices
        top_n_preds.scatter_(1, top_n_indices, 1)

        intersection = (top_n_preds * target).sum(dim=1).float()
        true_labels_count = target.sum(dim=1).float()
        recall = torch.where(true_labels_count == 0, torch.tensor(0.0, device=preds.device), intersection / true_labels_count)

        # Update state
        self.recall += recall.sum()
        self.total += target.size(0)

    def plot(self, val = None, ax = None):
        return self._plot(val, ax)

    def compute(self):
        return self.recall / self.total

class TopNPrecision(Metric):
    is_differentiable: Optional[bool] = False
    higher_is_better: Optional[bool] = True
    full_state_update: bool = False
    plot_lower_bound: float = 0.0
    plot_upper_bound: float = 1.0

    def __init__(self, top_n, **kwargs):
        super().__init__(**kwargs)
        self.top_n = top_n
        self.add_state("precision", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        """
        preds: tensor of shape (B, C) where B is the batch size and C is the number of classes.
        target: tensor of shape (B, C) with binary values indicating the true labels.
        """
        # Get the top N predictions
        if preds.shape != target.shape:
            raise ValueError("preds and target must have the same shape")
        top_n_preds = torch.zeros_like(preds, dtype=torch.int)
        top_n_indices = torch.topk(preds, self.top_n, dim=1).indices
        top_n_preds.scatter_(1, top_n_indices, 1)

        precision = (top_n_preds * target).sum(dim=1).float() / self.top_n

        # Update state
        self.precision += precision.sum()
        self.total += target.size(0)

    def plot(self, val = None, ax = None):
        return self._plot(val, ax)

    def compute(self):
        return self.precision / self.total

class MeanAveragePrecision(Metric):
    is_differentiable: Optional[bool] = False
    higher_is_better: Optional[bool] = True
    full_state_update: bool = False
    plot_lower_bound: float = 0.0
    plot_upper_bound: float = 1.0

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.add_state("average_precision", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, targets: torch.Tensor):
        """
        preds: tensor of shape (B, C) where B is the batch size and C is the number of classes.
        target: tensor of shape (B, C) with binary values indicating the true labels.
        """
        # Get the top N predictions
        if preds.shape != targets.shape:
            raise ValueError("preds and targets must have the same shape")

        ranks = torch.argsort(torch.argsort(-preds,dim=1)) + 1
        ranks_of_targets = ranks * targets
        ranks_of_targets[ranks_of_targets == 0] = np.inf
        ranks_of_targets[ranks == -np.inf] = np.inf
        internal_ranks_of_targets = ranks_of_targets.clone()
        internal_ranks_of_targets[ranks_of_targets == 0] = np.inf
        internal_ranks_of_targets = torch.argsort(torch.argsort(internal_ranks_of_targets,dim=1)) + 1
        q = (internal_ranks_of_targets / ranks_of_targets)
        q[ranks_of_targets == np.inf] = 0

        # Update state
        self.average_precision += (q.sum(dim=1) / targets.sum(dim=1)).sum()
        self.total += targets.size(0)

    def plot(self, val = None, ax = None):
        return self._plot(val, ax)

    def compute(self):
        return self.average_precision / self.total

### Model

In [92]:
from typing import List
import torch.nn.functional as F
from torchmetrics import MeanMetric

activation_getter = {'iden': lambda x: x, 'relu': F.relu, 'tanh': torch.tanh, 'sigm': torch.sigmoid}

class CASER(torch.nn.Module):
    def __init__(self, num_users, num_items, d, L, F_h, F_v, drop_ratio, ac_conv = 'relu', ac_fc = 'relu'):
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.d = d
        self.L = L
        self.F_h = F_h
        self.F_v = F_v
        self.drop_ratio = drop_ratio

        self.user_embedding = torch.nn.Embedding(num_users, d) # ~1/d
        self.item_embedding = torch.nn.Embedding(num_items, d) # ~1/d

        self.vertical_conv = torch.nn.Conv2d(1, F_v, (L, 1)) # ~1/L
        self.horizontal_convs = torch.nn.ModuleList([ # ~1/h*d
            torch.nn.Conv2d(1, F_h, kernel_size=(h, d))
            for h in range(1, L+1)
        ])
        self.d = d
        self.ac_conv = torch.nn.ReLU()
        self.ac_fc = torch.nn.ReLU()

        self.dropout = torch.nn.Dropout(drop_ratio)

        self.sequence_fc = torch.nn.Linear(F_v * d + F_h * L, d)
        self.output_fc = torch.nn.Linear(2 * d, num_items)

        torch.nn.init.kaiming_normal_(self.user_embedding.weight, mode='fan_out', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.item_embedding.weight, mode='fan_out', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.output_fc.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.zeros_(self.output_fc.bias)


    def forward(self, U, I):
        B = U.size(0)

        user_embs = self.user_embedding(U) # B x d                  ~d x {1/d}
        item_embs = self.item_embedding(I).unsqueeze(1) # N x 1 x L x d                 ~d x {1/d}
        o_v = self.vertical_conv(item_embs).reshape(B, -1) # ~d x F_v x {(L x 1/ L) x 1 / d} = d x F_v x {1/d}
        o_h = []

        for conv in self.horizontal_convs:
            conv_out = self.ac_conv(conv(item_embs)).squeeze(-1) # ~F_h x (L + 1 - h) x {h * d * 1/h * 1/d x 1/d}
            pool_out = F.max_pool1d(conv_out, conv_out.size(-1)).squeeze(-1) #~F_h x {1/d}
            o_h.append(pool_out)
        o_h = torch.cat(o_h, dim = 1) # F_h x L x {1/d}

        o = torch.cat([o_v, o_h], dim = 1) # (F_h x L + F_v x d) x {1/d}
        o = self.dropout(o)

        z = self.ac_fc(self.sequence_fc(o)) # d x {1 / d}
        x = torch.cat([z, user_embs], dim = 1) # 2d x {1/d}
        y = self.output_fc(x).squeeze(-1)
        return y

    def __repr__(self):
        return f"CASER(num_users={self.num_users}, num_items={self.num_items}, drop_ratio={self.drop_ratio}, d={self.d}, L={self.L}, F_h={self.F_h}, F_v={self.F_v})"

    def __str__(self):
        return self.__repr__()

class CASERLIT(pl.LightningModule):
    def __init__(self, num_users, num_items, d, L, F_h, F_v,drop_ratio=0.5,
                 lr: float = 3e-4, wd = 0.0, metric_Ns: Optional[List[int]] = None,
                 ac_conv = 'relu', ac_fc = 'relu'):
        super().__init__()
        if metric_Ns is None:
            metric_Ns = [1,5,10]

        self.model = CASER(num_users, num_items, d, L, F_h, F_v, drop_ratio, ac_conv, ac_fc)
        self.loss_fn = torch.nn.BCEWithLogitsLoss()

        self.recall = torch.nn.ModuleList([TopNRecall(n) for n in metric_Ns])
        self.precision = torch.nn.ModuleList([TopNPrecision(n) for n in metric_Ns])
        self.map = MeanAveragePrecision()

        self.lr = lr
        self.wd = wd

    def training_step(self, batch, batch_idx):
        U, I, T, N = batch
        y = self.model(U, I)
        T,N = torch.gather(y, 1, T), torch.gather(y, 1, N)
        loss = self.loss_fn(T, torch.ones_like(T)) + self.loss_fn(N, torch.zeros_like(N))
        self.log("train/loss", loss, prog_bar=True)
        return loss

    def _calculate_metrics(self, y, test_items, new_items, rated_items, stage):
        y[rated_items == 1] = -np.inf
        for r in self.recall:
            r(y, test_items)
            self.log(f"{stage}/recall@{r.top_n}", r, on_step=False, on_epoch=True, prog_bar=True)
        for p in self.precision:
            p(y, test_items)
            self.log(f"{stage}/precision@{p.top_n}", p, on_step=False, on_epoch=True, prog_bar=True)
        self.map(y, test_items)
        self.log(f"{stage}/map", self.map, on_step=False, on_epoch=True, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        U, I, test_items, new_items, rated_items = batch
        y = self.model(U, I)
        self._calculate_metrics(y, test_items, new_items, rated_items, "val")

    def test_step(self, batch, batch_idx):
        U, I, test_items, new_items, rated_items = batch
        y = self.model(U, I)
        self._calculate_metrics(y, test_items, new_items, rated_items, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.wd)

    def __repr__(self):
        return f"CASERLIT(model={self.model}, lr={self.lr}, wd={self.wd})"

    def __str__(self):
        return self.__repr__()


# Train

In [103]:
from pytorch_lightning.loggers import WandbLogger
import wandb

torch.set_float32_matmul_precision('medium')

pl.seed_everything(1234)
movielens = MovielensDataModule(L=5, T=3, batch_size=512, num_workers=6,num_negative_samples=3)
movielens.prepare_data()
movielens.setup()


wandb.login()
wandb.finish()
wandb_logger = WandbLogger(project="CASERLIT")
experiment_dir = wandb_logger.experiment.dir
model = CASERLIT(movielens.num_users, movielens.num_items, 50, 5, 16, 4, 0.5, 1e-3, 1e-6)
trainer = pl.Trainer(max_epochs=50, accelerator="gpu",devices=1, logger=wandb_logger)

sched = torch.profiler.schedule(wait=5, warmup=5, active=10, repeat=2)
profiler = pl.profilers.PyTorchProfiler(
    export_to_chrome=True, dirpath=experiment_dir,
    schwsedule=sched,
    )
profiler.STEP_FUNCTIONS = {"training_step"}  # only profile training
trainer.profiler = profiler

# we run testing without fitting here
trainer.fit(model=model, datamodule=movielens)

import glob

folder = wandb.run.dir
trace_matcher = wandb.run.dir + "/*.pt.trace.json"
trace_file = glob.glob(trace_matcher)[0]
trace_at = wandb.Artifact(name=f"trace-{wandb.run.id}", type="trace")
trace_at.add_file(trace_file, name="training_step.pt.trace.json")
wandb.log_artifact(trace_at)
wandb.finish()


INFO:lightning_fabric.utilities.seed:Seed set to 1234


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mpiterleviant[0m ([33mpiterleviant-weizmann-institute-of-science[0m). Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type                 | Params | Mode 
-----------------------------------------------------------
0 | model     | CASER                | 844 K  | train
1 | loss_fn   | BCEWithLogitsLoss    | 0      | train
2 | recall    | ModuleList           | 0      | train
3 | precision | ModuleList           | 0      | train
4 | map       | MeanAveragePrecision | 0      | train
-----------------------------------------------------------
844 K     Trainable params
0         Non-trainable params
844 K     Total params
3.376     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
  self.pid = os.fork()


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


VBox(children=(Label(value='1.291 MB of 2.580 MB uploaded\r'), FloatProgress(value=0.5001559696842243, max=1.0…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇█
train/loss,██▇▇▆▆▆▅▄▄▄▄▄▃▄▃▄▄▂▃▂▃▂▇▃▃▃▂▃▃▂▁▃▂▂▂▃▂▃▂
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val/map,▁▄▆▇▇▇███
val/precision@1,▁▄▆▆▇▇███
val/precision@10,▁▄▆▇▇████
val/precision@5,▁▄▆▇▇▇███
val/recall@1,▁▅▅▆▇▇███
val/recall@10,▁▄▆▇▇▇███
val/recall@5,▁▄▆▆▇▇███

0,1
epoch,9.0
train/loss,0.34649
trainer/global_step,13699.0
val/map,0.15489
val/precision@1,0.27765
val/precision@10,0.2107
val/precision@5,0.23222
val/recall@1,0.01627
val/recall@10,0.11299
val/recall@5,0.06413


# TODO
1. Implement more recommendation models as stated in the paper.
2. Implement more datamodules apart from moviellens
3. Recreate the resutls in the paper and give exploratory examples.