# SasRec training/inference with stream dataset example

## Imports and session initialization

In [1]:
import copy

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from pyspark.sql import functions as F
from pyspark.sql.window import Window

from replay.metrics.torch_metrics_builder import metrics_to_df
from replay.data import (
    FeatureHint,
    FeatureSource,
    FeatureType,
)
from replay.data.nn import (
    TensorFeatureInfo,
    TensorFeatureSource,
    TensorSchema,
)
from replay.experimental.nn.sequential.postprocessors import RemoveSeenItems
from replay.metrics import MAP, OfflineMetrics, Precision, Recall
from replay.models.nn.sequential import SasRec
from replay.experimental.nn.sequential.callbacks import (
    PandasPredictionCallback,
    ValidationMetricsCallback
)
from replay.models.nn.sequential.sasrec import (
    SasRecPredictionBatch,
    SasRecTrainingBatch,
    SasRecValidationBatch,
)
from replay.splitters import LastNSplitter, RatioSplitter
from replay.utils.session_handler import get_spark_session

# Fix seed to ensure reproducibility
L.seed_everything(42)

Seed set to 42


42

In [2]:
spark_session = get_spark_session()

25/12/18 09:47:10 WARN Utils: Your hostname, ecs-vagolubenko resolves to a loopback address: 127.0.1.1; using 10.11.10.197 instead (on interface eth0)
25/12/18 09:47:10 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/12/18 09:47:11 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/12/18 09:47:11 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


## Preparing data
In this example, we will be using the MovieLens dataset, namely the 100k subset.  
Begin by loading interactions, item features and user features using the created session.

---
**NOTE**

Current implementation of SasRec does not take into account user/item features. As such, they are only used in this example to get complete lists of users and items.

---

In [3]:
!pip install rs-datasets

Looking in indexes: https://pypi.org/simple, https://__token__:****@gitlab.amazmetest.ru/api/v4/projects/359/packages/pypi/simple, https://__token__:****@gitlab.amazmetest.ru/api/v4/projects/315/packages/pypi/simple, https://__token__:****@gitlab.amazmetest.ru/api/v4/projects/277/packages/pypi/simple, https://__token__:****@gitlab.amazmetest.ru/api/v4/projects/381/packages/pypi/simple, https://download.pytorch.org/whl/cu113

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.1.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
from rs_datasets import MovieLens

movielens = MovieLens("100k")

In [12]:
int_win = Window.partitionBy("user_id").orderBy("item_id")
interactions = spark_session.createDataFrame(movielens.ratings)

# NOTE: The following code block is optional and is used
# to counteract the issue of identical timestamps in the dataset.
# Uncomment if you wish to use it.
interactions = (
    interactions.select(["user_id", "item_id", "timestamp"])
    .withColumn("ts", F.col("timestamp").cast("long") * 1000)
    .withColumn("row_num", F.row_number().over(int_win))
    .withColumn("timestamp", (F.col("ts") + F.col("row_num")).cast("string"))
    .drop("ts", "row_num")
)

interactions.show(n=5)



+-------+-------+------------+
|user_id|item_id|   timestamp|
+-------+-------+------------+
|      1|      1|874965758001|
|      1|      2|876893171002|
|      1|      3|878542960003|
|      1|      4|876893119004|
|      1|      5|889751712005|
+-------+-------+------------+
only showing top 5 rows



                                                                                

In [13]:
user_features = spark_session.createDataFrame(movielens.users)
user_features.show(n=5)

+-------+------+---+----------+--------+
|user_id|gender|age|occupation|zip_code|
+-------+------+---+----------+--------+
|      1|    24|  M|technician|   85711|
|      2|    53|  F|     other|   94043|
|      3|    23|  M|    writer|   32067|
|      4|    24|  M|technician|   43537|
|      5|    33|  F|     other|   15213|
+-------+------+---+----------+--------+
only showing top 5 rows



