# Example of SasRec training/inference with Parquet Module

## Imports and session initialization

In [1]:
import copy

import torch
import lightning as L
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql.window import Window

from replay.metrics.torch_metrics_builder import metrics_to_df
from replay.data import (
    FeatureHint,
    FeatureSource,
    FeatureType,
)
from replay.data.nn import (
    TensorFeatureInfo,
    TensorFeatureSource,
    TensorSchema,
)
from replay.metrics import MAP, OfflineMetrics, Precision, Recall
from replay.splitters import LastNSplitter, RatioSplitter
from replay.utils.session_handler import get_spark_session

# Fix seed to ensure reproducibility
L.seed_everything(42)

import warnings
warnings.filterwarnings("ignore")

Seed set to 42


In [2]:
spark_session = get_spark_session()

26/01/16 13:59:29 WARN Utils: Your hostname, ecs-evtsinovnik-64 resolves to a loopback address: 127.0.1.1; using 10.11.12.49 instead (on interface eth0)
26/01/16 13:59:29 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/01/16 13:59:30 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
26/01/16 13:59:30 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).


## Preparing data
In this example, we will be using the MovieLens dataset, namely the 1m subset.  
Begin by loading interactions, item features and user features using the created session.

---
**NOTE**

Current implementation of SasRec handles only item and interactions features. It does not take into account user features. As such, they are only used in this example to get complete lists of users.

---

In [3]:
schema = T.StructType([
    T.StructField("user_id", T.IntegerType(), True),
    T.StructField("item_id", T.IntegerType(), True),
    T.StructField("rating", T.IntegerType(), True),
    T.StructField("timestamp", T.LongType(), True),
])
interactions = spark_session.read.schema(schema).option("sep", "\t").csv("data/ml1m_ratings.dat")

# NOTE: The following code block is optional and is used
# to counteract the issue of identical timestamps in the dataset.
# Uncomment if you wish to use it.
# int_win = Window.partitionBy("user_id").orderBy("item_id")
# interactions = (
#     interactions.select(["user_id", "item_id", "timestamp"])
#     .withColumn("ts", F.col("timestamp").cast("long") * 1000)
#     .withColumn("row_num", F.row_number().over(int_win))
#     .withColumn("timestamp", (F.col("ts") + F.col("row_num")).cast("string"))
#     .drop("ts", "row_num")
# )
interactions.show(n=5)

26/01/16 13:59:31 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.
26/01/16 13:59:31 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.
26/01/16 13:59:31 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.


+-------+-------+------+---------+
|user_id|item_id|rating|timestamp|
+-------+-------+------+---------+
|      1|   1193|     5|978300760|
|      1|    661|     3|978302109|
|      1|    914|     3|978301968|
|      1|   3408|     4|978300275|
|      1|   2355|     5|978824291|
+-------+-------+------+---------+
only showing top 5 rows



In [4]:
schema = T.StructType([
    T.StructField("user_id", T.IntegerType(), True),
    T.StructField("age", T.IntegerType(), True),
    T.StructField("gender", T.StringType(), True),
    T.StructField("occupation", T.StringType(), True),
    T.StructField("zip_code", T.LongType(), True),
])
user_features = spark_session.read.schema(schema).option("sep", "\t").csv("data/ml1m_users.dat")
user_features.show(n=5)

+-------+----+------+----------+--------+
|user_id| age|gender|occupation|zip_code|
+-------+----+------+----------+--------+
|      1|null|     1|        10|   48067|
|      2|null|    56|        16|   70072|
|      3|null|    25|        15|   55117|
|      4|null|    45|         7|    2460|
|      5|null|    25|        20|   55455|
+-------+----+------+----------+--------+
only showing top 5 rows



In [None]:
schema = T.StructType([
    T.StructField("item_id", T.IntegerType(), True),
    T.StructField("movie_title", T.StringType(), True),
    T.StructField("genre", T.StringType(), True),
])
item_features = spark_session.read.schema(schema).option("sep", "\t").csv("data/ml1m_items.dat")
item_features.show(n=5)

