# Example of the SASRec training with using RandomTargetNextNSplitter

In [None]:
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import torch

from replay.metrics import OfflineMetrics, Recall, Precision, MAP, NDCG, HitRate, MRR
from replay.metrics.torch_metrics_builder import metrics_to_df
from replay.splitters import LastNSplitter, RandomTargetNextNSplitter, TimeSplitter
from replay.preprocessing.filters import MinCountFilter
from replay.data import (
    FeatureHint,
    FeatureInfo,
    FeatureSchema,
    FeatureSource,
    FeatureType,
    Dataset,
)
from replay.models.nn.optimizer_utils import FatOptimizerFactory
from replay.models.nn.sequential.callbacks import (
    ValidationMetricsCallback,
    SparkPredictionCallback,
    PandasPredictionCallback,
    TorchPredictionCallback,
    QueryEmbeddingsPredictionCallback,
)
from replay.models.nn.sequential.postprocessors import RemoveSeenItems
from replay.data.nn import SequenceTokenizer, SequentialDataset, TensorFeatureSource, TensorSchema, TensorFeatureInfo
from replay.models.nn.sequential import SasRec
from replay.models.nn.sequential.sasrec import (
    SasRecPredictionDataset,
    SasRecTrainingDataset,
    SasRecValidationDataset,
    SasRecPredictionBatch,
    SasRecModel,
)
import pandas as pd

## Prepare data
### Load raw movielens-1M interactions, item features and user features.

In [None]:
!pip install rs-datasets

In [2]:
from rs_datasets import MovieLens

In [14]:
movielens = MovieLens("1m")
interactions = movielens.ratings
user_features = movielens.users
item_features = movielens.items

In [15]:
interactions.head()

Unnamed: 0,user_id,item_id,rating,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291


In [16]:
user_features.head()

Unnamed: 0,user_id,gender,age,occupation,zip_code
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117
3,4,M,45,7,2460
4,5,M,25,20,55455


In [17]:
item_features.head()

Unnamed: 0,item_id,title,genres
0,1,Toy Story (1995),Animation|Children's|Comedy
1,2,Jumanji (1995),Adventure|Children's|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama
4,5,Father of the Bride Part II (1995),Comedy


Removing duplicates in the timestamp column without changing the original items order where timestamp is the same

In [20]:
interactions["timestamp"] = interactions["timestamp"].astype("int64")
interactions = interactions.sort_values(by="timestamp")

### Split interactions into the train, validation and test datasets using RandomTargetNextNSplitter

In [107]:

random_splitter = RandomTargetNextNSplitter(
    N=1,
    seed=42,
    query_column="user_id"
)
time_splitter = TimeSplitter(
    time_threshold=0.1,
    query_column="user_id"
)

train_val, test_holdout = time_splitter.split(interactions)
train, val_holdout = time_splitter.split(train_val)

Remove users with less than 1 interaction

In [108]:
min_cnt_filter = MinCountFilter(num_entries=1, groupby_column="user_id")
train = min_cnt_filter.transform(train)

Remove cold items from test_holdout, val_holdout based on filtered train

In [109]:
test_holdout = test_holdout[test_holdout["item_id"].isin(train["item_id"].unique())]
val_holdout = val_holdout[val_holdout["item_id"].isin(train["item_id"].unique())]

train_val = train_val[train_val["item_id"].isin(train["item_id"].unique())]

Create input and target for validation and test subsets

In [110]:
val_input, val_target = random_splitter.split(val_holdout)
test_input, test_target = random_splitter.split(test_holdout)

test_input = pd.concat([train_val, test_input], axis=0)
val_input = pd.concat([train, val_input], axis=0)

Remove targets with no input

In [111]:
test_input = test_input[test_input["user_id"].isin(test_target["user_id"])]
val_input = val_input[val_input["user_id"].isin(val_target["user_id"])]

In [112]:
print(f"train: {len(train)}")
print(f"val_input: {len(val_input)}")
print(f"val_target: {len(val_target)}")
print(f"test_input: {len(test_input)}")
print(f"test_target: {len(test_target)}")


train: 810169
val_input: 170613
val_target: 1045
test_input: 314858
test_target: 1208


### Prepare FeatureSchema required to create Dataset