In [14]:
item_features = spark_session.createDataFrame(movielens.items)
item_features.show(n=5)

+-------+-----------------+------------+--------------------+-------+------+---------+---------+----------+------+-----+-----------+-----+-------+---------+------+-------+-------+-------+------+--------+-----+-------+
|item_id|            title|release_date|            imdb_url|unknown|Action|Adventure|Animation|Children's|Comedy|Crime|Documentary|Drama|Fantasy|Film-Noir|Horror|Musical|Mystery|Romance|Sci-Fi|Thriller|  War|Western|
+-------+-----------------+------------+--------------------+-------+------+---------+---------+----------+------+-----+-----------+-----+-------+---------+------+-------+-------+-------+------+--------+-----+-------+
|      1| Toy Story (1995)| 01-Jan-1995|http://us.imdb.co...|  false| false|    false|     true|      true|  true|false|      false|false|  false|    false| false|  false|  false|  false| false|   false|false|  false|
|      2| GoldenEye (1995)| 01-Jan-1995|http://us.imdb.co...|  false|  true|     true|    false|     false| false|false|      fa

### Encode catagorical data.
To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset.

In [15]:
from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule
from replay.utils.types import SparkDataFrame


def encode_data(queries: SparkDataFrame, items: SparkDataFrame, interactions: SparkDataFrame, label_encoder: LabelEncoder):
    full_data = interactions.join(queries, on="user_id").join(items, on="item_id")
    full_data = label_encoder.fit_transform(full_data)

    return full_data

In [16]:
encoder = LabelEncoder([
    LabelEncodingRule("user_id", default_value="last"),
    LabelEncodingRule("item_id", default_value="last")
])
encoded_interactions = encode_data(user_features, item_features, interactions, encoder)

25/12/18 09:49:23 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/18 09:49:23 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/18 09:49:23 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/18 09:49:24 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/18 09:49:24 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/18 09:49:24 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
25/12/18 0

### Split interactions into the train, validation and test datasets using RatioSplitter

In order to facilitate the model's training, we split the dataset in the following way:
1) A 60/40 data split of original data for training and subsequent splits
2) A 75/25 split of the leftover data for testing/validation respectively (i.e. 30%/10% of the full dataset)

We also remove cold users/items after each split.

In [17]:
train_events, test_events = RatioSplitter(
    test_size=0.4,
    divide_column="user_id",
    query_column="user_id",
    timestamp_column="timestamp",
    drop_cold_users=True,
    drop_cold_items=True,
).split(encoded_interactions)

print(f"{train_events.count()=}, {test_events.count()=}")

                                                                                

train_events.count()=59623, test_events.count()=40170


In [18]:
test_events, val_events = RatioSplitter(
    test_size=0.25,
    divide_column="user_id",
    query_column="user_id",
    timestamp_column="timestamp",
    drop_cold_users=True,
    drop_cold_items=True,
).split(test_events)

print(f"{test_events.count()=}, {val_events.count()=}")

                                                                                

test_events.count()=29781, val_events.count()=10327


### Split the validation dataset into events and ground_truth

For both validation and testing data, the last N items are split into ground truth, which will be used to calculate metrics.

In [19]:
VALIDATION_GROUND_TRUTH_INTERACTIONS_PER_USER = 3
TEST_GROUND_TRUTH_INTERACTIONS_PER_USER = 3

val_events, val_gt = LastNSplitter(
    N=VALIDATION_GROUND_TRUTH_INTERACTIONS_PER_USER, divide_column="user_id", query_column="user_id", strategy="interactions"
).split(val_events)
print(f"{val_events.count()=}, {val_gt.count()=}")

test_events, test_gt = LastNSplitter(
    N=TEST_GROUND_TRUTH_INTERACTIONS_PER_USER, divide_column="user_id", query_column="user_id", strategy="interactions"
).split(test_events)
print(f"{test_events.count()=}, {test_gt.count()=}")

                                                                                

val_events.count()=7534, val_gt.count()=2793


                                                                                

