# Example of TwoTower training/inference

In [1]:
import lightning as L
import pandas as pd

L.seed_everything(42)

Seed set to 42


42

## Preparing data
In this example, we will be using the MovieLens dataset, namely the 1m subset. It's demonstrated a simple case, so only item ids will be used as model input.

---
**NOTE**

Current implementation of TwoTower is able to handle item and interactions features. It does not take into account user features. 

---

In [2]:
interactions = pd.read_csv("./data/ml1m_ratings.dat", sep="\t", names=["user_id", "item_id","rating","timestamp"])
interactions = interactions.drop(columns=["rating"])

item_features = pd.read_csv("./data/ml1m_items.dat", sep="\t", names=["item_id", "title", "genres"])
item_features = item_features.drop(columns=["title", "genres"])

In [3]:
interactions["timestamp"] = interactions["timestamp"].astype("int64")
interactions = interactions.sort_values(by="timestamp")
interactions["timestamp"] = interactions.groupby("user_id").cumcount()
interactions

Unnamed: 0,user_id,item_id,timestamp
1000138,6040,858,0
1000153,6040,2384,1
999873,6040,593,2
1000007,6040,1961,3
1000192,6040,2019,4
...,...,...,...
825793,4958,2399,446
825438,4958,1407,447
825724,4958,3264,448
825731,4958,2634,449


### Encode catagorical data.
To ensure all categorical data is fit for training, it needs to be encoded using the `LabelEncoder` class. Create an instance of the encoder, providing a `LabelEncodingRule` for each categorcial column in the dataset that will be used in model. Note that ids of users and ids of items are always used.

Let's train the encoder for the `item_id` column using the item features, because some items may be missing from interactions. The `user_id` column will be trained using interactions.

In [4]:
from replay.preprocessing import LabelEncoder, LabelEncodingRule

encoder = LabelEncoder(
    [
        LabelEncodingRule("user_id"),
        LabelEncodingRule("item_id"),
    ]
)
encoder.rules[0].fit(interactions)
encoder.rules[1].fit(item_features)
encoded_interactions = encoder.transform(interactions)
encoded_interactions

Unnamed: 0,timestamp,user_id,item_id
0,0,6039,847
1,1,6039,2315
2,2,6039,589
3,3,6039,1892
4,4,6039,1950
...,...,...,...
1000204,446,4957,2330
1000205,447,4957,1384
1000206,448,4957,3195
1000207,449,4957,2565


In [5]:
item_features_encoded = encoder.rules[1].transform(item_features)
item_features_encoded

Unnamed: 0,item_id
0,0
1,1
2,2
3,3
4,4
...,...
3878,3878
3879,3879
3880,3880
3881,3881


### Split interactions into the train, validation and test datasets using LastNSplitter
We use widespread splitting strategy Last-One-Out. We filter out cold items and users for simplicity.

In [6]:
from replay.splitters import LastNSplitter

splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
    drop_cold_users=True,
    drop_cold_items=True
)

test_events, test_gt = splitter.split(encoded_interactions)
validation_events, validation_gt = splitter.split(test_events)
train_events = validation_events

### Dataset preprocessing ("baking")
SasRec expects each user in the batch to provide their events in form of a sequence. For this reason, the event splits must be properly processed using the `groupby_sequences` function provided by RePlay.

In [7]:
from replay.data.nn.utils import groupby_sequences


def bake_data(full_data):
    grouped_interactions = groupby_sequences(events=full_data, groupby_col="user_id", sort_col="timestamp")
    return grouped_interactions


train_events = bake_data(train_events)

validation_events = bake_data(validation_events)
validation_gt = bake_data(validation_gt)

test_events = bake_data(test_events)
test_gt = bake_data(test_gt)

train_events

