# Example of the SASRec training with using RandomTargetNextNSplitter

In [None]:
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import torch

from replay.metrics import OfflineMetrics, Recall, Precision, MAP, NDCG, HitRate, MRR
from replay.metrics.torch_metrics_builder import metrics_to_df
from replay.splitters import LastNSplitter, RandomTargetNextNSplitter, TimeSplitter
from replay.preprocessing.filters import MinCountFilter
from replay.data import (
    FeatureHint,
    FeatureInfo,
    FeatureSchema,
    FeatureSource,
    FeatureType,
    Dataset,
)
from replay.models.nn.optimizer_utils import FatOptimizerFactory
from replay.models.nn.sequential.callbacks import (
    ValidationMetricsCallback,
    SparkPredictionCallback,
    PandasPredictionCallback,
    TorchPredictionCallback,
    QueryEmbeddingsPredictionCallback,
)
from replay.models.nn.sequential.postprocessors import RemoveSeenItems
from replay.data.nn import SequenceTokenizer, SequentialDataset, TensorFeatureSource, TensorSchema, TensorFeatureInfo
from replay.models.nn.sequential import SasRec
from replay.models.nn.sequential.sasrec import (
    SasRecPredictionDataset,
    SasRecTrainingDataset,
    SasRecValidationDataset,
    SasRecPredictionBatch,
    SasRecModel,
)
import pandas as pd

from replay.preprocessing.filters import filter_cold
from replay.preprocessing.utils import merge_subsets

## Prepare data
### Load raw movielens-1M interactions, item features and user features.

In [None]:
!pip install rs-datasets

In [3]:
from rs_datasets import MovieLens

In [4]:
movielens = MovieLens("1m")
interactions = movielens.ratings
user_features = movielens.users
item_features = movielens.items

In [5]:
interactions.head()

Unnamed: 0,user_id,item_id,rating,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291


In [6]:
user_features.head()

Unnamed: 0,user_id,gender,age,occupation,zip_code
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117
3,4,M,45,7,2460
4,5,M,25,20,55455


In [7]:
item_features.head()

Unnamed: 0,item_id,title,genres
0,1,Toy Story (1995),Animation|Children's|Comedy
1,2,Jumanji (1995),Adventure|Children's|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama
4,5,Father of the Bride Part II (1995),Comedy


Removing duplicates in the timestamp column without changing the original items order where timestamp is the same

In [8]:
interactions["timestamp"] = interactions["timestamp"].astype("int64")
interactions = interactions.sort_values(by="timestamp")

In [9]:
def prepare_feature_schema(is_ground_truth: bool) -> FeatureSchema:
    base_features = FeatureSchema(
        [
            FeatureInfo(
                column="user_id",
                feature_hint=FeatureHint.QUERY_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
            FeatureInfo(
                column="item_id",
                feature_hint=FeatureHint.ITEM_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
        ]
    )
    if is_ground_truth:
        return base_features

    all_features = base_features + FeatureSchema(
        [
            FeatureInfo(
                column="timestamp",
                feature_type=FeatureType.NUMERICAL,
                feature_hint=FeatureHint.TIMESTAMP,
            ),
        ]
    )
    return all_features

In [14]:
ITEM_FEATURE_NAME = "item_id_seq"

tensor_schema = TensorSchema(
    TensorFeatureInfo(
        name=ITEM_FEATURE_NAME,
        is_seq=True,
        feature_type=FeatureType.CATEGORICAL,
        feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, train_dataset.feature_schema.item_id_column)],
        feature_hint=FeatureHint.ITEM_ID,
    )
)

## Split data

In this section, we will examine three different strategies for splitting data into training, validation, and test sets:

1) Leave one out split (LOO).
2) Global temporal split + LOO target.
3) Global temporal split + random target.

### Leave-one-out split

In [12]:
splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
)

raw_test_events, raw_test_gt = splitter.split(interactions)
raw_validation_events, raw_validation_gt = splitter.split(raw_test_events)
raw_train_events = raw_validation_events

In [16]:
train_dataset_loo = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_train_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)

validation_dataset_loo = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_validation_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
validation_gt_loo = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_validation_gt,
    check_consistency=True,
    categorical_encoded=False,
)

test_dataset_loo = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_test_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
test_gt_loo = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_test_gt,
    check_consistency=True,
    categorical_encoded=False,
)

In [19]:
tokenizer = SequenceTokenizer(tensor_schema, allow_collect_to_master=True)
tokenizer.fit(train_dataset_loo)

sequential_train_dataset_loo = tokenizer.transform(train_dataset_loo)