test_events.count()=26952, test_gt.count()=2829


### Dataset preprocessing ("baking")
SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay.

In [20]:
from replay.data.nn.utils import groupby_sequences


def bake_data(full_data: SparkDataFrame):
    grouped_interactions = groupby_sequences(
        events=full_data,
        groupby_col="user_id",
        sort_col="timestamp"
    )

    return grouped_interactions

In [21]:
train_events = bake_data(train_events)
val_events = bake_data(val_events)
val_gt = bake_data(val_gt)
test_events = bake_data(test_events)

25/12/18 09:50:18 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


To ensure we don't validate on unknown users, we join train and validation data by user ids, leaving only the common ones.  
We also pre-package the validation data with its ground truth and train-time events.

In [22]:
# Keep common query ids between val_dataset and val_gt
val_events = val_events.join(val_gt, on="user_id", how="left_semi")
val_gt = val_gt.join(val_events, on="user_id", how="left_semi")

gt_to_join = val_gt.select(["user_id", "item_id"]).withColumnRenamed("item_id", "ground_truth")
train_to_join = train_events.select(["user_id", "item_id"]).withColumnRenamed("item_id", "train")

val_events = val_events.join(gt_to_join, on="user_id", how="left")
val_events = val_events.join(train_to_join, on="user_id", how="left")

TRAIN_LEN = val_events.select(F.max(F.size("train")).alias("res")).collect()[0].res
GT_LEN = val_events.select(F.max(F.size("ground_truth")).alias("res")).collect()[0].res

                                                                                

In [23]:
TRAIN_PATH = "temp/data/train.parquet"
VAL_PATH = "temp/data/val.parquet"
TEST_PATH = "temp/data/test.parquet"

train_events.write.mode("overwrite").parquet(TRAIN_PATH)
val_events.write.mode("overwrite").parquet(VAL_PATH)
test_events.write.mode("overwrite").parquet(TEST_PATH)

                                                                                

## Initialize model

### Create the tensor schema
A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is requiredby the SasRec model to correctly perform operations such as padding and embeddings aggregation at train time.

In [24]:
EMBEDDING_DIM = 128

ITEM_FEATURE_NAME = "item_id"

tensor_schema = TensorSchema(
    TensorFeatureInfo(
        name="item_id",
        is_seq=True,
        cardinality=len(encoder.mapping["item_id"]),
        embedding_dim=EMBEDDING_DIM,
        feature_type=FeatureType.CATEGORICAL,
        feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, "item_id")],
        feature_hint=FeatureHint.ITEM_ID,
    )
)

### Configure ParquetModule and transformation pipelines
The `ParquetModule` class enables training of models on large datasets by reading data in streaming mode. This class initialized with a metadata dict containing information about dataset's features and miscellanious options for initialization (such as shuffling).

Additionally, `ParquetModule` supports "transform pipelines" - stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass.  

In out case, we create the following pipelines:
1) Training:
    1. Create a label mask, which contains the shifted item sequence that represents the next item in the sequence. This mask is unpadded, meaning we have to fetch an extra item for each sequence and then slice the data appropriately;
    2. Rename/group columns to match it with the `NamedTuple` expected by the model during training.
    3. Compose columns into the expected `NamedTuple`
2) Validation/Inference:
    1. Rename/group columns to match it with the `NamedTuple` expected by the model during valdiation/inference.
    2. Compose columns into the expected `NamedTuple`


In [26]:
from replay.nn.transforms import (
    BatchingTransform,
    GroupTransform,
    RenameTransform,
    NextTokenTransform,
)

MAX_SEQ_LEN = 50
BATCH_SIZE = 128
SHIFT=1

