# Example of SasRec training/inference with Parquet Module

In [1]:
from typing import Optional

import lightning as L
import pandas as pd

L.seed_everything(42)

import warnings
warnings.filterwarnings("ignore")

Seed set to 42


## Preparing data
In this example, we will be using the MovieLens dataset, namely the 1m subset. It's demonstrated a simple case, so only item ids will be used as model input.

---
**NOTE**

Current implementation of SasRec is able to handle item and interactions features. It does not take into account user features. 

---

In [2]:
interactions = pd.read_csv("./data/ml1m_ratings.dat", sep="\t", names=["user_id", "item_id","rating","timestamp"])
interactions = interactions.drop(columns=["rating"])

In [3]:
interactions["timestamp"] = interactions["timestamp"].astype("int64")
interactions = interactions.sort_values(by="timestamp")
interactions["timestamp"] = interactions.groupby("user_id").cumcount()
interactions

Unnamed: 0,user_id,item_id,timestamp
1000138,6040,858,0
1000153,6040,2384,1
999873,6040,593,2
1000007,6040,1961,3
1000192,6040,2019,4
...,...,...,...
825793,4958,2399,446
825438,4958,1407,447
825724,4958,3264,448
825731,4958,2634,449


### Encode catagorical data.
To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset that will be used in model. Note that ids of users and ids of items are always used.

In [4]:
from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule

encoder = LabelEncoder(
    [
        LabelEncodingRule("user_id", default_value="last"),
        LabelEncodingRule("item_id", default_value="last"),
    ]
)
interactions = interactions.sort_values(by="item_id", ascending=True)
encoded_interactions = encoder.fit_transform(interactions)
encoded_interactions

Unnamed: 0,timestamp,user_id,item_id
0,12,0,0
1,68,1,0
2,67,2,0
3,12,3,0
4,140,4,0
...,...,...,...
1000204,14,4555,3705
1000205,90,2813,3705
1000206,70,2404,3705
1000207,25,5835,3705


### Split interactions into the train, validation and test datasets using LastNSplitter
We use widespread splitting strategy Last-One-Out. We filter out cold items and users for simplicity.

In [5]:
from replay.splitters import LastNSplitter

splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
    drop_cold_users=True,
    drop_cold_items=True
)

test_events, test_gt = splitter.split(encoded_interactions)
validation_events, validation_gt = splitter.split(test_events)
train_events = validation_events

### Dataset preprocessing ("baking")
SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay.

In [6]:
from replay.data.nn.utils import groupby_sequences


def bake_data(full_data):
    grouped_interactions = groupby_sequences(events=full_data, groupby_col="user_id", sort_col="timestamp")
    return grouped_interactions

In [7]:
train_events = bake_data(train_events)

validation_events = bake_data(validation_events)
validation_gt = bake_data(validation_gt)

test_events = bake_data(test_events)
test_gt = bake_data(test_gt)

train_events

Unnamed: 0,user_id,timestamp,item_id
0,0,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2426, 822, 2733, 2587, 2937, 3618, 2943, 708,..."
1,1,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[3272, 3026, 2760, 851, 346, 3393, 1107, 515, ..."
2,2,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[579, 1140, 1154, 2426, 1524, 1260, 2160, 2621..."
3,3,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1781, 2940, 2468, 890, 948, 106, 593, 309, 49..."
4,4,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1108, 2229, 21, 2435, 2142, 106, 1167, 593, 1..."
...,...,...,...
6035,6035,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2426, 1279, 3151, 3321, 1178, 3301, 2501, 278..."
6036,6036,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1592, 2302, 1633, 1813, 2879, 1482, 2651, 250..."
6037,6037,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1971, 3500, 2077, 1666, 1399, 2651, 2748, 283..."
6038,6038,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1486, 1485, 3384, 3512, 3302, 3126, 3650, 330..."


To ensure we don't have unknown users in ground truth, we join validation events and validation ground truth (also join test events and test ground truth correspondingly) by user ids to leave only the common ones.  

In [8]:
def add_gt_to_events(events_df, gt_df):
    gt_to_join = gt_df[["user_id", "item_id"]].rename(columns={"item_id": "ground_truth"})

    events_df = events_df.merge(gt_to_join, on="user_id", how="inner")
    return events_df

validation_events = add_gt_to_events(validation_events, validation_gt)
test_events = add_gt_to_events(test_events, test_gt)

In [9]:
from pathlib import Path

data_dir = Path("temp/data/")
data_dir.mkdir(parents=True, exist_ok=True)

TRAIN_PATH = data_dir / "train.parquet"
VAL_PATH = data_dir / "val.parquet"
PREDICT_PATH = data_dir / "test.parquet"

ENCODER_PATH = data_dir / "encoder"

In [10]:
train_events.to_parquet(TRAIN_PATH)
validation_events.to_parquet(VAL_PATH)
test_events.to_parquet(PREDICT_PATH)

encoder.save(ENCODER_PATH)

# Prepare to model training
### Create the tensor schema
A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is required by the SasRec model to correctly create embeddings for every source column. Note that user_id does not required in `TensorSchema`.

Note that the **padding value** is the next value (item_id) after the last one. **Cardinality** is the number of unique values ​​given the padding value.

In [11]:
from replay.data import FeatureHint, FeatureType
from replay.data.nn import TensorFeatureInfo, TensorSchema


EMBEDDING_DIM = 64

encoder = encoder.load(ENCODER_PATH)
NUM_UNIQUE_ITEMS = len(encoder.mapping["item_id"])

tensor_schema = TensorSchema(
    [
        TensorFeatureInfo(
            name="item_id",
            is_seq=True,
            padding_value=NUM_UNIQUE_ITEMS,
            cardinality=NUM_UNIQUE_ITEMS + 1,  # taking into account padding
            embedding_dim=EMBEDDING_DIM,
            feature_type=FeatureType.CATEGORICAL,
            feature_hint=FeatureHint.ITEM_ID,
        )
    ]
)

### Configure ParquetModule and transformation pipelines

The `ParquetModule` class enables training of models on large datasets by reading data in batch-wise way. This class initialized with **paths to every data split, a metadata dict containing information about shape and padding value of every column and a dict of transforms**. `ParquetModule`'s  "transform pipelines" are stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass.  

For SasRec model, RePlay provides a function that generates a sequence of appropriate transforms for each data split named **make_default_sasrec_transforms**.

Internally this function creates the following transforms:
1) Training:
    1. Create a target, which contains the shifted item sequence that represents the next item in the sequence (for the next item prediction task).
    2. Rename features to match it with expected format by the model during training.
    3. Unsqueeze target (*positive_labels*) and it's padding mask (*target_padding_mask*) for getting required shape of this tensors for loss computation.
    4. Group input features to be embed in expected format.

2) Validation/Inference:
    1. Rename/group features to match it with expected format by the model during valdiation/inference.

