# Example of the feature extracting in sequential models
Note that all the given examples can be run without using PySpark, using only Pandas

In [1]:
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import torch

from replay.metrics import OfflineMetrics, Recall, Precision, MAP, NDCG, HitRate, MRR
from replay.metrics.torch_metrics_builder import metrics_to_df
from replay.splitters import LastNSplitter
from replay.utils import get_spark_session
from replay.data import (
    FeatureHint,
    FeatureInfo,
    FeatureSchema,
    FeatureSource,
    FeatureType,
    Dataset,
)
from replay.models.nn.optimizer_utils import FatOptimizerFactory
from replay.models.nn.sequential.callbacks import (
    ValidationMetricsCallback,
    SparkPredictionCallback,
    PandasPredictionCallback, 
    TorchPredictionCallback,
    QueryEmbeddingsPredictionCallback
)
from replay.models.nn.sequential.postprocessors import RemoveSeenItems
from replay.data.nn import (
    SequenceTokenizer,
    SequentialDataset,
    TensorFeatureSource,
    TensorSchema,
    TensorFeatureInfo
)
from replay.models.nn.sequential import SasRec
from replay.models.nn.sequential.sasrec import (
    SasRecPredictionDataset,
    SasRecTrainingDataset,
    SasRecValidationDataset,
    SasRecPredictionBatch,
    SasRecModel
)

import pandas as pd
import numpy as np

## Prepare data
### Load raw movielens-1M interactions, item features and user features.
In the current implementation, the SASRec does not take into account the features of items or users. They are only used to get a complete list of users and items.

In [None]:
!pip install rs-datasets

In [3]:
from rs_datasets import MovieLens

In [4]:
movielens = MovieLens("1m")
interactions = movielens.ratings
user_features = movielens.users
item_features = movielens.items

In [5]:
sz = interactions.shape[0]
cols = 8
interactions["int_cat_list"] = np.random.randint(0, 10, size=(sz, cols)).tolist()
interactions["int_num_list"] = np.random.randn(sz, cols).tolist()
interactions["int_cat"] = np.random.randint(0, 10, size=sz)
interactions["int_num"] = np.random.randn(sz)

interactions

Unnamed: 0,user_id,item_id,rating,timestamp,int_cat_list,int_num_list,int_cat,int_num
0,1,1193,5,978300760,"[2, 3, 0, 0, 4, 9, 9, 3]","[-0.9452159835358088, 0.11843513213509653, -1....",5,-0.990908
1,1,661,3,978302109,"[2, 9, 3, 0, 0, 5, 4, 8]","[-1.4540425647970485, -0.3245363595041769, -1....",7,-0.267634
2,1,914,3,978301968,"[9, 0, 6, 9, 9, 7, 7, 6]","[0.049276029393911756, -0.7011654552905875, 2....",6,0.746226
3,1,3408,4,978300275,"[8, 4, 1, 3, 8, 5, 3, 7]","[1.6044119024135526, 1.1472465490348889, 0.672...",2,-0.078762
4,1,2355,5,978824291,"[8, 3, 2, 2, 2, 7, 5, 5]","[-1.061623858284662, 1.4147818817122946, 0.549...",8,-0.115445
...,...,...,...,...,...,...,...,...
1000204,6040,1091,1,956716541,"[1, 4, 0, 4, 9, 9, 0, 7]","[-1.2390545237060733, -0.0038395034983150665, ...",7,1.968532
1000205,6040,1094,5,956704887,"[1, 3, 9, 0, 3, 8, 7, 0]","[1.894659362573173, -0.11866467453941414, 0.41...",9,1.068694
1000206,6040,562,5,956704746,"[1, 1, 7, 7, 4, 2, 0, 9]","[0.3693274230584605, 0.9522883597124906, 1.392...",8,-1.524719
1000207,6040,1096,4,956715648,"[2, 2, 6, 2, 9, 5, 9, 7]","[-1.7092599668894297, -0.08893790573181981, -0...",0,1.842639