+-------+--------------------+--------------------+
|item_id|         movie_title|               genre|
+-------+--------------------+--------------------+
|      1|    Toy Story (1995)|Animation|Childre...|
|      2|      Jumanji (1995)|Adventure|Childre...|
|      3|Grumpier Old Men ...|      Comedy|Romance|
|      4|Waiting to Exhale...|        Comedy|Drama|
|      5|Father of the Bri...|              Comedy|
+-------+--------------------+--------------------+
only showing top 5 rows



### Encode catagorical data.
To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset.

In [6]:
from replay.preprocessing.label_encoder import LabelEncoder, LabelEncodingRule
from replay.utils.types import SparkDataFrame


def encode_data(
    queries: SparkDataFrame, items: SparkDataFrame, interactions: SparkDataFrame, label_encoder: LabelEncoder
):
    full_data = interactions.join(queries, on="user_id").join(items, on="item_id")
    full_data = label_encoder.fit_transform(full_data)

    return full_data

In [7]:
encoder = LabelEncoder(
    [
        LabelEncodingRule("user_id", default_value="last"),
        LabelEncodingRule("item_id", default_value="last"),
        LabelEncodingRule("genre", default_value="last"),
    ]
)
encoded_interactions = encode_data(user_features, item_features, interactions, encoder)

26/01/16 13:59:35 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/01/16 13:59:35 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/01/16 13:59:35 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.


26/01/16 13:59:35 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/01/16 13:59:35 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/01/16 13:59:36 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/01/16 13:59:36 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/01/16 13:59:36 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/01/16 13:59:36 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
26/01/16 1

### Split interactions into the train, validation and test datasets using RatioSplitter

In order to facilitate the model's training, we split the dataset in the following way:
1) A 60/40 data split of original data for training and subsequent splits
2) A 75/25 split of the leftover data for testing/validation respectively (i.e. 30%/10% of the full dataset)

We also remove cold users/items after each split.

In [8]:
train_events, test_events = RatioSplitter(
    test_size=0.4,
    divide_column="user_id",
    query_column="user_id",
    timestamp_column="timestamp",
    drop_cold_users=True,
    drop_cold_items=True,
).split(encoded_interactions)

print(f"{train_events.count()=}, {test_events.count()=}")

                                                                                

train_events.count()=597866, test_events.count()=402190


In [9]:
test_events, val_events = RatioSplitter(
    test_size=0.25,
    divide_column="user_id",
    query_column="user_id",
    timestamp_column="timestamp",
    drop_cold_users=True,
    drop_cold_items=True,
).split(test_events)

print(f"{test_events.count()=}, {val_events.count()=}")

                                                                                

test_events.count()=299479, val_events.count()=102645


### Split the validation dataset into events and ground_truth

For both validation and testing data, the last N items are split into ground truth, which will be used to calculate metrics.

In [10]:
VALIDATION_GROUND_TRUTH_INTERACTIONS_PER_USER = 3
TEST_GROUND_TRUTH_INTERACTIONS_PER_USER = 3

val_events, val_gt = LastNSplitter(
    N=VALIDATION_GROUND_TRUTH_INTERACTIONS_PER_USER,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
).split(val_events)
print(f"{val_events.count()=}, {val_gt.count()=}")

test_events, test_gt = LastNSplitter(
    N=TEST_GROUND_TRUTH_INTERACTIONS_PER_USER, divide_column="user_id", query_column="user_id", strategy="interactions"
).split(test_events)
print(f"{test_events.count()=}, {test_gt.count()=}")

                                                                                

val_events.count()=84611, val_gt.count()=18034


                                                                                

test_events.count()=281359, test_gt.count()=18120


### Dataset preprocessing ("baking")
SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay.

In [11]:
from replay.data.nn.utils import groupby_sequences


def bake_data(full_data: SparkDataFrame):
    grouped_interactions = groupby_sequences(events=full_data, groupby_col="user_id", sort_col="timestamp")

    return grouped_interactions

In [12]:
train_events = bake_data(train_events)
val_events = bake_data(val_events)
val_gt = bake_data(val_gt)
test_events = bake_data(test_events)

To ensure we don't validate on unknown users, we join train and validation data by user ids, leaving only the common ones.  
We also pre-package the validation data with its ground truth and train-time events.

In [13]:
# Keep common query ids between val_dataset and val_gt
val_events = val_events.join(val_gt, on="user_id", how="left_semi")
val_gt = val_gt.join(val_events, on="user_id", how="left_semi")