If a different set of transforms is required, you can create them yourself and submit them to the ParquetModule in the form of a dictionary where the key is the name of the split, and the value is the list of transforms. Available transforms are in the replay/nn/transforms/.

**Note:** One of the transforms for the training data prepares the initial sequence for the task of Next Item Prediction so it shifts the sequence of items. For the final sequence length to be correct, you need to set shape of item_id in metadata as **model sequence length + shift**. Default shift value is 1.

In [12]:
import copy

import torch

from replay.data.nn import TensorSchema
from replay.nn.transform import GroupTransform, NextTokenTransform, RenameTransform, UnsqueezeTransform, UniformNegativeSamplingTransform

In [20]:
def make_sasrec_transforms(
    tensor_schema: TensorSchema, query_column: str = "query_id", num_negative_samples: int = 128,
) -> dict[str, list[torch.nn.Module]]:
    item_column = tensor_schema.item_id_feature_name
    vocab_size = tensor_schema[item_column].cardinality
    train_transforms = [
        UniformNegativeSamplingTransform(vocab_size, num_negative_samples),
        NextTokenTransform(label_field=item_column, query_features=query_column, shift=1),
        RenameTransform(
            {
                query_column: "query_id",
                f"{item_column}_mask": "padding_mask",
                "positive_labels_mask": "target_padding_mask",
            }
        ),
        UnsqueezeTransform("target_padding_mask", -1),
        UnsqueezeTransform("positive_labels", -1),
        GroupTransform({"feature_tensors": [item_column]}),
    ]

    val_transforms = [
        RenameTransform({query_column: "query_id", f"{item_column}_mask": "padding_mask"}),
        GroupTransform({"feature_tensors": [item_column]}),
    ]
    test_transforms = copy.deepcopy(val_transforms)

    predict_transforms = copy.deepcopy(val_transforms)

    transforms = {
        "train": train_transforms,
        "validate": val_transforms,
        "test": test_transforms,
        "predict": predict_transforms,
    }

    return transforms

transforms = make_sasrec_transforms(tensor_schema, query_column="user_id")

In [21]:
MAX_SEQ_LEN = 50

def create_meta(shape: int, gt_shape: Optional[int] = None):
    meta = {
        "user_id": {},
        "item_id": {"shape": shape, "padding": tensor_schema["item_id"].padding_value},
    }
    if gt_shape is not None:
        meta.update({"ground_truth": {"shape": gt_shape, "padding": -1}})

    return meta

train_metadata = {
    "train": create_meta(shape=MAX_SEQ_LEN+1),
    "validate": create_meta(shape=MAX_SEQ_LEN, gt_shape=1),
}

In [22]:
from replay.data.nn import ParquetModule

BATCH_SIZE = 32

parquet_module = ParquetModule(
    train_path=TRAIN_PATH,
    validate_path=VAL_PATH,
    batch_size=BATCH_SIZE,
    metadata=train_metadata,
    transforms=transforms,
)

## Train model
### Create SasRec model instance and run the training stage using lightning
We may now train the model using the Lightning trainer class. 

RePlay's implementation of SasRec is designed in a modular, **block-based approach**. Instead of passing configuration parameters to the constructor, SasRec is now built by providing fully initialized components that makes the model more flexible and easier to extend.

#### Default Configuration

Default SasRec model may be created quickly via method **from_params**. Default model instance has CE loss, original SasRec transformer layes, and embeddings are aggregated via sum.