Unnamed: 0,user_id,timestamp,item_id
0,0,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[3117, 1672, 1250, 1009, 2271, 1768, 3339, 118..."
1,1,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1180, 1199, 1192, 2648, 1273, 2874, 1207, 117..."
2,2,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[589, 2789, 3465, 1899, 1892, 1407, 1246, 3602..."
3,3,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1192, 1081, 476, 3399, 3458, 1178, 257, 1180,..."
4,4,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[2648, 907, 896, 352, 1230, 2119, 2789, 1111, ..."
...,...,...,...
6035,6035,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1672, 1814, 3369, 2307, 2359, 2614, 2503, 263..."
6036,6036,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[1813, 693, 1247, 1959, 3439, 3079, 558, 847, ..."
6037,6037,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[3327, 908, 1192, 2077, 1366, 352, 1063, 1132,..."
6038,6038,"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[109, 279, 1998, 1211, 918, 935, 3019, 2953, 3..."


To ensure we don't have unknown users in ground truth, we join validation events and validation ground truth (also join test events and test ground truth correspondingly) by user ids to leave only the common ones.  

In [8]:
def add_gt_to_events(events_df, gt_df):
    gt_to_join = gt_df[["user_id", "item_id"]].rename(columns={"item_id": "ground_truth"})

    events_df = events_df.merge(gt_to_join, on="user_id", how="inner")
    return events_df

validation_events = add_gt_to_events(validation_events, validation_gt)
test_events = add_gt_to_events(test_events, test_gt)

In [9]:
from pathlib import Path

data_dir = Path("temp/data/")
data_dir.mkdir(parents=True, exist_ok=True)

TRAIN_PATH = data_dir / "train.parquet"
VAL_PATH = data_dir / "val.parquet"
TEST_PATH = data_dir / "test.parquet"

PATH_ENCODED_FEATURES = data_dir / "item_features_encoded.parquet"

ENCODER_PATH = data_dir / "encoder"

In [10]:
train_events.to_parquet(TRAIN_PATH)
validation_events.to_parquet(VAL_PATH)
test_events.to_parquet(TEST_PATH)

item_features_encoded[["item_id"]].to_parquet(PATH_ENCODED_FEATURES)

encoder.save(ENCODER_PATH)



# Prepare to model training
### Create the tensor schema
A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model. It is required by the NN models to correctly create embeddings for every source column. Note that user_id does not required in `TensorSchema`.

Note that **cardinality** is the number of unique values â€‹in the item catalog (vocabulary). **Padding value** is the next value after the last one.

In [11]:
from replay.data import FeatureHint, FeatureType, FeatureSource
from replay.data.nn import TensorFeatureInfo, TensorFeatureSource, TensorSchema


EMBEDDING_DIM = 64

encoder = encoder.load(ENCODER_PATH)
NUM_UNIQUE_ITEMS = len(encoder.mapping["item_id"])

tensor_schema = TensorSchema(
    [
        TensorFeatureInfo(
            name="item_id",
            is_seq=True,
            padding_value=NUM_UNIQUE_ITEMS,
            cardinality=NUM_UNIQUE_ITEMS,
            embedding_dim=EMBEDDING_DIM,
            feature_type=FeatureType.CATEGORICAL,
            feature_hint=FeatureHint.ITEM_ID,
            feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, "item_id")]
        )
    ]
)

### Configure ParquetModule and transformation pipelines

The `ParquetModule` class enables training of models on large datasets by reading data in batch-wise way. This class initialized with **paths to every data split, a metadata dict containing information about shape and padding value of every column and a dict of transforms**. `ParquetModule`'s  "transform pipelines" are stage-specific modules implementing additional preprocessing to be performed on batch level right before the forward pass.  

For SasRec model (User tower), RePlay provides a function that generates a sequence of appropriate transforms for each data split named **make_default_sasrec_transforms**.

Internally this function creates the following transforms:
1) Training:
    1. Create a target, which contains the shifted item sequence that represents the next item in the sequence (for the next item prediction task).
    2. Rename features to match it with expected format by the model during training.
    3. Unsqueeze target (*positive_labels*) and it's padding mask (*target_padding_mask*) for getting required shape of this tensors for loss computation.
    4. Group input features to be embed in expected format.

2) Validation/Inference:
    1. Rename/group features to match it with expected format by the model during valdiation/inference.

If a different set of transforms is required, you can create them yourself and submit them to the ParquetModule in the form of a dictionary where the key is the name of the split, and the value is the list of transforms. Available transforms are in the replay/nn/transforms/.

**Note:** One of the transforms for the training data prepares the initial sequence for the task of Next Item Prediction so it shifts the sequence of items. For the final sequence length to be correct, you need to set shape of item_id in metadata as **model sequence length + shift**. Default shift value is 1.

In [12]:
from replay.nn.transform.template import make_default_twotower_transforms

transforms = make_default_twotower_transforms(tensor_schema, query_column="user_id")

In [13]:
from typing import Optional

MAX_SEQ_LEN = 50

def create_meta(shape: int, gt_shape: Optional[int] = None):
    meta = {
        "user_id": {},
        "item_id": {"shape": shape, "padding": tensor_schema["item_id"].padding_value},
    }
    if gt_shape is not None:
        meta.update({"ground_truth": {"shape": gt_shape, "padding": -1}})

    return meta

train_metadata = {
    "train": create_meta(shape=MAX_SEQ_LEN+1),
    "validate": create_meta(shape=MAX_SEQ_LEN, gt_shape=1),
}

In [14]:
from replay.data.nn import ParquetModule

BATCH_SIZE = 32

parquet_module = ParquetModule(
    train_path=TRAIN_PATH,
    validate_path=VAL_PATH,
    batch_size=BATCH_SIZE,
    metadata=train_metadata,
    transforms=transforms,
)

  parquet_module = ParquetModule(


## Train model
### Create TwoTower model instance and run the training stage using lightning
We may now train the model using the Lightning trainer class. 

RePlay's implementation of TwoTower is designed in a modular, **block-based approach**. Instead of passing configuration parameters to the constructor, TwoTower is built by providing fully initialized components that makes the model more flexible and easier to extend. 


#### Default Configuration

Default TwoTower model may be created quickly via method *from_params*. Default model instance has CE loss, user tower is SasRec with original SasRec transformer layes and sum aggregated embeddings, item tower is a SwiGlU MLP block. Both towers use the same features.

In [15]:
from replay.nn.sequential import TwoTower
from replay.nn.sequential.twotower import FeaturesReader

NUM_BLOCKS = 2
NUM_HEADS = 2
DROPOUT = 0.3

twotower = TwoTower.from_params(
    schema=tensor_schema,
    embedding_dim=EMBEDDING_DIM,
    max_sequence_length=MAX_SEQ_LEN,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    dropout=DROPOUT,
    item_features_reader=FeaturesReader(
        schema=tensor_schema,
        metadata={"item_id": {}},
        path=PATH_ENCODED_FEATURES,
    )
)

A universal PyTorch Lightning module is provided that can work with any NN model.

In [16]:
from replay.nn.lightning import LightningModule

model = LightningModule(twotower)

To facilitate training, we add the following callbacks:
1) `ModelCheckpoint` - to save the best trained model based on its Recall metric. It's a default Lightning Callback.
1) `ComputeMetricsCallback` - to display a detailed validation metric matrix after each epoch. It's a custom RePlay callback for computing recsys metrics on validation and test stages.