gt_to_join = val_gt.select(["user_id", "item_id"]).withColumnRenamed("item_id", "ground_truth")
train_to_join = train_events.select(["user_id", "item_id"]).withColumnRenamed("item_id", "train")

val_events = val_events.join(gt_to_join, on="user_id", how="left")
val_events = val_events.join(train_to_join, on="user_id", how="left")

TRAIN_LEN = val_events.select(F.max(F.size("train")).alias("res")).collect()[0].res
GT_LEN = val_events.select(F.max(F.size("ground_truth")).alias("res")).collect()[0].res

                                                                                

In [14]:
TRAIN_PATH = "temp/data/train.parquet"
VAL_PATH = "temp/data/val.parquet"
TEST_PATH = "temp/data/test.parquet"

train_events.write.mode("overwrite").parquet(TRAIN_PATH)
val_events.write.mode("overwrite").parquet(VAL_PATH)
test_events.write.mode("overwrite").parquet(TEST_PATH)

                                                                                

### Create the tensor schema
A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is required by the SasRec model to correctly create embeddings at train time.

In [15]:
EMBEDDING_DIM = 64

ITEM_FEATURE_NAME = "item_id"
NUM_UNIQUE_ITEMS = len(encoder.mapping["item_id"])
NUM_UNIQUE_CLASS_VALUES = len(encoder.mapping["genre"])

tensor_schema = TensorSchema(
    [
        TensorFeatureInfo(
            name="item_id",
            is_seq=True,
            padding_value=NUM_UNIQUE_ITEMS,
            cardinality=NUM_UNIQUE_ITEMS + 1,  # taking into account padding
            embedding_dim=EMBEDDING_DIM,
            feature_type=FeatureType.CATEGORICAL,
            feature_sources=[TensorFeatureSource(FeatureSource.ITEM_FEATURES, "item_id")],
            feature_hint=FeatureHint.ITEM_ID,
        ),
        TensorFeatureInfo(
            name="genre",
            is_seq=True,
            padding_value=NUM_UNIQUE_CLASS_VALUES,
            cardinality=NUM_UNIQUE_CLASS_VALUES + 1,  # taking into account padding
            embedding_dim=EMBEDDING_DIM,
            feature_type=FeatureType.CATEGORICAL,
            feature_sources=[TensorFeatureSource(FeatureSource.ITEM_FEATURES, "item_id")],
        ),
    ]
)

### Configure ParquetModule and transformation pipelines
The `ParquetModule` class enables training of models on large datasets by reading data in streaming mode. This class initialized with a metadata dict containing information about dataset's features and miscellanious options for initialization (such as shuffling).

Additionally, `ParquetModule` supports "transform pipelines" - stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass.  

In our case, we create the following pipelines:
1) Training:
    1. Create a target, which contains the shifted item sequence that represents the next item in the sequence (for the next item prediction task).
    2. Optionally sample negatives (required only for sampled losses).
    3. Rename features to match it with expected format by the model during training.
    4. Unsqueeze target (`positive_labels`) and it's padding mask (`target_padding_mask`) for getting required shape of this tensors for loss computation.
    5. Group input features to be embed in expected format.

2) Validation/Inference:
    1. Rename/group features to match it with expected format by the model during valdiation/inference.

Then, metadata for ParquetModule should be created. It contains shape and padding value for each feature.

In [30]:
from replay.nn.transforms import (
    UnsqueezeTransform,
    GroupTransform,
    RenameTransform,
    NextTokenTransform,
    UniformNegativeSamplingTransform,
    TrimTransform,
    CopyTransform
)

MAX_SEQ_LEN = 50
BATCH_SIZE = 32
SHIFT = 1