In [23]:
from replay.nn.sequential import SasRec
from typing import Literal
def make_sasrec(
    schema: TensorSchema,
    embedding_dim: int = 192,
    num_heads: int = 4,
    num_blocks: int = 2,
    max_sequence_length: int = 50,
    dropout: float = 0.3,
    excluded_features: Optional[list[str]] = None,
    categorical_list_feature_aggregation_method: Literal["sum", "mean", "max"] = "sum",
) -> SasRec:
    from replay.nn.sequential.sasrec import SasRecBody, SasRecTransformerLayer
    from replay.nn.agg import SumAggregator
    from replay.nn.embedding import SequenceEmbedding
    from replay.nn.loss import CE, CESampled
    from replay.nn.mask import DefaultAttentionMask
    from replay.nn.sequential.sasrec.agg import PositionAwareAggregator
    from replay.nn.sequential.sasrec.transformer import SasRecTransformerLayer
    excluded_features = [
        schema.query_id_feature_name,
        schema.timestamp_feature_name,
        *(excluded_features or []),
    ]
    excluded_features = list(set(excluded_features))
    body = SasRecBody(
        embedder=SequenceEmbedding(
            schema=schema,
            categorical_list_feature_aggregation_method=categorical_list_feature_aggregation_method,
            excluded_features=excluded_features,
        ),
        embedding_aggregator=PositionAwareAggregator(
            embedding_aggregator=SumAggregator(embedding_dim=embedding_dim),
            max_sequence_length=max_sequence_length,
            dropout=dropout,
        ),
        attn_mask_builder=DefaultAttentionMask(
            reference_feature_name=schema.item_id_feature_name,
            num_heads=num_heads,
        ),
        encoder=SasRecTransformerLayer(
            embedding_dim=embedding_dim,
            num_heads=num_heads,
            num_blocks=num_blocks,
            dropout=dropout,
            activation="relu",
        ),
        output_normalization=torch.nn.LayerNorm(embedding_dim),
    )
    padding_idx = schema.item_id_features.item().padding_value
    return SasRec(
        body=body,
        loss=CESampled(padding_idx=padding_idx),
    )

In [24]:
NUM_BLOCKS = 2
NUM_HEADS = 2
DROPOUT = 0.3

sasrec = make_sasrec(
    schema=tensor_schema,
    embedding_dim=EMBEDDING_DIM,
    max_sequence_length=MAX_SEQ_LEN,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    dropout=DROPOUT,
)

A universal PyTorch Lightning module is provided. It can work with any NN model.

In [25]:
from replay.nn.lightning.optimizer import OptimizerFactory
from replay.nn.lightning.scheduler import LRSchedulerFactory
from replay.nn.lightning import LightningModule

model = LightningModule(
    sasrec,
    optimizer_factory=OptimizerFactory(),
    lr_scheduler_factory=LRSchedulerFactory(),
)

To facilitate training, we add the following callbacks:
1) `ModelCheckpoint` - to save the best trained model based on its Recall metric. It's a default Lightning Callback.
1) `ComputeMetricsCallback` - to display a detailed validation metric matrix after each epoch. It's a custom RePlay callback for computing recsys metrics on validation and test stages.


In [26]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

from replay.nn.lightning.callback import ComputeMetricsCallback


checkpoint_callback = ModelCheckpoint(
    dirpath="sasrec/checkpoints",
    save_top_k=1,
    verbose=True,
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ComputeMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=NUM_UNIQUE_ITEMS,
)

csv_logger = CSVLogger(save_dir="sasrec/logs/train", name="SasRec-example")