In [6]:
sz = user_features.shape[0]
cols = 5
user_features["user_cat_list"] = np.random.randint(0, 10, size=(sz, cols)).tolist()
user_features["user_num_list"] = np.random.randn(sz, cols).tolist()
user_features["user_cat"] = np.random.randint(0, 10, size=sz)
user_features["user_num"] = np.random.randn(sz)

user_features

Unnamed: 0,user_id,gender,age,occupation,zip_code,user_cat_list,user_num_list,user_cat,user_num
0,1,F,1,10,48067,"[8, 6, 6, 9, 3]","[0.36367906143495726, -1.0637550692721314, -0....",5,-0.758647
1,2,M,56,16,70072,"[1, 8, 8, 3, 4]","[-0.14181118212015617, -0.7331238227279658, 0....",4,-0.781051
2,3,M,25,15,55117,"[8, 3, 6, 0, 1]","[-0.26908219336594236, -0.23063841761491402, -...",8,0.649338
3,4,M,45,7,02460,"[6, 8, 3, 5, 2]","[0.97334778247759, -1.1123954097867839, 0.4551...",0,-0.181194
4,5,M,25,20,55455,"[7, 2, 2, 1, 3]","[0.24034122814066178, -0.7078123918666701, -0....",0,0.509186
...,...,...,...,...,...,...,...,...,...
6035,6036,F,25,15,32603,"[9, 0, 0, 7, 6]","[0.2869394624507952, -0.8306940486832483, -0.0...",3,-0.607487
6036,6037,F,45,1,76006,"[5, 1, 3, 5, 0]","[-1.7619847725610451, -0.949514704723724, -2.0...",1,0.633373
6037,6038,F,56,1,14706,"[8, 2, 3, 0, 4]","[1.8133636396157904, 0.7590471733911787, 1.181...",9,0.157026
6038,6039,F,45,0,01060,"[8, 4, 7, 9, 1]","[-0.5381018507382919, -0.6673470540260682, 0.2...",8,-0.242536


In [7]:
sz = item_features.shape[0]
cols = 13
item_features["item_cat_list"] = np.random.randint(0, 10, size=(sz, cols)).tolist()
item_features["item_num_list"] = np.random.randn(sz, cols).tolist()
item_features["item_cat"] = np.random.randint(0, 10, size=sz)
item_features["item_num"] = np.random.randn(sz)

item_features

Unnamed: 0,item_id,title,genres,item_cat_list,item_num_list,item_cat,item_num
0,1,Toy Story (1995),Animation|Children's|Comedy,"[1, 1, 3, 2, 6, 6, 2, 8, 3, 8, 3, 5, 7]","[0.2267798026143387, -0.16205344405521666, -0....",2,-0.821077
1,2,Jumanji (1995),Adventure|Children's|Fantasy,"[0, 0, 0, 1, 6, 7, 1, 4, 4, 1, 7, 8, 1]","[0.7842691024723795, 0.17351749019965385, -1.8...",4,-1.139324
2,3,Grumpier Old Men (1995),Comedy|Romance,"[3, 2, 8, 2, 0, 3, 2, 6, 1, 3, 0, 1, 8]","[-1.0548189125745606, -0.47528169864926134, -0...",0,0.054791
3,4,Waiting to Exhale (1995),Comedy|Drama,"[3, 0, 1, 8, 7, 8, 2, 9, 2, 6, 8, 3, 6]","[-0.7922967666883725, 0.9983384891958796, 2.80...",9,-1.265111
4,5,Father of the Bride Part II (1995),Comedy,"[2, 0, 4, 1, 7, 1, 1, 7, 7, 5, 4, 8, 4]","[-1.2110901744432956, -1.7013676917798273, -0....",3,-1.934216
...,...,...,...,...,...,...,...
3878,3948,Meet the Parents (2000),Comedy,"[8, 5, 2, 2, 7, 1, 4, 4, 2, 4, 4, 2, 6]","[1.3017434495626832, -0.9180213461397108, -0.5...",8,1.994910
3879,3949,Requiem for a Dream (2000),Drama,"[2, 8, 3, 1, 2, 1, 4, 9, 9, 0, 8, 8, 4]","[0.2320009008928167, -0.07487716198864543, 0.8...",7,-2.421320
3880,3950,Tigerland (2000),Drama,"[8, 2, 1, 3, 9, 6, 8, 5, 1, 1, 5, 9, 7]","[0.8950098179007996, -0.8872821023368912, -1.4...",9,0.567622
3881,3951,Two Family House (2000),Drama,"[6, 3, 1, 1, 0, 6, 1, 3, 8, 3, 8, 4, 1]","[-0.800715863041395, 0.3454083851292667, 0.297...",9,1.461547