TRANSFORMS = {
    "train": [
        NextTokenTransform(
            label_field="item_id", query_features="user_id", shift=SHIFT, out_feature_name="positive_labels"
        ),
        UniformNegativeSamplingTransform(vocab_size=NUM_UNIQUE_ITEMS, num_negative_samples=200),
        RenameTransform(
            {"user_id": "query_id", "item_id_mask": "padding_mask", "positive_labels_mask": "target_padding_mask"}
        ),
        UnsqueezeTransform("target_padding_mask", -1),
        UnsqueezeTransform("positive_labels", -1),
        GroupTransform({"feature_tensors": ["item_id", "genre"]}),
    ],
    "validate": [
        RenameTransform({"user_id": "query_id", "item_id_mask": "padding_mask"}),
        CopyTransform({"item_id": "seen_ids"}),
        TrimTransform(seq_len=MAX_SEQ_LEN, feature_names=["item_id", "padding_mask"]),
        GroupTransform({"feature_tensors": ["item_id", "genre"]}),
    ],
    "predict": [
        RenameTransform({"user_id": "query_id", "item_id_mask": "padding_mask"}),
        CopyTransform({"item_id": "seen_ids"}),
        TrimTransform(seq_len=MAX_SEQ_LEN, feature_names=["item_id", "padding_mask"]),
        GroupTransform({"feature_tensors": ["item_id", "genre"]}),
    ],
}

def create_meta(feature_names, common_seq_len=50, gt_mapping=None, item_id_shape=None):
    meta = {"user_id": {}}
    for feature in feature_names:
        
        meta.update({feature: {"shape": common_seq_len, "padding": tensor_schema[feature].padding_value}})
        if feature == "item_id" and item_id_shape is not None:
            meta["item_id"]["shape"] = item_id_shape
    
    if gt_mapping:
        meta.update({
        "train": {"shape": gt_mapping["train"], "padding": -2},
        "ground_truth": {"shape": gt_mapping["ground_truth"], "padding": -1}
        })

    return meta

METADATA = {
    "train": create_meta(feature_names=["item_id", "genre"], common_seq_len=MAX_SEQ_LEN+1),
    "validate": create_meta(feature_names=["item_id", "genre"], common_seq_len=MAX_SEQ_LEN, item_id_shape=1000, gt_mapping={"train":1000, "ground_truth":500}),
    "predict": create_meta(feature_names=["item_id", "genre"], common_seq_len=MAX_SEQ_LEN, item_id_shape=1000),
}

In [31]:
from replay.data.nn import ParquetModule

parquet_module = ParquetModule(
    train_path=TRAIN_PATH,
    validate_path=VAL_PATH,
    predict_path=TEST_PATH,
    batch_size=BATCH_SIZE,
    metadata=METADATA,
    transforms=TRANSFORMS,
)

In [18]:
parquet_module.setup("validate")
batch = parquet_module.compiled_transforms["validate"](next(iter(parquet_module.val_dataloader())))

**NOTE**: 
You can also create a module specifically for training/inference by providing only their respective datapaths.
In such cases it's possible to pass to ParquetModule either all transforms or transforms for used data splits only.

For example:

In [20]:
parquet_module_train_val = ParquetModule(
    train_path=TRAIN_PATH,
    validate_path=VAL_PATH,
    batch_size=BATCH_SIZE,
    metadata=METADATA,
    transforms=TRANSFORMS
)

## Train model
### Create SasRec model instance and run the training stage using lightning
We may now train the model using the Lightning trainer class. 

RePlay's implementation of SasRec is designed in a modular, **block-based approach**. Instead of passing configuration parameters to the constructor, SasRec is now built by providing fully initialized components that makes the model more flexible and easier to extend. SasRec consists of the body and loss. Body consits of the following components: embedder, aggregator, encoder, mask, output_normalization, loss.

#### Components of SasRec
* `Body` - The body component defines the full model excluding loss.
* `Loss` - The loss component defines how the training loss is computed. All available loss implementations are located in nn/loss.

#### Components of SasRecBody

* `Embedder` -The embedder is responsible for converting input features into embeddings. The default implementation is `SequenceEmbedding`, which supports the following feature types: categorical, categorical_list, numerical, numerical_list

* `Aggregator` - The aggregator combines all embeddings produced by the embedder and adds positional embeddings.
Currently, `SasRecAggregator` is supported. It internally uses one of the following embedding aggregation strategies: `SumAggregator`, `ConcatAggregator`.

* `Encoder` - The encoder represents the core transformer block of the model. The following implementations are currently available: `SasRecTransformerLayer` (default one), `DiffAttentionLayer` (a modified version with differential attention).