In [17]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

from replay.nn.lightning.callback import ComputeMetricsCallback


checkpoint_callback = ModelCheckpoint(
    dirpath="twotower/checkpoints",
    save_top_k=1,
    verbose=True,
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ComputeMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=NUM_UNIQUE_ITEMS,
)

csv_logger = CSVLogger(save_dir="twotower/.logs/train", name="TwoTower-example")

trainer = L.Trainer(
    max_epochs=5,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

trainer.fit(model, datamodule=parquet_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type     | Params | Mode  | FLOPs
---------------------------------------------------
0 | model | TwoTower | 352 K  | train | 0    
---------------------------------------------------
352 K     Trainable params
0         Non-trainable params
352 K     Total params
1.409     Total estimated model params size (MB)
52        Modules in train mode
0         Modules in eval mode
0         Total Flops


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 189: 'recall@10' reached 0.03958 (best 0.03958), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=0-step=189.ckpt' as top 1


k              1        10        20         5
map     0.003643  0.011629  0.013674  0.009331
ndcg    0.003643  0.018055  0.025700  0.012422
recall  0.003643  0.039576  0.070210  0.022024





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 378: 'recall@10' reached 0.10184 (best 0.10184), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=1-step=378.ckpt' as top 1


k              1        10        20         5
map     0.010763  0.033238  0.037375  0.028159
ndcg    0.010763  0.049178  0.064472  0.036708
recall  0.010763  0.101838  0.162775  0.062924





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 567: 'recall@10' reached 0.13280 (best 0.13280), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=2-step=567.ckpt' as top 1


k              1        10        20         5
map     0.012916  0.041789  0.047532  0.034280
ndcg    0.012916  0.062866  0.084155  0.044635
recall  0.012916  0.132803  0.217751  0.076337





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 756: 'recall@10' reached 0.14920 (best 0.14920), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=3-step=756.ckpt' as top 1


k             1        10        20         5
map     0.01391  0.046216  0.052732  0.037647
ndcg    0.01391  0.070051  0.093930  0.049280
recall  0.01391  0.149197  0.243915  0.084948





Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 945: 'recall@10' reached 0.16079 (best 0.16079), saving model to '/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=5` reached.


k              1        10        20         5
map     0.013247  0.047682  0.054643  0.038307
ndcg    0.013247  0.073782  0.099453  0.050853
recall  0.013247  0.160788  0.262957  0.089419



Now we can get the best model path stored in the checkpoint callback.

In [18]:
best_model_path = checkpoint_callback.best_model_path
best_model_path

'/home/nkulikov/RePlay/examples/twotower/checkpoints/epoch=4-step=945.ckpt'

## Inference

To obtain model scores, we will load the weights from the best checkpoint. To do this, we use the `LightningModule`, provide there the path to the checkpoint and the model instance.

In [19]:
import torch
import replay

twotower = TwoTower.from_params(
    schema=tensor_schema,
    embedding_dim=EMBEDDING_DIM,
    max_sequence_length=MAX_SEQ_LEN,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    dropout=DROPOUT,
    item_features_reader=FeaturesReader(
        schema=tensor_schema,
        metadata={"item_id": {}},
        path=PATH_ENCODED_FEATURES,
    )
)

torch.serialization.add_safe_globals([
    replay.nn.lightning.optimizer.OptimizerFactory,
    replay.nn.lightning.scheduler.LRSchedulerFactory,
])

best_model = LightningModule.load_from_checkpoint(best_model_path, model=twotower)
best_model.eval();

Configure `ParquetModule` for inference

In [20]:
inference_metadata = {"predict": create_meta(shape=MAX_SEQ_LEN)}

parquet_module = ParquetModule(
    predict_path=TEST_PATH,
    batch_size=BATCH_SIZE,
    metadata=inference_metadata,
    transforms=transforms,
)

  parquet_module = ParquetModule(


During inference, we can use `TopItemsCallback`. Such callback allows you to get scores for each user throughout the entire catalog and get recommendations in the form of ids of items with the highest score values.


Recommendations can be fetched in four formats: PySpark DataFrame, Pandas DataFrame, Polars DataFrame or raw PyTorch tensors. Each of the types corresponds a callback. In this example, we'll be using the `PandasTopItemsCallback`.

In [21]:
from replay.nn.lightning.callback import PandasTopItemsCallback

csv_logger = CSVLogger(save_dir="twotower/.logs/test", name="TwoTower-example")

TOPK = [1, 5, 10, 20]

pandas_prediction_callback = PandasTopItemsCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
)

trainer = L.Trainer(callbacks=[pandas_prediction_callback], logger=csv_logger, inference_mode=True)
trainer.predict(best_model, datamodule=parquet_module, return_predictions=False)

pandas_res = pandas_prediction_callback.get_result()

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/nkulikov/new_venv/lib/python3.12/site-packages/lightning/pytorch/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.


Predicting: |          | 0/? [00:00<?, ?it/s]



In [22]:
pandas_res

Unnamed: 0,user_id,item_id,score
0,0,773,26.849968
0,0,360,26.444904
0,0,1526,26.433756
0,0,2618,26.282482
0,0,1838,26.202333
...,...,...,...
6037,6039,1680,26.977921
6037,6039,1375,26.976624
6037,6039,2125,26.967751
6037,6039,1439,26.963537


### Calculating metrics

*test_gt* is already encoded, so we can use it for computing metrics.

In [23]:
from replay.metrics import MAP, OfflineMetrics, Precision, Recall
from replay.metrics.torch_metrics_builder import metrics_to_df

In [24]:
result_metrics = OfflineMetrics(
    [Recall(TOPK), Precision(TOPK), MAP(TOPK)],
    query_column="user_id",
    rating_column="score",
)(pandas_res, test_gt.explode("item_id"))

In [25]:
metrics_to_df(result_metrics)

k,1,10,20,5
MAP,0.016893,0.050191,0.056355,0.0412
Precision,0.016893,0.015916,0.012479,0.018085
Recall,0.016893,0.159159,0.249586,0.090427


Let's call the `inverse_transform` encoder's function to get the final dataframe with recommendations

In [26]:
encoder.inverse_transform(pandas_res)

Unnamed: 0,user_id,item_id,score
0,1,783,26.849968
0,1,364,26.444904
0,1,1566,26.433756
0,1,2687,26.282482
0,1,1907,26.202333
...,...,...,...
6037,6040,1729,26.977921
6037,6040,1396,26.976624
6037,6040,2194,26.967751
6037,6040,1466,26.963537