trainer = L.Trainer(
    max_epochs=100,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

trainer.fit(model, datamodule=parquet_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | Mode  | FLOPs
-------------------------------------------------
0 | model | SasRec | 291 K  | train | 0    
-------------------------------------------------
291 K     Trainable params
0         Non-trainable params
291 K     Total params
1.164     Total estimated model params size (MB)
39        Modules in train mode
0         Modules in eval mode
0         Total Flops


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 189: 'recall@10' reached 0.01507 (best 0.01507), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=0-step=189-v3.ckpt' as top 1


k              1        10        20         5
map     0.001821  0.005005  0.005654  0.004220
ndcg    0.001821  0.007341  0.009693  0.005424
recall  0.001821  0.015069  0.024342  0.009107



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 378: 'recall@10' reached 0.02302 (best 0.02302), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=1-step=378-v2.ckpt' as top 1


k              1        10        20         5
map     0.002981  0.007374  0.008710  0.006044
ndcg    0.002981  0.010976  0.016031  0.007759
recall  0.002981  0.023017  0.043385  0.013082



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 567: 'recall@10' reached 0.02981 (best 0.02981), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=2-step=567.ckpt' as top 1


k              1        10        20         5
map     0.003809  0.009334  0.011016  0.007628
ndcg    0.003809  0.014037  0.020154  0.009877
recall  0.003809  0.029806  0.053982  0.016890



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 756: 'recall@10' reached 0.04554 (best 0.04554), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=3-step=756-v3.ckpt' as top 1


k              1        10        20         5
map     0.004637  0.013212  0.015986  0.010153
ndcg    0.004637  0.020583  0.030981  0.013059
recall  0.004637  0.045537  0.087266  0.022024



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 945: 'recall@10' reached 0.06143 (best 0.06143), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=4-step=945-v4.ckpt' as top 1


k              1        10        20         5
map     0.008611  0.020454  0.024102  0.016998
ndcg    0.008611  0.029885  0.043547  0.021429
recall  0.008611  0.061434  0.116244  0.035105



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 1134: 'recall@10' reached 0.08147 (best 0.08147), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=5-step=1134-v1.ckpt' as top 1


k              1        10        20         5
map     0.006789  0.023326  0.027358  0.018168
ndcg    0.006789  0.036654  0.051677  0.023948
recall  0.006789  0.081470  0.141580  0.041729



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 1323: 'recall@10' reached 0.09141 (best 0.09141), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=6-step=1323-v2.ckpt' as top 1


k              1        10        20         5
map     0.008776  0.027117  0.031929  0.021949
ndcg    0.008776  0.041864  0.059872  0.028990
recall  0.008776  0.091406  0.163603  0.050836



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 1512: 'recall@10' reached 0.10383 (best 0.10383), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=7-step=1512.ckpt' as top 1


k             1        10        20         5
map     0.01126  0.032126  0.037192  0.026006
ndcg    0.01126  0.048599  0.067473  0.033582
recall  0.01126  0.103825  0.179334  0.056963



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 1701: 'recall@10' reached 0.11244 (best 0.11244), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=8-step=1701.ckpt' as top 1


k              1        10        20         5
map     0.012088  0.033863  0.039145  0.026379
ndcg    0.012088  0.051835  0.071279  0.033566
recall  0.012088  0.112436  0.189767  0.055638



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 1890: 'recall@10' reached 0.12038 (best 0.12038), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=9-step=1890-v2.ckpt' as top 1


k              1        10        20         5
map     0.014075  0.037768  0.043720  0.030035
ndcg    0.014075  0.056689  0.078432  0.037731
recall  0.014075  0.120384  0.206491  0.061268



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 2079: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014572  0.038940  0.045264  0.032039
ndcg    0.014572  0.057550  0.080690  0.040461
recall  0.014572  0.119887  0.211624  0.066236



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 2268: 'recall@10' reached 0.13313 (best 0.13313), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=11-step=2268.ckpt' as top 1


k              1        10        20         5
map     0.014572  0.041118  0.047172  0.032944
ndcg    0.014572  0.062227  0.084644  0.042165
recall  0.014572  0.133135  0.222553  0.070541



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 2457: 'recall@10' reached 0.13578 (best 0.13578), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=12-step=2457-v1.ckpt' as top 1


k              1        10        20         5
map     0.012916  0.041002  0.048065  0.032950
ndcg    0.012916  0.062827  0.088817  0.043172
recall  0.012916  0.135784  0.239112  0.074681



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13, global step 2646: 'recall@10' reached 0.13761 (best 0.13761), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=13-step=2646.ckpt' as top 1


k              1        10        20         5
map     0.016393  0.043745  0.049922  0.035814
ndcg    0.016393  0.065313  0.088050  0.045814
recall  0.016393  0.137606  0.228018  0.076668



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14, global step 2835: 'recall@10' reached 0.14572 (best 0.14572), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=14-step=2835.ckpt' as top 1


k              1        10        20         5
map     0.013744  0.044572  0.051397  0.036082
ndcg    0.013744  0.067861  0.093062  0.046887
recall  0.013744  0.145719  0.246067  0.079980



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15, global step 3024: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014572  0.043037  0.049347  0.033872
ndcg    0.014572  0.066414  0.089676  0.043945
recall  0.014572  0.145057  0.237622  0.075012



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16, global step 3213: 'recall@10' reached 0.15301 (best 0.15301), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=16-step=3213.ckpt' as top 1


k              1        10        20         5
map     0.020533  0.050175  0.056921  0.041307
ndcg    0.020533  0.073800  0.098705  0.052068
recall  0.020533  0.153005  0.252194  0.085279



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17, global step 3402: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016228  0.044922  0.052359  0.035591
ndcg    0.016228  0.068935  0.096304  0.045990
recall  0.016228  0.149859  0.258652  0.078159



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 18, global step 3591: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.048151  0.055104  0.039250
ndcg    0.017056  0.072077  0.097874  0.050346
recall  0.017056  0.152012  0.255009  0.084451



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19, global step 3780: 'recall@10' reached 0.16443 (best 0.16443), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=19-step=3780.ckpt' as top 1


k              1        10        20         5
map     0.017553  0.050482  0.057555  0.040688
ndcg    0.017553  0.076663  0.102664  0.052659
recall  0.017553  0.164431  0.267760  0.089584



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 20, global step 3969: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.048458  0.055842  0.039204
ndcg    0.015731  0.074221  0.101518  0.051256
recall  0.015731  0.160623  0.269415  0.088425



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 21, global step 4158: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.049611  0.056758  0.040399
ndcg    0.018049  0.074154  0.100495  0.051515
recall  0.018049  0.156317  0.261136  0.085610



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 22, global step 4347: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016393  0.050516  0.057873  0.040887
ndcg    0.016393  0.076624  0.103727  0.053190
recall  0.016393  0.163603  0.271403  0.090909



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 23, global step 4536: 'recall@10' reached 0.16559 (best 0.16559), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=23-step=4536.ckpt' as top 1


k              1        10        20         5
map     0.014903  0.050339  0.057592  0.040376
ndcg    0.014903  0.076891  0.103476  0.052462
recall  0.014903  0.165590  0.271071  0.089419



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 24, global step 4725: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01689  0.049176  0.056484  0.040109
ndcg    0.01689  0.074093  0.100858  0.052004
recall  0.01689  0.157145  0.263289  0.088591



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 25, global step 4914: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.051505  0.058838  0.042118
ndcg    0.018049  0.077604  0.104511  0.054477
recall  0.018049  0.164928  0.271734  0.092565



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 26, global step 5103: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017718  0.051380  0.059348  0.042295
ndcg    0.017718  0.076988  0.106090  0.054788
recall  0.017718  0.162279  0.277529  0.093227



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 27, global step 5292: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017221  0.048906  0.057236  0.039648
ndcg    0.017221  0.074121  0.104692  0.051467
recall  0.017221  0.158470  0.279848  0.087928



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 28, global step 5481: 'recall@10' was not in top 1


k             1        10        20         5
map     0.01954  0.051349  0.059358  0.041872
ndcg    0.01954  0.076578  0.106111  0.053326
recall  0.01954  0.161119  0.278689  0.088591



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 29, global step 5670: 'recall@10' reached 0.17221 (best 0.17221), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=29-step=5670.ckpt' as top 1


k              1        10        20         5
map     0.021527  0.055773  0.063681  0.046164
ndcg    0.021527  0.082563  0.111508  0.058804
recall  0.021527  0.172214  0.286968  0.097698



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 30, global step 5859: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018877  0.053691  0.061561  0.044171
ndcg    0.018877  0.080600  0.109718  0.057056
recall  0.018877  0.170558  0.286637  0.096705



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 31, global step 6048: 'recall@10' reached 0.17404 (best 0.17404), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=31-step=6048.ckpt' as top 1


k              1        10        20         5
map     0.014903  0.050917  0.058470  0.040520
ndcg    0.014903  0.079214  0.107110  0.053566
recall  0.014903  0.174035  0.285147  0.093724



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 32, global step 6237: 'recall@10' reached 0.17553 (best 0.17553), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=32-step=6237.ckpt' as top 1


k              1        10        20         5
map     0.018049  0.053705  0.061464  0.043180
ndcg    0.018049  0.081750  0.110451  0.056098
recall  0.018049  0.175526  0.289949  0.095877



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 33, global step 6426: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019208  0.053874  0.061629  0.043597
ndcg    0.019208  0.081296  0.109911  0.056049
recall  0.019208  0.173207  0.287134  0.094386



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 34, global step 6615: 'recall@10' reached 0.18016 (best 0.18016), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=34-step=6615.ckpt' as top 1


k              1        10        20         5
map     0.017553  0.054073  0.061606  0.043559
ndcg    0.017553  0.083073  0.111055  0.057211
recall  0.017553  0.180162  0.291936  0.099354



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 35, global step 6804: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018215  0.055078  0.062641  0.044364
ndcg    0.018215  0.082983  0.110672  0.056984
recall  0.018215  0.176023  0.285809  0.095546



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 36, global step 6993: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017221  0.052581  0.060386  0.042071
ndcg    0.017221  0.080544  0.109221  0.054559
recall  0.017221  0.174367  0.288293  0.092896



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 37, global step 7182: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017884  0.053871  0.061073  0.043766
ndcg    0.017884  0.082064  0.108563  0.057241
recall  0.017884  0.176354  0.281669  0.098857



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 38, global step 7371: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018546  0.054212  0.061996  0.044508
ndcg    0.018546  0.081737  0.110419  0.057725
recall  0.018546  0.173704  0.287796  0.098361



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 39, global step 7560: 'recall@10' reached 0.18082 (best 0.18082), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=39-step=7560.ckpt' as top 1


k             1        10        20         5
map     0.02103  0.057426  0.064922  0.046829
ndcg    0.02103  0.085795  0.113317  0.059729
recall  0.02103  0.180825  0.290114  0.099354



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 40, global step 7749: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020699  0.055658  0.063230  0.044966
ndcg    0.020699  0.083080  0.111028  0.057106
recall  0.020699  0.174863  0.286140  0.094386



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 41, global step 7938: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016725  0.052379  0.060527  0.042126
ndcg    0.016725  0.079931  0.110088  0.054617
recall  0.016725  0.172214  0.292433  0.092896



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 42, global step 8127: 'recall@10' was not in top 1


k              1        10        20         5
map     0.021692  0.058423  0.066561  0.048173
ndcg    0.021692  0.086242  0.116025  0.061053
recall  0.021692  0.179169  0.297235  0.100513



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 43, global step 8316: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019705  0.055943  0.063840  0.045391
ndcg    0.019705  0.084424  0.113300  0.058431
recall  0.019705  0.179831  0.294254  0.098526



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 44, global step 8505: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015069  0.053308  0.061374  0.043291
ndcg    0.015069  0.081830  0.111383  0.057192
recall  0.015069  0.176850  0.294088  0.099851



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 45, global step 8694: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.055185  0.062849  0.044538
ndcg    0.018049  0.083015  0.111301  0.057084
recall  0.018049  0.175857  0.288458  0.095380



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 46, global step 8883: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017387  0.054915  0.062717  0.044329
ndcg    0.017387  0.083518  0.112327  0.057648
recall  0.017387  0.179003  0.293757  0.098526



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 47, global step 9072: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017553  0.055809  0.064044  0.045537
ndcg    0.017553  0.084333  0.114317  0.059376
recall  0.017553  0.179169  0.297731  0.101838



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 48, global step 9261: 'recall@10' reached 0.18596 (best 0.18596), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=48-step=9261.ckpt' as top 1


k              1        10        20         5
map     0.022355  0.060448  0.068135  0.049738
ndcg    0.022355  0.089385  0.117671  0.063264
recall  0.022355  0.185958  0.298394  0.104819



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 49, global step 9450: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020533  0.058365  0.065847  0.047737
ndcg    0.020533  0.087662  0.115249  0.061533
recall  0.020533  0.185627  0.295413  0.103991



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 50, global step 9639: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019705  0.057252  0.065311  0.046550
ndcg    0.019705  0.086161  0.115841  0.059957
recall  0.019705  0.182812  0.300878  0.101176



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 51, global step 9828: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019871  0.056373  0.064782  0.046465
ndcg    0.019871  0.084487  0.115268  0.060018
recall  0.019871  0.178341  0.300381  0.101672



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 52, global step 10017: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017056  0.055635  0.063238  0.044657
ndcg    0.017056  0.085148  0.113217  0.058444
recall  0.017056  0.183640  0.295413  0.100845



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 53, global step 10206: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016725  0.056399  0.064312  0.046448
ndcg    0.016725  0.085609  0.114614  0.061160
recall  0.016725  0.182646  0.297731  0.106309



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 54, global step 10395: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020864  0.056358  0.064559  0.045976
ndcg    0.020864  0.084551  0.114658  0.059012
recall  0.020864  0.179003  0.298559  0.099189



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 55, global step 10584: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.052936  0.060713  0.042005
ndcg    0.015897  0.082446  0.111144  0.055636
recall  0.015897  0.181321  0.295579  0.097698



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 56, global step 10773: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018215  0.055970  0.064171  0.045659
ndcg    0.018215  0.084831  0.115129  0.059707
recall  0.018215  0.180990  0.301706  0.102997



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 57, global step 10962: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019705  0.057038  0.065075  0.046796
ndcg    0.019705  0.085973  0.115791  0.060577
recall  0.019705  0.182812  0.301871  0.102997



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 58, global step 11151: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.054501  0.062389  0.043956
ndcg    0.015731  0.084609  0.113613  0.058502
recall  0.015731  0.185296  0.300546  0.103328



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 59, global step 11340: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019705  0.057285  0.065306  0.046415
ndcg    0.019705  0.086774  0.116052  0.060147
recall  0.019705  0.185461  0.301374  0.102500



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 60, global step 11529: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.056947  0.064405  0.046498
ndcg    0.018049  0.086052  0.113448  0.060260
recall  0.018049  0.183143  0.291936  0.102335



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 61, global step 11718: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017553  0.057223  0.064880  0.046870
ndcg    0.017553  0.086840  0.115219  0.061361
recall  0.017553  0.185461  0.298725  0.105812



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 62, global step 11907: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017718  0.056284  0.064414  0.045866
ndcg    0.017718  0.085691  0.115490  0.060021
recall  0.017718  0.183805  0.302037  0.103494



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 63, global step 12096: 'recall@10' was not in top 1


k              1        10        20         5
map     0.014075  0.053323  0.061450  0.042706
ndcg    0.014075  0.083044  0.113033  0.057016
recall  0.014075  0.182149  0.301540  0.101010



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 64, global step 12285: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016559  0.055950  0.063635  0.045606
ndcg    0.016559  0.085525  0.113860  0.060119
recall  0.016559  0.183971  0.296738  0.104653



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 65, global step 12474: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017718  0.056836  0.065167  0.046782
ndcg    0.017718  0.085591  0.116301  0.060821
recall  0.017718  0.181321  0.303527  0.103825



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 66, global step 12663: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015897  0.054146  0.063006  0.043884
ndcg    0.015897  0.083296  0.116044  0.058099
recall  0.015897  0.180493  0.310979  0.101838



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 67, global step 12852: 'recall@10' reached 0.19159 (best 0.19159), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=67-step=12852.ckpt' as top 1


k              1        10        20         5
map     0.017056  0.057459  0.065661  0.045971
ndcg    0.017056  0.088389  0.118391  0.060327
recall  0.017056  0.191588  0.310482  0.104322



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 68, global step 13041: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017718  0.058559  0.066859  0.048270
ndcg    0.017718  0.088819  0.119630  0.063497
recall  0.017718  0.189435  0.312469  0.110283



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 69, global step 13230: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017387  0.057117  0.065430  0.046705
ndcg    0.017387  0.087054  0.117606  0.061578
recall  0.017387  0.186620  0.307998  0.107303



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 70, global step 13419: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019043  0.057870  0.066302  0.047017
ndcg    0.019043  0.087867  0.118650  0.061382
recall  0.019043  0.187945  0.309820  0.105647



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 71, global step 13608: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.054928  0.063489  0.044635
ndcg    0.015731  0.084751  0.116133  0.059363
recall  0.015731  0.184136  0.308660  0.104653



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 72, global step 13797: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018381  0.056898  0.065084  0.046131
ndcg    0.018381  0.086764  0.116853  0.060242
recall  0.018381  0.186620  0.306177  0.103660



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 73, global step 13986: 'recall@10' reached 0.19308 (best 0.19308), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=73-step=13986.ckpt' as top 1


k             1        10        20         5
map     0.01954  0.060165  0.068179  0.048888
ndcg    0.01954  0.090827  0.120453  0.063258
recall  0.01954  0.193078  0.311144  0.107303



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 74, global step 14175: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020202  0.059898  0.068066  0.049208
ndcg    0.020202  0.090335  0.120156  0.063797
recall  0.020202  0.192085  0.310151  0.108627



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 75, global step 14364: 'recall@10' reached 0.19540 (best 0.19540), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=75-step=14364.ckpt' as top 1


k              1        10        20         5
map     0.017553  0.059239  0.067168  0.047867
ndcg    0.017553  0.090614  0.119642  0.062577
recall  0.017553  0.195397  0.310482  0.107634



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 76, global step 14553: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017553  0.057571  0.066031  0.046285
ndcg    0.017553  0.088607  0.119752  0.060844
recall  0.017553  0.192416  0.316278  0.105647



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 77, global step 14742: 'recall@10' was not in top 1


k             1        10        20         5
map     0.02252  0.059148  0.067676  0.048570
ndcg    0.02252  0.087980  0.119156  0.062026
recall  0.02252  0.184468  0.307998  0.103494



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 78, global step 14931: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020864  0.059902  0.068102  0.048949
ndcg    0.020864  0.089627  0.119711  0.062921
recall  0.020864  0.188773  0.308164  0.105812



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 79, global step 15120: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018049  0.057458  0.065510  0.046583
ndcg    0.018049  0.087580  0.117393  0.060762
recall  0.018049  0.188276  0.307170  0.104322



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 80, global step 15309: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018546  0.057481  0.065625  0.046603
ndcg    0.018546  0.087740  0.117695  0.061046
recall  0.018546  0.188773  0.307832  0.105481



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 81, global step 15498: 'recall@10' reached 0.19672 (best 0.19672), saving model to '/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=81-step=15498.ckpt' as top 1


k              1        10        20         5
map     0.019705  0.060155  0.067919  0.048319
ndcg    0.019705  0.091545  0.120256  0.062412
recall  0.019705  0.196721  0.311144  0.105647



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 82, global step 15687: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018381  0.057389  0.065948  0.046868
ndcg    0.018381  0.087341  0.118833  0.061256
recall  0.018381  0.187448  0.312635  0.105481



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 83, global step 15876: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017387  0.057521  0.065933  0.046556
ndcg    0.017387  0.088367  0.119428  0.061458
recall  0.017387  0.191257  0.314953  0.107303



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 84, global step 16065: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019374  0.058125  0.066462  0.047210
ndcg    0.019374  0.087957  0.118482  0.061159
recall  0.019374  0.187614  0.308660  0.103991



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 85, global step 16254: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016725  0.055607  0.064210  0.044947
ndcg    0.016725  0.085456  0.117166  0.059262
recall  0.016725  0.185130  0.311310  0.103328



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 86, global step 16443: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016559  0.057457  0.065991  0.046661
ndcg    0.016559  0.087773  0.119155  0.061410
recall  0.016559  0.188607  0.313297  0.106640



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 87, global step 16632: 'recall@10' was not in top 1


k              1        10        20         5
map     0.015731  0.057037  0.065072  0.045678
ndcg    0.015731  0.088516  0.118281  0.060867
recall  0.015731  0.193410  0.312138  0.107634



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 88, global step 16821: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018712  0.059290  0.067419  0.048380
ndcg    0.018712  0.090193  0.120127  0.063372
recall  0.018712  0.193244  0.312303  0.109455



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 89, global step 17010: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019208  0.060027  0.067680  0.048468
ndcg    0.019208  0.091291  0.119452  0.062846
recall  0.019208  0.195893  0.307832  0.106971



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 90, global step 17199: 'recall@10' was not in top 1


k              1        10        20         5
map     0.017553  0.059004  0.066841  0.047811
ndcg    0.017553  0.089908  0.118707  0.062456
recall  0.017553  0.192913  0.307336  0.107303



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 91, global step 17388: 'recall@10' was not in top 1


k              1        10        20         5
map     0.019871  0.059618  0.068323  0.048620
ndcg    0.019871  0.089395  0.121484  0.062390
recall  0.019871  0.188773  0.316443  0.104488



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 92, global step 17577: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016393  0.056838  0.065549  0.046489
ndcg    0.016393  0.087088  0.118991  0.061682
recall  0.016393  0.187779  0.314290  0.108462



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 93, global step 17766: 'recall@10' was not in top 1


k              1        10        20         5
map     0.020864  0.060046  0.068195  0.049431
ndcg    0.020864  0.090098  0.120281  0.063916
recall  0.020864  0.190429  0.310813  0.108462



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 94, global step 17955: 'recall@10' was not in top 1


k              1        10        20         5
map     0.018215  0.059519  0.067797  0.048775
ndcg    0.018215  0.090521  0.120919  0.064276
recall  0.018215  0.193575  0.314290  0.111939



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 95, global step 18144: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016393  0.057420  0.065561  0.045932
ndcg    0.016393  0.088958  0.118821  0.060803
recall  0.016393  0.194237  0.312800  0.106475



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 96, global step 18333: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016725  0.056762  0.065223  0.045708
ndcg    0.016725  0.087184  0.118141  0.060252
recall  0.016725  0.188607  0.311310  0.104984



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 97, global step 18522: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016228  0.056993  0.065326  0.045982
ndcg    0.016228  0.087804  0.118434  0.060850
recall  0.016228  0.190429  0.312138  0.106475



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 98, global step 18711: 'recall@10' was not in top 1


k              1        10        20         5
map     0.016393  0.056134  0.064101  0.044497
ndcg    0.016393  0.087227  0.116437  0.058923
recall  0.016393  0.191091  0.307004  0.103328



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 99, global step 18900: 'recall@10' was not in top 1
`Trainer.fit` stopped: `max_epochs=100` reached.


k              1        10        20         5
map     0.017884  0.058572  0.066894  0.048223
ndcg    0.017884  0.088838  0.119484  0.063391
recall  0.017884  0.189435  0.311310  0.109952



Now we can get the best model path stored in the checkpoint callback.

In [27]:
best_model_path = checkpoint_callback.best_model_path
best_model_path

'/home/nkulikov/RePlay/examples/sasrec/checkpoints/epoch=81-step=15498.ckpt'

## Inference

To obtain model scores, we will load the weights from the best checkpoint. To do this, we use the `LightningModule`, provide there the path to the checkpoint and the model instance.

In [33]:
import replay
torch.serialization.add_safe_globals([
    replay.nn.lightning.optimizer.OptimizerFactory,
    replay.nn.lightning.scheduler.LRSchedulerFactory
])

In [34]:
sasrec = make_sasrec(
    schema=tensor_schema,
    embedding_dim=EMBEDDING_DIM,
    max_sequence_length=MAX_SEQ_LEN,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    dropout=DROPOUT,
)

best_model = LightningModule.load_from_checkpoint(best_model_path, model=sasrec)
best_model.eval();

Configure `ParquetModule` for inference

In [44]:
inference_metadata = {"predict": create_meta(shape=MAX_SEQ_LEN)}

parquet_module = ParquetModule(
    predict_path=PREDICT_PATH,
    batch_size=BATCH_SIZE,
    metadata=inference_metadata,
    transforms=transforms,
)

During inference, we can use `TopItemsCallback`. Such callback allows you to get scores for each user throughout the entire catalog and get recommendations in the form of ids of items with the highest score values.


Recommendations can be fetched in four formats: PySpark DataFrame, Pandas DataFrame, Polars DataFrame or raw PyTorch tensors. Each of the types corresponds a callback. In this example, we'll be using the `PandasTopItemsCallback`.

In [46]:
from replay.nn.lightning.callback import PandasTopItemsCallback

csv_logger = CSVLogger(save_dir="sasrec/logs/test", name="SasRec-example")

TOPK = [1, 5, 10, 20]

pandas_prediction_callback = PandasTopItemsCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
)

trainer = L.Trainer(callbacks=[pandas_prediction_callback], logger=csv_logger, inference_mode=True)
trainer.predict(best_model, datamodule=parquet_module, return_predictions=False)

pandas_res = pandas_prediction_callback.get_result()

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

In [47]:
pandas_res

Unnamed: 0,user_id,item_id,score
0,0,224,7.242342
0,0,572,6.818249
0,0,486,6.81148
0,0,1371,6.534966
0,0,210,6.52649
...,...,...,...
6037,6039,2497,10.457304
6037,6039,3503,10.305973
6037,6039,2601,10.280416
6037,6039,2750,10.01198


### Calculating metrics

*test_gt* is already encoded, so we can use it for computing metrics.

In [48]:
from replay.metrics import MAP, OfflineMetrics, Precision, Recall
from replay.metrics.torch_metrics_builder import metrics_to_df

In [49]:
result_metrics = OfflineMetrics(
    [Recall(TOPK), Precision(TOPK), MAP(TOPK)],
    query_column="user_id",
    rating_column="score",
)(pandas_res, test_gt.explode("item_id"))

In [50]:
metrics_to_df(result_metrics)

k,1,10,20,5
MAP,0.016065,0.054039,0.061749,0.043969
Precision,0.016065,0.017655,0.014458,0.020073
Recall,0.016065,0.176549,0.289169,0.100364


Let's call the `inverse_transform` encoder's function to get the final dataframe with recommendations

In [51]:
encoder.inverse_transform(pandas_res)

Unnamed: 0,user_id,item_id,score
0,2012,231,7.242342
0,2012,586,6.818249
0,2012,500,6.81148
0,2012,1485,6.534966
0,2012,216,6.52649
...,...,...,...
6037,5727,2702,10.457304
6037,5727,3745,10.305973
6037,5727,2806,10.280416
6037,5727,2961,10.01198
