In [26]:
import pandas as pd
from rectools import Columns
import tqdm
import torch
import os
from lightning_fabric import seed_everything

In [25]:
from rectools.model_selection import TimeRangeSplitter
from rectools.dataset import Interactions

## Dataset

### MovieLens 1M

In [17]:
splitter = TimeRangeSplitter(
    test_size="7D",
    n_splits=1,
    filter_cold_users=True,
    filter_cold_items=True,
    filter_already_seen=True,
)

In [14]:
ml1m = pd.read_csv(
        "/Users/mayyaspirina/Desktop/vkrgrid/repo/prepare_vkr/Datasets/ratings.dat",
        sep="::",
        names=["userId", "movieId", "rating", "timestamp"],
        engine="python",
    )

ml1m = ml1m[ml1m["rating"] >= 1].drop(columns=["rating"])
ml1m.rename(columns={
        "userId": Columns.User,
        "movieId": Columns.Item,
        "timestamp": Columns.Datetime,
    },
    inplace=True,
)
ml1m[Columns.Datetime] = pd.to_datetime(ml1m[Columns.Datetime], unit="s")
ml1m[Columns.Weight] = 1

In [23]:
split_iterator = splitter.split(Interactions(ml1m))
train_ids, test_ids, _ = next(iter(split_iterator))
train = ml1m.iloc[train_ids]
test = ml1m.iloc[test_ids]

# Modification comparison

## Set seed

In [None]:
torch.use_deterministic_algorithms(True)
seed_everything(42, workers=True)

# Enable deterministic behaviour with CUDA >= 10.2
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

## Determine model parameters

Here are options that should be chosen for each modification to work:

### Training objective
* Shifted sequence: choose model `SASRecModel`
* MLM: choose model `BERT4RecModel`
* All action: choose model `BERT4RecModel`
    * data_preparator_type: modifications.objectives.all_action.AllActionDataPreparator
    * lightning_module_type: modifications.objectives.all_action.AllActionLightningModule
    * backbone_type: modifications.objectives.all_action.AllActionTransformerTorchBackbone
* Dense all action: choose model `SASRecModel`
    * data_preparator_type: modifications.objectives.dense_all_action.DenseAllActionDataPreparator

### Transformer layers
* SASRec: rectools.models.nn.transformers.sasrec.SASRecTransformerLayers
* BERT4Rec: rectools.models.nn.transformers.net_blocks.PreLNTransformerLayers
* ALBERT: src.models.transformers.transformer_layers.albert.AlbertLayers