### Split interactions into the train, validation and test datasets using LastNSplitter

In [8]:
splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
)

raw_test_events, raw_test_gt = splitter.split(interactions)
raw_validation_events, raw_validation_gt = splitter.split(raw_test_events)
raw_train_events = raw_validation_events

### Prepare FeatureSchema required to create Dataset

In [9]:
suffix_names = ["cat_list", "num_list", "cat", "num"]
feature_types = [FeatureType.CATEGORICAL_LIST, FeatureType.NUMERICAL_LIST, FeatureType.CATEGORICAL, FeatureType.NUMERICAL]

def prepare_feature_schema(is_ground_truth: bool) -> FeatureSchema:
    all_features = FeatureSchema(
        [
            FeatureInfo(
                column="user_id",
                feature_hint=FeatureHint.QUERY_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
            FeatureInfo(
                column="item_id",
                feature_hint=FeatureHint.ITEM_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
            *[
                FeatureInfo(
                    column="int_" + suffix,
                    feature_type=feature_type,
                )
                for suffix, feature_type in zip(suffix_names, feature_types)
            ],
        ]
    )
    if is_ground_truth:
        return all_features

    all_features = all_features + FeatureSchema([
        FeatureInfo(
            column="timestamp",
            feature_type=FeatureType.NUMERICAL,
            feature_hint=FeatureHint.TIMESTAMP,
        ),
        # item features
        *[
            FeatureInfo(
                column="item_" + suffix,
                feature_type=feature_type,
                feature_source=FeatureSource.ITEM_FEATURES,
            )
            for suffix, feature_type in zip(suffix_names, feature_types)
        ],
        # query features
        *[
            FeatureInfo(
                column="user_" + suffix,
                feature_type=feature_type,
                feature_source=FeatureSource.QUERY_FEATURES,
            )
            for suffix, feature_type in zip(suffix_names, feature_types)
        ],
    ])
    return all_features

### Create Dataset for the training stage

In [10]:
train_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_train_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)

### Create Datasets (events and ground_truth) for the validation stage

In [11]:
validation_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_validation_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
validation_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_validation_gt,
    check_consistency=True,
    categorical_encoded=False,
)

### Create Datasets (events and ground_truth) for the testing stage

In [12]:
test_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_test_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
test_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_test_gt,
    check_consistency=True,
    categorical_encoded=False,
)

### Create the tensor schema
A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model

In [13]:
ITEM_FEATURE_NAME = "item_id_seq"