TRANSFORMS = {
    "train": [
        NextTokenTransform(label_field="item_id", query_features="user_id", shift=SHIFT),
        RenameTransform({"user_id": "query_id", "item_id_mask": "padding_mask", "labels_mask": "labels_padding_mask"}),
        GroupTransform({"features": ["item_id"]}),
        BatchingTransform(SasRecTrainingBatch)
    ],
    "val": [
        RenameTransform({"user_id": "query_id", "item_id_mask": "padding_mask"}),
        GroupTransform({"features": ["item_id"]}),
        BatchingTransform(SasRecValidationBatch)
    ],
    "test": [
        RenameTransform({"user_id": "query_id", "item_id_mask": "padding_mask"}),
        GroupTransform({"features": ["item_id"]}),
        BatchingTransform(SasRecPredictionBatch)
    ]
}

shared_meta = {
    "user_id": {},
    "item_id": {
        "shape": MAX_SEQ_LEN,
        "padding": tensor_schema["item_id"].padding_value
    }
}

METADATA = {
    "train": copy.deepcopy(shared_meta),
    "val": {
        **copy.deepcopy(shared_meta),
        "train": {
            "shape": TRAIN_LEN,
            "padding": -2
        },
        "ground_truth": {
            "shape": GT_LEN,
            "padding": -1
        },
    },
    "test": copy.deepcopy(shared_meta)
}

In [28]:
from replay.data.nn import ParquetModule

streaming_dataset = ParquetModule(
    train_path=TRAIN_PATH,
    val_path=VAL_PATH,
    test_path=TEST_PATH,
    batch_size=BATCH_SIZE,
    metadata=METADATA,
    transforms=TRANSFORMS
)

# NOTE: You can also create a module specifically for training/inference by providing only their respective datapaths
# streaming_dataset_train_only = ParquetModule(
#     train_path=TRAIN_PATH,
#     val_path=VAL_PATH,
#     batch_size=BATCH_SIZE,
#     metadata=METADATA,
#     transforms=TRANSFORMS
# )

## Train model
### Create SasRec model instance and run the training stage using lightning
We may now train the model using the Lightning trainer class. 
To facilitate training, we add the following callbacks:
1) `ValidationMetricsCallback` - to display a detailed validation metric matrix after each epoch.
2) `ModelCheckpoint` - to save the best trained model based on its Recall metric.

In [29]:
model = SasRec(tensor_schema, max_seq_len=MAX_SEQ_LEN, dropout_rate=0.0)

In [30]:
checkpoint_callback = ModelCheckpoint(
    dirpath=".checkpoints",
    save_top_k=1,
    verbose=True,
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=len(encoder.mapping["item_id"]),
)

csv_logger = CSVLogger(save_dir=".logs/train", name="GPT_example")