* `Mask` - The mask is an object that creates attention mask by input. RePlay supports `DefaultAttentionMask` creating a lower-triangular attention mask.

* `Output Normalization` - Any suitable PyTorch normalization layer may be used as output_normalization, for example: torch.nn.LayerNorm or torch.nn.RMSNorm

In [32]:
from replay.nn import DefaultAttentionMask, SequenceEmbedding, SumAggregator
from replay.nn.loss import CESampled
from replay.nn.sequential import SasRec, SasRecBody, PositionAwareAggregator, SasRecTransformerLayer


NUM_BLOCKS = 1
NUM_HEADS = 1
DROPOUT = 0.0

body = SasRecBody(
    embedder=SequenceEmbedding(
        schema=tensor_schema,
        categorical_list_feature_aggregation_method="sum",
    ),
    embedding_aggregator=PositionAwareAggregator(
        embedding_aggregator=SumAggregator(embedding_dim=EMBEDDING_DIM),
        max_sequence_length=MAX_SEQ_LEN,
        dropout=DROPOUT,
    ),
    attn_mask_builder=DefaultAttentionMask(
        reference_feature_name=tensor_schema.item_id_feature_name,
        num_heads=NUM_HEADS,
    ),
    encoder=SasRecTransformerLayer(
        embedding_dim=EMBEDDING_DIM,
        num_heads=NUM_HEADS,
        num_blocks=NUM_BLOCKS,
        dropout=DROPOUT,
        activation="relu",
    ),
    output_normalization=torch.nn.LayerNorm(EMBEDDING_DIM),
)
sasrec = SasRec(
    body=body,
    loss=CESampled(padding_idx=tensor_schema.item_id_features.item().padding_value),
)

#### Default Configuration

Default SasRec model may be created quickly via method *from_params*. Default model instance has CE loss, original SasRec transformer layes, and embeddings are aggregated via sum.

In [33]:
default_sasrec = SasRec.from_params(
    schema=tensor_schema, 
    embedding_dim=EMBEDDING_DIM, 
    max_sequence_length=MAX_SEQ_LEN,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    dropout=DROPOUT,
    excluded_features=None
    )

A universal PyTorch Lightning module is provided that can work with any RePlay NN model.

In [34]:
from replay.nn.lightning import LightningModule
from replay.models.nn.optimizer_utils import FatOptimizerFactory, FatLRSchedulerFactory

model = LightningModule(
    sasrec,
    optimizer_factory=FatOptimizerFactory(),
    lr_scheduler_factory=FatLRSchedulerFactory(),
)

To facilitate training, we add the following callbacks:
1) `ModelCheckpoint` - to save the best trained model based on its Recall metric. It's a default Lightning Callback.
1) `ComputeMetricsCallback` - to display a detailed validation metric matrix after each epoch. It's a custom RePlay callback for computing recsys metrics on validation. It supports model's logits postpocessing (before metrics computing), we will use RePlay `SeenItemsFilter` in order to compute metrics on unseen ground truth items only.


In [35]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from replay.nn.lightning.callbacks import ComputeMetricsCallback
from replay.nn.lightning.postprocessors import SeenItemsFilter

checkpoint_callback = ModelCheckpoint(
    dirpath=".checkpoints",
    save_top_k=1,
    verbose=True,
    monitor="recall@10",
    mode="max",
)

postprocessors = [
    SeenItemsFilter(
        item_count=NUM_UNIQUE_ITEMS,
        seen_items_column="seen_ids"
    )
]
validation_metrics_callback = ComputeMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=NUM_UNIQUE_ITEMS,
    postprocessors=postprocessors
)

csv_logger = CSVLogger(save_dir=".logs/train", name="SasRec-example")