In [113]:
def prepare_feature_schema(is_ground_truth: bool) -> FeatureSchema:
    base_features = FeatureSchema(
        [
            FeatureInfo(
                column="user_id",
                feature_hint=FeatureHint.QUERY_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
            FeatureInfo(
                column="item_id",
                feature_hint=FeatureHint.ITEM_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
        ]
    )
    if is_ground_truth:
        return base_features

    all_features = base_features + FeatureSchema(
        [
            FeatureInfo(
                column="timestamp",
                feature_type=FeatureType.NUMERICAL,
                feature_hint=FeatureHint.TIMESTAMP,
            ),
        ]
    )
    return all_features

### Create Dataset for the training stage

In [114]:
train_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=train,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)

### Create Datasets (events and ground_truth) for the validation stage

In [115]:
validation_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=val_input,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
validation_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=val_target,
    check_consistency=True,
    categorical_encoded=False,
)

### Create Datasets (events and ground_truth) for the testing stage

In [116]:
test_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=test_input,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
test_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=test_target,
    check_consistency=True,
    categorical_encoded=False,
)

### Create the tensor schema
A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model

In [117]:
ITEM_FEATURE_NAME = "item_id_seq"

tensor_schema = TensorSchema(
    TensorFeatureInfo(
        name=ITEM_FEATURE_NAME,
        is_seq=True,
        feature_type=FeatureType.CATEGORICAL,
        feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, train_dataset.feature_schema.item_id_column)],
        feature_hint=FeatureHint.ITEM_ID,
    )
)

### Create sequential datasets using SequenceTokenizer

In [118]:
tokenizer = SequenceTokenizer(tensor_schema, allow_collect_to_master=True)
tokenizer.fit(train_dataset)

sequential_train_dataset = tokenizer.transform(train_dataset)

sequential_validation_dataset = tokenizer.transform(validation_dataset)
sequential_validation_gt = tokenizer.transform(validation_gt, [tensor_schema.item_id_feature_name])

sequential_test_dataset = tokenizer.transform(test_dataset)
sequential_test_gt = tokenizer.transform(test_gt, [tensor_schema.item_id_feature_name])


## Train model
### Create SASRec model instance and run the training stage using lightning
After each epoch validation metrics are shown. You can change the list of validation metrics in ValidationMetricsCallback
The model is determined to be the best and is saved if the metric updates its maximum during validation (see the ModelCheckpoint)

In [None]:
MAX_SEQ_LEN = 200
BATCH_SIZE = 512
NUM_WORKERS = 4
MAX_EPOCHS = 10

model = SasRec(
    tensor_schema,
    block_count=2,
    head_count=2,
    max_seq_len=MAX_SEQ_LEN,
    hidden_size=300,
    dropout_rate=0.5,
    optimizer_factory=FatOptimizerFactory(learning_rate=0.001),
)

csv_logger = CSVLogger(save_dir=".logs/train", name="SASRec_example")

checkpoint_callback = ModelCheckpoint(
    dirpath=".checkpoints",
    save_top_k=1,
    verbose=True,
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset.item_count,
    postprocessors=[RemoveSeenItems(sequential_validation_dataset)],
)

trainer = L.Trainer(
    max_epochs=MAX_EPOCHS,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

train_dataloader = DataLoader(
    dataset=SasRecTrainingDataset(
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

validation_dataloader = DataLoader(
    dataset=SasRecValidationDataset(
        sequential_validation_dataset,
        sequential_validation_gt,
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=validation_dataloader,
)

In [138]:
best_model = SasRec.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    tensor_schema=tensor_schema,
)

### Test model

In [None]:
test_dataloader = DataLoader(
    dataset=SasRecValidationDataset(
        sequential_test_dataset,
        sequential_test_gt,
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE
)

In [159]:
test_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset.item_count,
    postprocessors=[RemoveSeenItems(sequential_test_dataset)],
)

trainer_test = L.Trainer(callbacks=[test_metrics_callback], logger=csv_logger)
trainer_test.validate(best_model, test_dataloader)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Validation: |          | 0/? [00:00<?, ?it/s]

k              1        10        20         5
map     0.020695  0.045578  0.049658  0.038562
ndcg    0.020695  0.064293  0.079709  0.047215
recall  0.020695  0.126656  0.188742  0.073675



[{'recall@1': 0.020695364102721214,
  'ndcg@1': 0.020695364102721214,
  'map@1': 0.020695364102721214,
  'recall@5': 0.0736754983663559,
  'ndcg@5': 0.04721491411328316,
  'map@5': 0.038562361150979996,
  'recall@10': 0.12665562331676483,
  'ndcg@10': 0.06429333984851837,
  'map@10': 0.04557809233665466,
  'recall@20': 0.18874172866344452,
  'ndcg@20': 0.0797085240483284,
  'map@20': 0.04965835064649582}]

## Comparing with LOO splitter 

In this section, we will train the same SASRec model using Leave-One-Out (LOO) splitter and compare results.

In [166]:
from re import T


splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
)

test_input, test_target = splitter.split(interactions)
train, val_target = splitter.split(test_input)

In [167]:
train_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=train,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)


validation_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=train,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
validation_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=val_target,
    check_consistency=True,
    categorical_encoded=False,
)