sequential_validation_dataset_loo = tokenizer.transform(validation_dataset_loo)
sequential_validation_gt_loo = tokenizer.transform(validation_gt_loo, [tensor_schema.item_id_feature_name])

sequential_validation_dataset_loo, sequential_validation_gt_loo = SequentialDataset.keep_common_query_ids(
    sequential_validation_dataset_loo, sequential_validation_gt_loo
)

sequential_test_dataset_loo = tokenizer.transform(test_dataset_loo)
sequential_test_gt_loo = tokenizer.transform(test_gt_loo, [tensor_schema.item_id_feature_name])

sequential_test_dataset_loo, sequential_test_gt_loo = SequentialDataset.keep_common_query_ids(
    sequential_test_dataset_loo, sequential_test_gt_loo
)

### Global temporal split + LOO target

In [57]:
time_splitter = TimeSplitter(
    time_threshold=0.1,
    query_column="user_id"
)

loo_splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
)

min_cnt_filter = MinCountFilter(num_entries=2, groupby_column="user_id")

train_val, test_holdout = time_splitter.split(interactions)
train, val_holdout = time_splitter.split(train_val)

train = min_cnt_filter.transform(train)

test_holdout = filter_cold(test_holdout, train, mode="items")
val_holdout = filter_cold(val_holdout, train, mode="items")
train_val = filter_cold(train_val, train, mode="items")

val_input, val_target = loo_splitter.split(val_holdout)
test_input, test_target = loo_splitter.split(test_holdout)

val_input = merge_subsets([val_input, train])
test_input = merge_subsets([test_input, train_val])

val_target = filter_cold(val_target, train, mode="users", query_column="user_id")
test_target = filter_cold(test_target, train, mode="users", query_column="user_id")

In [58]:
train_dataset_gts_loo = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=train,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)

validation_dataset_gts_loo = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=val_input,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
validation_gt_gts_loo = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=val_target,
    check_consistency=True,
    categorical_encoded=False,
)

test_dataset_gts_loo = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=test_input,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
test_gt_gts_loo = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=test_target,
    check_consistency=True,
    categorical_encoded=False,
)

In [59]:
tokenizer = SequenceTokenizer(tensor_schema, allow_collect_to_master=True)
tokenizer.fit(train_dataset_gts_loo)

sequential_train_dataset_gts_loo = tokenizer.transform(train_dataset_gts_loo)

sequential_validation_dataset_gts_loo = tokenizer.transform(validation_dataset_gts_loo)
sequential_validation_gt_gts_loo = tokenizer.transform(validation_gt_gts_loo, [tensor_schema.item_id_feature_name])

sequential_validation_dataset_gts_loo, sequential_validation_gt_gts_loo = SequentialDataset.keep_common_query_ids(
    sequential_validation_dataset_gts_loo, sequential_validation_gt_gts_loo
)

sequential_test_dataset_gts_loo = tokenizer.transform(test_dataset_gts_loo)
sequential_test_gt_gts_loo = tokenizer.transform(test_gt_gts_loo, [tensor_schema.item_id_feature_name])

sequential_test_dataset_gts_loo, sequential_test_gt_gts_loo = SequentialDataset.keep_common_query_ids(
    sequential_test_dataset_gts_loo, sequential_test_gt_gts_loo
)

### Global temporal split + random target

In [85]:
time_splitter = TimeSplitter(
    time_threshold=0.1,
    query_column="user_id"
)

random_splitter = RandomTargetNextNSplitter(
    N=1,
    seed=42,
    divide_column="user_id",
    query_column="user_id"
)

min_cnt_filter = MinCountFilter(num_entries=2, groupby_column="user_id")

train_val, test_holdout = time_splitter.split(interactions)
train, val_holdout = time_splitter.split(train_val)

train = min_cnt_filter.transform(train)

test_holdout = filter_cold(test_holdout, train, mode="items")
val_holdout = filter_cold(val_holdout, train, mode="items")
train_val = filter_cold(train_val, train, mode="items")

val_input, val_target = random_splitter.split(val_holdout)
test_input, test_target = random_splitter.split(test_holdout)

val_input = merge_subsets([val_input, train])
test_input = merge_subsets([test_input, train_val])

val_target = filter_cold(val_target, train, mode="users", query_column="user_id")
test_target = filter_cold(test_target, train, mode="users", query_column="user_id")

In [86]:
train_dataset_gts_random = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=train,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)

validation_dataset_gts_random = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=val_input,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
validation_gt_gts_random = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=val_target,
    check_consistency=True,
    categorical_encoded=False,
)

test_dataset_gts_random = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=test_input,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
test_gt_gts_random = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=test_target,
    check_consistency=True,
    categorical_encoded=False,
)