trainer = L.Trainer(
    max_epochs=20,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

trainer.fit(model, datamodule=streaming_dataset)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
/root/replay/RePlay/.venv/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:751: Checkpoint directory /root/replay/RePlay/examples/.checkpoints exists and is not empty.

  | Name   | Type             | Params | Mode 
----------------------------------------------------
0 | _model | SasRecModel      | 117 K  | train
1 | _loss  | CrossEntropyLoss | 0      | train
----------------------------------------------------
117 K     Trainable params
0         Non-trainable params
117 K     Total params
0.471     Total estimated model params size (MB)
35        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/root/replay/RePlay/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


k         1        10        20         5
map     0.0  0.002707  0.003277  0.002214
ndcg    0.0  0.006342  0.009477  0.004601
recall  0.0  0.010417  0.019531  0.006510



/root/replay/RePlay/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.
/root/replay/RePlay/.venv/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 8: 'recall@10' reached 0.01333 (best 0.01333), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=0-step=8.ckpt' as top 1


k              1        10        20         5
map     0.001379  0.003661  0.004645  0.002621
ndcg    0.001379  0.008308  0.013262  0.004876
recall  0.000460  0.013333  0.027126  0.005977



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 16: 'recall@10' reached 0.01655 (best 0.01655), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=1-step=16.ckpt' as top 1


k              1        10        20         5
map     0.002759  0.003970  0.005116  0.002835
ndcg    0.002759  0.009617  0.015453  0.005683
recall  0.000920  0.016552  0.033563  0.007816



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 24: 'recall@10' was not in top 1


k              1        10        20         5
map     0.006897  0.005240  0.006886  0.004130
ndcg    0.006897  0.010993  0.018811  0.007093
recall  0.002299  0.016552  0.039080  0.007816



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 32: 'recall@10' reached 0.01977 (best 0.01977), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=3-step=32.ckpt' as top 1


k              1        10        20         5
map     0.011034  0.006542  0.008011  0.005218
ndcg    0.011034  0.013396  0.020575  0.008703
recall  0.003678  0.019770  0.040460  0.009195



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 40: 'recall@10' reached 0.02299 (best 0.02299), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=4-step=40.ckpt' as top 1


k              1        10        20         5
map     0.008276  0.006884  0.008335  0.005333
ndcg    0.008276  0.014907  0.022292  0.009773
recall  0.002759  0.022989  0.044138  0.011954



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 48: 'recall@10' reached 0.02943 (best 0.02943), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=5-step=48.ckpt' as top 1


k              1        10        20         5
map     0.008276  0.008469  0.010061  0.006713
ndcg    0.008276  0.018635  0.026410  0.012663
recall  0.002759  0.029425  0.051494  0.016092



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 56: 'recall@10' reached 0.03310 (best 0.03310), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=6-step=56.ckpt' as top 1


k              1        10        20         5
map     0.009655  0.009902  0.012182  0.007379
ndcg    0.009655  0.021121  0.031973  0.013387
recall  0.003218  0.033103  0.063908  0.016092



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 64: 'recall@10' reached 0.03816 (best 0.03816), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=7-step=64.ckpt' as top 1


k              1        10        20         5
map     0.008276  0.010709  0.012531  0.008077
ndcg    0.008276  0.023570  0.031873  0.015128
recall  0.002759  0.038161  0.062069  0.019310



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 72: 'recall@10' was not in top 1


k              1        10        20         5
map     0.009655  0.011645  0.013846  0.008966
ndcg    0.009655  0.024417  0.034202  0.016141
recall  0.003218  0.038161  0.066207  0.019770



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 80: 'recall@10' reached 0.04230 (best 0.04230), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=9-step=80.ckpt' as top 1


k              1        10        20         5
map     0.012414  0.013926  0.016703  0.010674
ndcg    0.012414  0.027765  0.040384  0.018640
recall  0.004138  0.042299  0.078621  0.022529



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 88: 'recall@10' reached 0.04460 (best 0.04460), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=10-step=88.ckpt' as top 1


k              1        10        20         5
map     0.017931  0.015878  0.018796  0.013356
ndcg    0.017931  0.030595  0.043104  0.023309
recall  0.005977  0.044598  0.080460  0.028506



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 96: 'recall@10' reached 0.04460 (best 0.04460), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=11-step=96.ckpt' as top 1


k              1        10        20         5
map     0.016552  0.015271  0.018759  0.012092
ndcg    0.016552  0.029896  0.044930  0.020808
recall  0.005517  0.044598  0.087356  0.024828



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 104: 'recall@10' reached 0.04782 (best 0.04782), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=12-step=104.ckpt' as top 1


k              1        10        20         5
map     0.012414  0.015035  0.018235  0.011195
ndcg    0.012414  0.030481  0.044741  0.019434
recall  0.004138  0.047816  0.088736  0.023448



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13, global step 112: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016552  0.015766  0.019039  0.012268
ndcg    0.016552  0.031083  0.045698  0.021059
recall  0.005517  0.046897  0.088736  0.024368



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14, global step 120: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015172  0.015242  0.018826  0.011847
ndcg    0.015172  0.030643  0.045776  0.020601
recall  0.005057  0.046897  0.090115  0.024368



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15, global step 128: 'recall@10' reached 0.05011 (best 0.05011), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=15-step=128.ckpt' as top 1


k              1        10        20         5
map     0.016552  0.015647  0.019261  0.011977
ndcg    0.016552  0.032116  0.046968  0.020869
recall  0.005517  0.050115  0.092414  0.024828



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16, global step 136: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019310  0.015955  0.019780  0.012115
ndcg    0.019310  0.032391  0.047883  0.020978
recall  0.006437  0.049655  0.093793  0.024368



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17, global step 144: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016552  0.015907  0.020235  0.012238
ndcg    0.016552  0.032098  0.049018  0.021310
recall  0.005517  0.048736  0.097011  0.024828



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 18, global step 152: 'recall@10' reached 0.05057 (best 0.05057), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=18-step=152.ckpt' as top 1


k              1        10        20         5
map     0.017931  0.016487  0.020537  0.012874
ndcg    0.017931  0.033225  0.049172  0.022156
recall  0.005977  0.050575  0.096092  0.025747



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19, global step 160: 'recall@10' reached 0.05149 (best 0.05149), saving model to '/root/replay/RePlay/examples/.checkpoints/epoch=19-step=160.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=20` reached.


k              1        10        20         5
map     0.020690  0.017529  0.021282  0.013724
ndcg    0.020690  0.034374  0.049122  0.022727
recall  0.006897  0.051494  0.093333  0.025287



We can now laod the best model usingthe path stored in the callback.

In [31]:
best_model = SasRec.load_from_checkpoint(checkpoint_callback.best_model_path)

## Inference stage

### Run inference
We can now perform inference using the data module we created earlier. Recommendations can be fetched in four formats: PySpark DataFrame, Pandas DataFrame, Polars DataFrame or raw PyTorch tensors. Each of the types corresponds a callback. Inthis example, we'll be using the `PandasPredictionCallback`.

Prediction callbacks can filter results using postprocessors. In our case, we apply the `RemoveSeenItems` postprocessor to filter out items already present in the test dataset.

In [32]:
csv_logger = CSVLogger(save_dir=".logs/test", name="GPT_example")

TOPK = [1, 2, 3]

postprocessors = [
    RemoveSeenItems(
        seen_path=TEST_PATH,
        item_count=tensor_schema[ITEM_FEATURE_NAME].cardinality,
        query_column="user_id",
        item_column=tensor_schema.item_id_feature_name
    )
]

pandas_prediction_callback = PandasPredictionCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=postprocessors,
)