test_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=test_input,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
test_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=test_target,
    check_consistency=True,
    categorical_encoded=False,
)


In [168]:
ITEM_FEATURE_NAME = "item_id_seq"

tensor_schema = TensorSchema(
    TensorFeatureInfo(
        name=ITEM_FEATURE_NAME,
        is_seq=True,
        feature_type=FeatureType.CATEGORICAL,
        feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, train_dataset.feature_schema.item_id_column)],
        feature_hint=FeatureHint.ITEM_ID,
    )
)

tokenizer = SequenceTokenizer(tensor_schema, allow_collect_to_master=True)
tokenizer.fit(train_dataset)

sequential_train_dataset = tokenizer.transform(train_dataset)

sequential_validation_dataset = tokenizer.transform(validation_dataset)
sequential_validation_gt = tokenizer.transform(validation_gt, [tensor_schema.item_id_feature_name])

sequential_validation_dataset, sequential_validation_gt = SequentialDataset.keep_common_query_ids(
    sequential_validation_dataset, sequential_validation_gt
)

sequential_test_dataset = tokenizer.transform(test_dataset)
sequential_test_gt = tokenizer.transform(test_gt, [tensor_schema.item_id_feature_name])

sequential_test_dataset, sequential_test_gt = SequentialDataset.keep_common_query_ids(
    sequential_test_dataset, sequential_test_gt
)

In [None]:
MAX_SEQ_LEN = 200
BATCH_SIZE = 512
NUM_WORKERS = 4
MAX_EPOCHS = 10

model = SasRec(
    tensor_schema,
    block_count=2,
    head_count=2,
    max_seq_len=MAX_SEQ_LEN,
    hidden_size=300,
    dropout_rate=0.5,
    optimizer_factory=FatOptimizerFactory(learning_rate=0.001),
)

csv_logger = CSVLogger(save_dir=".logs/train", name="SASRec_example")

checkpoint_callback = ModelCheckpoint(
    dirpath=".checkpoints",
    save_top_k=1,
    verbose=True,
    # if you use multiple dataloaders, then add the serial number of the dataloader to the suffix of the metric name.
    # For example,"recall@10/dataloader_idx_0"
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset.item_count,
    postprocessors=[RemoveSeenItems(sequential_validation_dataset)],
)

trainer = L.Trainer(
    max_epochs=MAX_EPOCHS,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

train_dataloader = DataLoader(
    dataset=SasRecTrainingDataset(
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

validation_dataloader = DataLoader(
    dataset=SasRecValidationDataset(
        sequential_validation_dataset,
        sequential_validation_gt,
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=validation_dataloader,
)

In [171]:
best_model = SasRec.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    tensor_schema=tensor_schema,
)

### Test Model

In [172]:
test_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset.item_count,
    postprocessors=[RemoveSeenItems(sequential_test_dataset)],
)

trainer_test = L.Trainer(callbacks=[test_metrics_callback], logger=csv_logger)
trainer_test.validate(best_model, test_dataloader)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Validation: |          | 0/? [00:00<?, ?it/s]

k              1        10        20         5
map     0.001656  0.003494  0.004011  0.002718
ndcg    0.001656  0.005145  0.007026  0.003269
recall  0.001656  0.010762  0.018212  0.004967



[{'recall@1': 0.0016556291375309229,
  'ndcg@1': 0.0016556291375309229,
  'map@1': 0.0016556291375309229,
  'recall@5': 0.004966887645423412,
  'ndcg@5': 0.003268592059612274,
  'map@5': 0.0027179911267012358,
  'recall@10': 0.010761589743196964,
  'ndcg@10': 0.0051450589671730995,
  'map@10': 0.0034939032047986984,
  'recall@20': 0.01821191981434822,
  'ndcg@20': 0.007026043254882097,
  'map@20': 0.00401055533438921}]