In [87]:
tokenizer = SequenceTokenizer(tensor_schema, allow_collect_to_master=True)
tokenizer.fit(train_dataset_gts_random)

sequential_train_dataset_gts_random = tokenizer.transform(train_dataset_gts_random)

sequential_validation_dataset_gts_random = tokenizer.transform(validation_dataset_gts_random)
sequential_validation_gt_gts_random = tokenizer.transform(validation_gt_gts_random, [tensor_schema.item_id_feature_name])

sequential_validation_dataset_gts_random, sequential_validation_gt_gts_random = SequentialDataset.keep_common_query_ids(
    sequential_validation_dataset_gts_random, sequential_validation_gt_gts_random
)

sequential_test_dataset_gts_random = tokenizer.transform(test_dataset_gts_random)
sequential_test_gt_gts_random = tokenizer.transform(test_gt_gts_random, [tensor_schema.item_id_feature_name])

sequential_test_dataset_gts_random, sequential_test_gt_gts_random = SequentialDataset.keep_common_query_ids(
    sequential_test_dataset_gts_random, sequential_test_gt_gts_random
)

## Train model

### Train SASRec with LOO split

In [None]:
MAX_SEQ_LEN = 200
BATCH_SIZE = 512
NUM_WORKERS = 4
MAX_EPOCHS = 10

model = SasRec(
    tensor_schema,
    block_count=2,
    head_count=2,
    max_seq_len=MAX_SEQ_LEN,
    hidden_size=300,
    dropout_rate=0.5,
    optimizer_factory=FatOptimizerFactory(learning_rate=0.001),
)

csv_logger = CSVLogger(save_dir=".logs/train", name="SASRec_example")

checkpoint_callback = ModelCheckpoint(
    dirpath=".checkpoints",
    save_top_k=1,
    verbose=True,
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset_loo.item_count,
    postprocessors=[RemoveSeenItems(sequential_validation_dataset_loo)],
)