trainer = L.Trainer(
    callbacks=[pandas_prediction_callback],
    logger=csv_logger,
    inference_mode=True
)
best_model.eval()
trainer.predict(best_model, datamodule=streaming_dataset, return_predictions=False)

pandas_res = pandas_prediction_callback.get_result()

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
/root/replay/RePlay/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Predicting: |          | 0/? [00:00<?, ?it/s]

In [33]:
pandas_res

Unnamed: 0,user_id,item_id,score
0,3,203,3.197812
0,3,215,2.702926
0,3,95,2.678969
1,17,117,2.713183
1,17,755,2.623668
...,...,...,...
941,923,482,3.044874
941,923,173,3.040446
942,941,173,3.509326
942,941,97,3.448085


Suppose we want to get the recomendations in PySpark format. 
Let's get the inverse representation of labels using inverse_transform method.

Note that the reverse representation can only be obtained for PySpark and Pandas formats. When working with PyTorch tensors, the reverse representation must be done manually

In [36]:
test_mapped_gt = encoder.inverse_transform(test_gt)

### Calculating metrics

In [38]:
result_metrics = OfflineMetrics(
    [Recall(TOPK), Precision(TOPK), MAP(TOPK)],
    query_column="user_id",
    rating_column="score"
)(pandas_res, test_mapped_gt.toPandas())



In [39]:
metrics_to_df(result_metrics)

k,1,2,3
MAP,0.004242,0.002386,0.00218
Precision,0.004242,0.002651,0.003535
Recall,0.001414,0.001767,0.003535