trainer = L.Trainer(
    max_epochs=5,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

trainer.fit(model, datamodule=parquet_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type   | Params | Mode 
-----------------------------------------
0 | model | SasRec | 285 K  | train
-----------------------------------------
285 K     Trainable params
0         Non-trainable params
285 K     Total params
1.140     Total estimated model params size (MB)
31        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 189: 'recall@10' reached 0.02537 (best 0.02537), saving model to '/home/evtsinovnik/replay/examples/.checkpoints/epoch=0-step=189.ckpt' as top 1


k              1        10        20         5
map     0.004971  0.006787  0.008945  0.004851
ndcg    0.004971  0.015234  0.024739  0.008931
recall  0.001657  0.025366  0.052390  0.011217



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 378: 'recall@10' reached 0.05022 (best 0.05022), saving model to '/home/evtsinovnik/replay/examples/.checkpoints/epoch=1-step=378.ckpt' as top 1


k              1        10        20         5
map     0.018738  0.017200  0.020406  0.013859
ndcg    0.018738  0.033770  0.046851  0.023970
recall  0.006246  0.050223  0.087380  0.028362



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 567: 'recall@10' reached 0.07387 (best 0.07387), saving model to '/home/evtsinovnik/replay/examples/.checkpoints/epoch=2-step=567.ckpt' as top 1


k              1        10        20         5
map     0.029637  0.026177  0.031039  0.020665
ndcg    0.029637  0.049957  0.068529  0.034731
recall  0.009879  0.073869  0.126577  0.039898



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 756: 'recall@10' reached 0.09675 (best 0.09675), saving model to '/home/evtsinovnik/replay/examples/.checkpoints/epoch=3-step=756.ckpt' as top 1


k              1        10        20         5
map     0.042065  0.036253  0.042210  0.029414
ndcg    0.042065  0.067027  0.088544  0.048898
recall  0.014022  0.096750  0.157680  0.056597



Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 945: 'recall@10' reached 0.10688 (best 0.10688), saving model to '/home/evtsinovnik/replay/examples/.checkpoints/epoch=4-step=945-v12.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=5` reached.


k              1        10        20         5
map     0.048375  0.041845  0.048315  0.034324
ndcg    0.048375  0.075353  0.098168  0.055858
recall  0.016125  0.106883  0.171638  0.063607



We can now laod the best model using the path stored in the callback.

In [36]:
best_model = LightningModule.load_from_checkpoint(checkpoint_callback.best_model_path, model=sasrec)

## Inference stage

### Run inference
We can now perform inference using the data module we created earlier. Recommendations can be fetched in four formats: PySpark DataFrame, Pandas DataFrame, Polars DataFrame or raw PyTorch tensors. Each of the types corresponds a callback. Inthis example, we'll be using the `PandasTopItemsCallback`.
Prediction callbacks also can filter results using postprocessors.

In [37]:
from replay.nn.lightning.callbacks import PandasTopItemsCallback

csv_logger = CSVLogger(save_dir=".logs/test", name="SasRec-example")

TOPK = [1, 5, 10, 20]

postprocessors = [
    SeenItemsFilter(
        item_count=NUM_UNIQUE_ITEMS,
        seen_items_column="seen_ids"
    )
]

pandas_prediction_callback = PandasTopItemsCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=postprocessors,
)

trainer = L.Trainer(callbacks=[pandas_prediction_callback], logger=csv_logger, inference_mode=True)

trainer.predict(best_model, datamodule=parquet_module, return_predictions=False)

pandas_res = pandas_prediction_callback.get_result()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Predicting: |          | 0/? [00:00<?, ?it/s]

In [48]:
candidates is not None and (torch.unique(candidates).numel() != candidates.numel())

False

In [42]:
candidates = torch.LongTensor([0, 1])
candidates and (torch.unique(candidates).numel() != candidates.numel()).item()

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [38]:
pandas_res

Unnamed: 0,user_id,item_id,score
0,3,1820,6.561905
0,3,2013,6.529581
0,3,443,6.375957
0,3,1148,6.199466
0,3,3186,5.923382
...,...,...,...
6039,6025,1110,5.289825
6039,6025,1148,5.1394
6039,6025,2651,5.095589
6039,6025,1106,5.070946


### Calculating metrics

*test_gt* is already encoded, so we can use it for computing metrics.

In [39]:
result_metrics = OfflineMetrics(
    [Recall(TOPK), Precision(TOPK), MAP(TOPK)], query_column="user_id", rating_column="score"
)(pandas_res, test_gt.toPandas())

                                                                                

In [40]:
metrics_to_df(result_metrics)

k,1,10,20,5
MAP,0.063742,0.059596,0.070069,0.047866
Precision,0.063742,0.046407,0.037674,0.054007
Recall,0.021247,0.154691,0.251159,0.090011