tensor_schema = TensorSchema([
    TensorFeatureInfo(
        name=ITEM_FEATURE_NAME,
        is_seq=True,
        feature_type=FeatureType.CATEGORICAL,
        feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, train_dataset.feature_schema.item_id_column)],
        feature_hint=FeatureHint.ITEM_ID,
    ),
    TensorFeatureInfo(
        name="timestamp_seq",
        is_seq=True,
        feature_type=FeatureType.NUMERICAL,
        feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, "timestamp")],
        feature_hint=FeatureHint.TIMESTAMP,
    ),
    # interaction features
    *[
        TensorFeatureInfo(
            name="int_" + suffix,
            is_seq=True,
            feature_type=feature_type,
            feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, "int_" + suffix)],
        )
        for suffix, feature_type in zip(suffix_names, feature_types)
    ],
    # item features
    *[
        TensorFeatureInfo(
            name="item_" + suffix,
            is_seq=True,
            feature_type=feature_type,
            feature_sources=[TensorFeatureSource(FeatureSource.ITEM_FEATURES, "item_" + suffix)],
        )
        for suffix, feature_type in zip(suffix_names, feature_types)
    ],
    # query features
    *[
        TensorFeatureInfo(
            name="user_" + suffix,
            is_seq=True,
            feature_type=feature_type,
            feature_sources=[TensorFeatureSource(FeatureSource.QUERY_FEATURES, "user_" + suffix)],
        )
        for suffix, feature_type in zip(suffix_names, feature_types)
    ],
    # not sequential query features
    *[
        TensorFeatureInfo(
            name="user_not_seq_" + suffix,
            is_seq=False,
            feature_type=feature_type,
            feature_sources=[TensorFeatureSource(FeatureSource.QUERY_FEATURES, "user_" + suffix)],
        )
        for suffix, feature_type in zip(suffix_names, feature_types)
    ],
])

### Create sequential datasets using SequenceTokenizer
The SequentialDataset internally store data in the form of sequences of items sorted by increasing interaction time (timestamp). A SequenceTokenizer is used to convert to this format. In addition, the SequenceTokenizer encodes all categorical columns from the source dataset and stores mapping inside itself.
SequentialDataset.keep_common_query_ids is used to leave only sequences from the same users

In [14]:
tokenizer = SequenceTokenizer(tensor_schema, allow_collect_to_master=True)
tokenizer.fit(train_dataset)

sequential_train_dataset = tokenizer.transform(train_dataset)

sequential_validation_dataset = tokenizer.transform(validation_dataset)
sequential_validation_gt = tokenizer.transform(validation_gt, [tensor_schema.item_id_feature_name])

sequential_validation_dataset, sequential_validation_gt = SequentialDataset.keep_common_query_ids(
    sequential_validation_dataset, sequential_validation_gt
)

In [15]:
train_dataloader = DataLoader(
    dataset=SasRecTrainingDataset(
        sequential_train_dataset,
        max_sequence_length=10,
    ),
    batch_size=3,
    shuffle=True,
    num_workers=1,
    pin_memory=True,
)

  dataset=SasRecTrainingDataset(
  self._inner = TorchSequentialDataset(


In [16]:
for batch in train_dataloader:
    print(f"{batch.query_id=}")
    print(f"{batch.padding_mask=}")
    for k, v in batch.features.items():
        print(k, v)
    break

batch.query_id=tensor([[ 809],
        [1991],
        [4282]])
batch.padding_mask=tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])
item_id_seq tensor([[1575, 3366, 1232, 1163, 2918, 1188,  257, 1196, 3083, 1844],
        [3379, 2677, 1934, 2964, 3684, 3509, 3686, 3554, 3793, 3729],
        [2683,  898, 2966,  352, 2286, 3381,  363, 2067, 1015, 1179]])
timestamp_seq tensor([[975395010, 975395042, 975395042, 975395064, 975395064, 975395204,
         975395204, 975395204, 975395267, 975395267],
        [974692292, 974692292, 974692327, 974692361, 974692512, 974692512,
         974692512, 974692559, 974692645, 974692684],
        [965278819, 965278838, 965278884, 965278884, 965278884, 965278884,
         965278884, 965278966, 965278966, 965278966]])
int_cat_list tensor([[[4, 4, 2, 4, 5, 2, 6, 8],
         [3, 9, 1, 7, 1, 2, 6, 5]