trainer = L.Trainer(
    max_epochs=MAX_EPOCHS,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

train_dataloader = DataLoader(
    dataset=SasRecTrainingDataset(
        sequential_train_dataset_loo,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

validation_dataloader = DataLoader(
    dataset=SasRecValidationDataset(
        sequential_validation_dataset_loo,
        sequential_validation_gt_loo,
        sequential_train_dataset_loo,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=validation_dataloader,
)

In [33]:
best_model = SasRec.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    tensor_schema=tensor_schema,
)

In [None]:
test_dataloader = DataLoader(
    dataset=SasRecValidationDataset(
        sequential_test_dataset_loo,
        sequential_test_gt_loo,
        sequential_train_dataset_loo,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE
)

In [38]:
test_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset_loo.item_count,
    postprocessors=[RemoveSeenItems(sequential_test_dataset_loo)],
)

trainer_test = L.Trainer(callbacks=[test_metrics_callback], logger=csv_logger)
trainer_test.validate(best_model, test_dataloader)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Validation: |          | 0/? [00:00<?, ?it/s]

k              1        10        20         5
map     0.040894  0.077621  0.083144  0.069200
ndcg    0.040894  0.102925  0.123141  0.082525
recall  0.040894  0.186258  0.266391  0.123179



[{'recall@1': 0.0408940427005291,
  'ndcg@1': 0.0408940427005291,
  'map@1': 0.0408940427005291,
  'recall@5': 0.12317880988121033,
  'ndcg@5': 0.08252540975809097,
  'map@5': 0.06919977813959122,
  'recall@10': 0.18625827133655548,
  'ndcg@10': 0.10292463004589081,
  'map@10': 0.07762108743190765,
  'recall@20': 0.26639074087142944,
  'ndcg@20': 0.12314070761203766,
  'map@20': 0.08314439654350281}]

### Train SASRec with Gloabal Temporal split + LOO target

In [None]:
MAX_SEQ_LEN = 200
BATCH_SIZE = 512
NUM_WORKERS = 4
MAX_EPOCHS = 10

model = SasRec(
    tensor_schema,
    block_count=2,
    head_count=2,
    max_seq_len=MAX_SEQ_LEN,
    hidden_size=300,
    dropout_rate=0.5,
    optimizer_factory=FatOptimizerFactory(learning_rate=0.001),
)

csv_logger = CSVLogger(save_dir=".logs/train", name="SASRec_example")

checkpoint_callback = ModelCheckpoint(
    dirpath=".checkpoints",
    save_top_k=1,
    verbose=True,
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset_gts_loo.item_count,
    postprocessors=[RemoveSeenItems(sequential_validation_dataset_gts_loo)],
)

trainer = L.Trainer(
    max_epochs=MAX_EPOCHS,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

train_dataloader = DataLoader(
    dataset=SasRecTrainingDataset(
        sequential_train_dataset_gts_loo,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

validation_dataloader = DataLoader(
    dataset=SasRecValidationDataset(
        sequential_validation_dataset_gts_loo,
        sequential_validation_gt_gts_loo,
        sequential_train_dataset_gts_loo,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=validation_dataloader,
)

In [71]:
best_model = SasRec.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    tensor_schema=tensor_schema,
)

In [None]:
test_dataloader = DataLoader(
    dataset=SasRecValidationDataset(
        sequential_test_dataset_gts_loo,
        sequential_test_gt_gts_loo,
        sequential_train_dataset_gts_loo,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE
)

In [72]:
test_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset_loo.item_count,
    postprocessors=[RemoveSeenItems(sequential_test_dataset_gts_loo)],
)

trainer_test = L.Trainer(callbacks=[test_metrics_callback], logger=csv_logger)
trainer_test.validate(best_model, test_dataloader)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Validation: |          | 0/? [00:00<?, ?it/s]

k              1        10        20         5
map     0.019569  0.043450  0.047445  0.038536
ndcg    0.019569  0.058902  0.073808  0.046928
recall  0.019569  0.109589  0.169276  0.072407



[{'recall@1': 0.01956947147846222,
  'ndcg@1': 0.01956947147846222,
  'map@1': 0.01956947147846222,
  'recall@5': 0.07240704447031021,
  'ndcg@5': 0.04692849889397621,
  'map@5': 0.038535553961992264,
  'recall@10': 0.10958904027938843,
  'ndcg@10': 0.05890187993645668,
  'map@10': 0.04344966635107994,
  'recall@20': 0.1692759245634079,
  'ndcg@20': 0.07380761206150055,
  'map@20': 0.047445159405469894}]

### Train SASRec with Global Temporal split + random target

In [None]:
MAX_SEQ_LEN = 200
BATCH_SIZE = 512
NUM_WORKERS = 4
MAX_EPOCHS = 10

model = SasRec(
    tensor_schema,
    block_count=2,
    head_count=2,
    max_seq_len=MAX_SEQ_LEN,
    hidden_size=300,
    dropout_rate=0.5,
    optimizer_factory=FatOptimizerFactory(learning_rate=0.001),
)

csv_logger = CSVLogger(save_dir=".logs/train", name="SASRec_example")

checkpoint_callback = ModelCheckpoint(
    dirpath=".checkpoints",
    save_top_k=1,
    verbose=True,
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset_gts_random.item_count,
    postprocessors=[RemoveSeenItems(sequential_validation_dataset_gts_random)],
)

trainer = L.Trainer(
    max_epochs=MAX_EPOCHS,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

train_dataloader = DataLoader(
    dataset=SasRecTrainingDataset(
        sequential_train_dataset_gts_random,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

validation_dataloader = DataLoader(
    dataset=SasRecValidationDataset(
        sequential_validation_dataset_gts_random,
        sequential_validation_gt_gts_random,
        sequential_train_dataset_gts_random,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=validation_dataloader,
)

In [90]:
best_model = SasRec.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    tensor_schema=tensor_schema,
)

In [None]:
test_dataloader = DataLoader(
    dataset=SasRecValidationDataset(
        sequential_test_dataset_gts_random,
        sequential_test_gt_gts_random,
        sequential_train_dataset_gts_random,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE
)

In [93]:
test_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset_loo.item_count,
    postprocessors=[RemoveSeenItems(sequential_test_dataset_gts_random)],
)

trainer_test = L.Trainer(callbacks=[test_metrics_callback], logger=csv_logger)
trainer_test.validate(best_model, test_dataloader)

Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


Validation: |          | 0/? [00:00<?, ?it/s]

k              1        10        20         5
map     0.019569  0.047283  0.051570  0.040753
ndcg    0.019569  0.067384  0.083146  0.051194
recall  0.019569  0.134051  0.196673  0.083170



[{'recall@1': 0.01956947147846222,
  'ndcg@1': 0.01956947147846222,
  'map@1': 0.01956947147846222,
  'recall@5': 0.08317025750875473,
  'ndcg@5': 0.0511941984295845,
  'map@5': 0.040753427892923355,
  'recall@10': 0.1340508759021759,
  'ndcg@10': 0.06738412380218506,
  'map@10': 0.04728318378329277,
  'recall@20': 0.196673184633255,
  'ndcg@20': 0.08314632624387741,
  'map@20': 0.05157012864947319}]