# Example of the Bert4Rec training and inference stages
Note that all the given examples can be run without using PySpark, using only Pandas

In [1]:
import lightning as L
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import torch

from replay.metrics import OfflineMetrics, Recall, Precision, MAP, NDCG, HitRate, MRR
from replay.metrics.torch_metrics_builder import metrics_to_df
from replay.splitters import LastNSplitter
from replay.data import (
    FeatureHint,
    FeatureInfo,
    FeatureSchema,
    FeatureSource,
    FeatureType,
    Dataset,
)
from replay.models.nn.optimizer_utils import FatOptimizerFactory
from replay.models.nn.sequential.callbacks import (
    ValidationMetricsCallback,
    SparkPredictionCallback,
    PandasPredictionCallback,
    TorchPredictionCallback,
    QueryEmbeddingsPredictionCallback,
)
from replay.models.nn.sequential.postprocessors import RemoveSeenItems
from replay.data.nn import (
    SequenceTokenizer,
    SequentialDataset,
    TensorFeatureSource,
    TensorSchema,
    TensorFeatureInfo
)
from replay.models.nn.sequential import Bert4Rec
from replay.models.nn.sequential.bert4rec import (
    Bert4RecPredictionDataset,
    Bert4RecTrainingDataset,
    Bert4RecValidationDataset,
    Bert4RecPredictionBatch,
    Bert4RecModel
)

import pandas as pd

## Prepare data
### Load raw movielens-1M interactions, item features and user features.
In the current implementation, the Bert4Rec does not take into account the features of items or users. They are only used to get a complete list of users and items.

In [None]:
!pip install rs-datasets

In [3]:
from rs_datasets import MovieLens

In [4]:
movielens = MovieLens("1m")
interactions = movielens.ratings
user_features = movielens.users
item_features = movielens.items

In [5]:
interactions.head()

Unnamed: 0,user_id,item_id,rating,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291


In [6]:
user_features.head()

Unnamed: 0,user_id,gender,age,occupation,zip_code
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117
3,4,M,45,7,2460
4,5,M,25,20,55455


In [7]:
item_features.head()

Unnamed: 0,item_id,title,genres
0,1,Toy Story (1995),Animation|Children's|Comedy
1,2,Jumanji (1995),Adventure|Children's|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama
4,5,Father of the Bride Part II (1995),Comedy


Removing duplicates in the timestamp column without changing the original items order where timestamp is the same

In [8]:
interactions["timestamp"] = interactions["timestamp"].astype("int64")
interactions = interactions.sort_values(by="timestamp")
interactions["timestamp"] = interactions.groupby("user_id").cumcount()
interactions

Unnamed: 0,user_id,item_id,rating,timestamp
1000138,6040,858,4,0
1000153,6040,2384,4,1
999873,6040,593,5,2
1000007,6040,1961,4,3
1000192,6040,2019,5,4
...,...,...,...,...
825793,4958,2399,1,446
825438,4958,1407,5,447
825724,4958,3264,4,448
825731,4958,2634,3,449


### Split interactions into the train, validation and test datasets using LastNSplitter

In [9]:
splitter = LastNSplitter(
    N=1,
    divide_column="user_id",
    query_column="user_id",
    strategy="interactions",
)

raw_test_events, raw_test_gt = splitter.split(interactions)
raw_validation_events, raw_validation_gt = splitter.split(raw_test_events)
raw_train_events = raw_validation_events

### Prepare FeatureSchema required to create Dataset

In [10]:
def prepare_feature_schema(is_ground_truth: bool) -> FeatureSchema:
    base_features = FeatureSchema(
        [
            FeatureInfo(
                column="user_id",
                feature_hint=FeatureHint.QUERY_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
            FeatureInfo(
                column="item_id",
                feature_hint=FeatureHint.ITEM_ID,
                feature_type=FeatureType.CATEGORICAL,
            ),
        ]
    )
    if is_ground_truth:
        return base_features

    all_features = base_features + FeatureSchema(
        [
            FeatureInfo(
                column="timestamp",
                feature_type=FeatureType.NUMERICAL,
                feature_hint=FeatureHint.TIMESTAMP,
            ),
        ]
    )
    return all_features

### Create Dataset for the training stage

In [11]:
train_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_train_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)

### Create Datasets (events and ground_truth) for the validation stage

In [12]:
validation_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_validation_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
validation_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_validation_gt,
    check_consistency=True,
    categorical_encoded=False,
)

### Create Datasets (events and ground_truth) for the testing stage

In [13]:
test_dataset = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=False),
    interactions=raw_test_events,
    query_features=user_features,
    item_features=item_features,
    check_consistency=True,
    categorical_encoded=False,
)
test_gt = Dataset(
    feature_schema=prepare_feature_schema(is_ground_truth=True),
    interactions=raw_test_gt,
    check_consistency=True,
    categorical_encoded=False,
)

### Create the tensor schema
A schema shows the correspondence of columns from the source dataset with the internal representation of tensors inside the model

In [14]:
ITEM_FEATURE_NAME = "item_id_seq"

tensor_schema = TensorSchema(
    TensorFeatureInfo(
        name=ITEM_FEATURE_NAME,
        is_seq=True,
        feature_type=FeatureType.CATEGORICAL,
        feature_sources=[TensorFeatureSource(FeatureSource.INTERACTIONS, train_dataset.feature_schema.item_id_column)],
        feature_hint=FeatureHint.ITEM_ID,
        embedding_dim=300,
    )
)

### Create sequential datasets using SequenceTokenizer
The SequentialDataset internally store data in the form of sequences of items sorted by increasing interaction time (timestamp). A SequenceTokenizer is used to convert to this format. In addition, the SequenceTokenizer encodes all categorical columns from the source dataset and stores mapping inside itself.
SequentialDataset.keep_common_query_ids is used to leave only sequences from the same users

In [15]:
tokenizer = SequenceTokenizer(tensor_schema, allow_collect_to_master=True)
tokenizer.fit(train_dataset)

sequential_train_dataset = tokenizer.transform(train_dataset)

sequential_validation_dataset = tokenizer.transform(validation_dataset)
sequential_validation_gt = tokenizer.transform(validation_gt, [tensor_schema.item_id_feature_name])

sequential_validation_dataset, sequential_validation_gt = SequentialDataset.keep_common_query_ids(
    sequential_validation_dataset, sequential_validation_gt
)

In [16]:
test_query_ids = test_gt.query_ids
test_query_ids_np = tokenizer.query_id_encoder.transform(test_query_ids)["user_id"].values
sequential_test_dataset = tokenizer.transform(test_dataset).filter_by_query_id(test_query_ids_np)

## Train model
### Create Bert4Rec model instance and run the training stage using lightning
After each epoch validation metrics are shown. You can change the list of validation metrics in ValidationMetricsCallback
The model is determined to be the best and is saved if the metric updates its maximum during validation (see the ModelCheckpoint)

In [None]:
MAX_SEQ_LEN = 100
BATCH_SIZE = 512
NUM_WORKERS = 4
MAX_EPOCHS = 1

model = Bert4Rec(
    tensor_schema,
    block_count=2,
    head_count=4,
    max_seq_len=MAX_SEQ_LEN,
    hidden_size=300,
    dropout_rate=0.5,
    optimizer_factory=FatOptimizerFactory(learning_rate=0.001),
)
checkpoint_callback = ModelCheckpoint(
    dirpath=".checkpoints",
    save_top_k=1,
    verbose=True,
    # if you use multiple dataloaders, then add the serial number of the dataloader to the suffix of the metric name.
    # For example,"recall@10/dataloader_idx_0"
    monitor="recall@10",
    mode="max",
)

validation_metrics_callback = ValidationMetricsCallback(
    metrics=["map", "ndcg", "recall"],
    ks=[1, 5, 10, 20],
    item_count=train_dataset.item_count,
    postprocessors=[RemoveSeenItems(sequential_validation_dataset)]
)

csv_logger = CSVLogger(save_dir=".logs/train", name="Bert4Rec_example")

trainer = L.Trainer(
    max_epochs=MAX_EPOCHS,
    callbacks=[checkpoint_callback, validation_metrics_callback],
    logger=csv_logger,
)

train_dataloader = DataLoader(
    dataset=Bert4RecTrainingDataset(
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

validation_dataloader = DataLoader(
    dataset=Bert4RecValidationDataset(
        sequential_validation_dataset,
        sequential_validation_gt,
        sequential_train_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

trainer.fit(
    model,
    train_dataloaders=train_dataloader,
    val_dataloaders=validation_dataloader,
)

The path to the best model is saved inside checkpoint_callback

In [19]:
best_model = Bert4Rec.load_from_checkpoint(checkpoint_callback.best_model_path)

## Inference stage
### Prepare Dataloader and logger

In [None]:
prediction_dataloader = DataLoader(
    dataset=Bert4RecPredictionDataset(
        sequential_test_dataset,
        max_sequence_length=MAX_SEQ_LEN,
    ),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

csv_logger = CSVLogger(save_dir=".logs/test", name="Bert4Rec_example")

### Run inference
You can get the recommendations in four formats: PySpark DataFrame, Pandas DataFrame, Polars DataFrame, PyTorch tensors. Each of the types corresponds a callback.

You can filter the results using postprocessors strategy. For example the RemoveSeenItems postprocessor is filtering out the items that already have been seen in test dataset.

You don't need to use all three callbacks. This is shown only for example

Also, you can get user embeddings, that were used to perform predictions, using `get_query_embedding` method inside Bert4RecModel or `QueryEmbeddingsPredictionCallback` for lightning module.

To operate with PySpark DataFrames and use ``SparkPredictionCallback`` you should create a spark session.

In [None]:
from replay.utils.session_handler import get_spark_session
spark_session = get_spark_session()

In [None]:
TOPK = [1, 10, 20, 100]

postprocessors = [RemoveSeenItems(sequential_test_dataset)]

spark_prediction_callback = SparkPredictionCallback(
    spark_session=spark_session,
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=postprocessors,
)

pandas_prediction_callback = PandasPredictionCallback(
    top_k=max(TOPK),
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=postprocessors,
)

torch_prediction_callback = TorchPredictionCallback(
    top_k=max(TOPK),
    postprocessors=postprocessors,
)

query_embeddings_callback = QueryEmbeddingsPredictionCallback()

trainer = L.Trainer(
    callbacks=[
        spark_prediction_callback,
        pandas_prediction_callback,
        torch_prediction_callback,
        query_embeddings_callback,
    ],
    logger=csv_logger,
    inference_mode=True
)
trainer.predict(best_model, dataloaders=prediction_dataloader, return_predictions=False)

spark_res = spark_prediction_callback.get_result()
pandas_res = pandas_prediction_callback.get_result()
torch_user_ids, torch_item_ids, torch_scores = torch_prediction_callback.get_result()
user_embeddings = query_embeddings_callback.get_result()

In [22]:
spark_res.show()

[Stage 0:>                                                          (0 + 1) / 1]

+-------+-------+------------------+
|user_id|item_id|             score|
+-------+-------+------------------+
|      0|   2559|2.1270530223846436|
|      0|   1539|2.1258444786071777|
|      0|    589| 2.071683883666992|
|      0|   1192| 1.998902440071106|
|      0|   1178|1.9865914583206177|
|      0|    585|1.8850951194763184|
|      0|   2647|1.8264131546020508|
|      0|   2821|1.8195174932479858|
|      0|   1284|1.8040269613265991|
|      0|   2789|1.8005681037902832|
|      0|   2502|1.7928752899169922|
|      0|   1220| 1.791283369064331|
|      0|    108|1.7866688966751099|
|      0|   1196|1.7825181484222412|
|      0|   1227|1.7511154413223267|
|      0|   2530| 1.718513011932373|
|      0|   1111| 1.713446855545044|
|      0|    293| 1.691711187362671|
|      0|    642|1.6797393560409546|
|      0|   2327|1.6610043048858643|
+-------+-------+------------------+
only showing top 20 rows



                                                                                

In [23]:
pandas_res

Unnamed: 0,user_id,item_id,score
0,0,2559,2.127053
0,0,1539,2.125844
0,0,589,2.071684
0,0,1192,1.998902
0,0,1178,1.986591
...,...,...,...
6039,6039,3686,0.782319
6039,6039,2512,0.782027
6039,6039,1556,0.781867
6039,6039,584,0.779429


In [24]:
print(torch_user_ids[0], torch_item_ids[0], torch_scores[0])

tensor(0) tensor([2559, 1539,  589, 1192, 1178,  585, 2647, 2821, 1284, 2789, 2502, 1220,
         108, 1196, 1227, 2530, 1111,  293,  642, 2327, 3509, 1366, 2928, 1575,
        3106,  352,   33,  220,  476, 2637, 3402, 1180, 1211,  900, 1726,  847,
        3682, 2847, 2614, 1287, 1195,  315, 3724,  908, 1245, 2918,  537, 3091,
         453,  941, 2105, 1073, 1202, 1513,   31, 3184, 1854, 1166, 2588, 2536,
        1058, 3412, 1214,  918,  593, 1628, 2632, 1543, 1899, 1271, 1204, 3441,
        2433, 2233,  740,  770,  956, 3554, 1529, 1353,  912, 1533, 1282, 1246,
        1120, 1023, 1238,  586,   46, 3430, 2324, 1931, 3107, 1568, 3686, 1373,
        1840, 1885, 1232, 3178]) tensor([2.1271, 2.1258, 2.0717, 1.9989, 1.9866, 1.8851, 1.8264, 1.8195, 1.8040,
        1.8006, 1.7929, 1.7913, 1.7867, 1.7825, 1.7511, 1.7185, 1.7134, 1.6917,
        1.6797, 1.6610, 1.6569, 1.6292, 1.6153, 1.6131, 1.6088, 1.6063, 1.6017,
        1.5603, 1.5589, 1.5530, 1.5476, 1.5298, 1.5297, 1.5266, 1.5113, 1.507

Suppose we want to get the recomendations in PySpark format. 
Let's get the inverse representation of labels using inverse_transform method.

Note that the reverse representation can only be obtained for PySpark and Pandas formats. When working with PyTorch tensors, the reverse representation must be done manually

In [25]:
recommendations = tokenizer.query_and_item_id_encoder.inverse_transform(spark_res)

In [26]:
recommendations.show()

+------------------+-------+-------+
|             score|user_id|item_id|
+------------------+-------+-------+
|2.1270530223846436|      1|   2628|
|2.1258444786071777|      1|   1580|
| 2.071683883666992|      1|    593|
| 1.998902440071106|      1|   1210|
|1.9865914583206177|      1|   1196|
|1.8850951194763184|      1|    589|
|1.8264131546020508|      1|   2716|
|1.8195174932479858|      1|   2890|
|1.8040269613265991|      1|   1304|
|1.8005681037902832|      1|   2858|
|1.7928752899169922|      1|   2571|
| 1.791283369064331|      1|   1240|
|1.7866688966751099|      1|    110|
|1.7825181484222412|      1|   1214|
|1.7511154413223267|      1|   1247|
| 1.718513011932373|      1|   2599|
| 1.713446855545044|      1|   1127|
| 1.691711187362671|      1|    296|
|1.6797393560409546|      1|    648|
|1.6610043048858643|      1|   2396|
+------------------+-------+-------+
only showing top 20 rows



### Run inference on a subset of items

It happens that it is necessary to process an inference not on all items, but on a certain subset (we will call it ``candidates``). For example, you want to make predictions only for the cartoons among all possible movies.

To speed up the inference in this case, you can use the Bert4Rec's property ``candidates_to_score``. It should be a ``torch.LongTensor`` with the IDs of the objects on which you want to process an inference. It is important that the candidate scores will be returned in the order in which their IDs were in the ``candidates_to_score``.

In [27]:
best_model_candidates = Bert4Rec.load_from_checkpoint(checkpoint_callback.best_model_path)

In [None]:
TOPK = 2
CANDIDATES = torch.LongTensor([42, 1337])

postprocessors = [RemoveSeenItems(sequential_test_dataset)]

pandas_prediction_callback = PandasPredictionCallback(
    top_k=TOPK,
    query_column="user_id",
    item_column="item_id",
    rating_column="score",
    postprocessors=postprocessors,
)

trainer = L.Trainer(callbacks=[pandas_prediction_callback], logger=csv_logger, inference_mode=True)
best_model_candidates.candidates_to_score = CANDIDATES
trainer.predict(best_model_candidates, dataloaders=prediction_dataloader, return_predictions=False)

There will be scores only for items whose IDs are contained in ``candidates_to_score``.

If ``candidates_to_score`` contains a small number of candidates and ``top_k`` parameter is small, it may happen that the required number of items will not remain after the postprocessor is running. In this case, the ``top_k`` items for each user will be returned from the model, then the postprocessor will remove the seen items and if the user has less than the ``top_k`` items, then the non-candidate items with a score equal to ``-inf`` will be added to it.

In [29]:
pandas_prediction_callback.get_result()

Unnamed: 0,user_id,item_id,score
0,0,1337,0.963838
0,0,42,-1.138735
1,1,1337,0.605986
1,1,42,-0.830044
2,2,1337,0.941634
...,...,...,...
6037,6037,42,-1.272029
6038,6038,1337,0.561242
6038,6038,42,-0.793737
6039,6039,42,-0.82424


**Note:** don`t forget to reset ``candidates_to_score`` to ``None`` if they are no longer needed and you want to run the model inference with all items.

In [None]:
best_model_candidates.candidates_to_score = None

### Calculating metrics

In [30]:
init_args = {"query_column": "user_id", "rating_column": "score"}

In [None]:
result_metrics = OfflineMetrics(
    [Recall(TOPK), Precision(TOPK), MAP(TOPK), NDCG(TOPK), MRR(TOPK), HitRate(TOPK)], **init_args
)(recommendations.toPandas(), raw_test_gt)

In [32]:
metrics_to_df(result_metrics)

k,2
HitRate,0.007119
MAP,0.005215
MRR,0.005215
NDCG,0.005714
Precision,0.00356
Recall,0.007119


### User embeddings

Got 6040 x 300 user embeddings, because among all 12 batches: 

11 batches contains 512 samples

1 batch contains 408 left samples

11 * 512 + 408 == 6040

In [33]:
user_embeddings

tensor([[-0.7878,  0.8159,  1.2742,  ..., -1.4545,  1.6887,  0.0550],
        [-0.4360,  0.5662,  1.4676,  ..., -0.6618,  1.4071,  0.3617],
        [-0.6623,  0.9143,  1.3480,  ..., -1.2325,  1.7697,  0.2644],
        ...,
        [-1.3103,  0.8071,  1.0478,  ..., -0.8558,  1.4005,  0.6592],
        [-0.4810,  0.5335,  1.7069,  ..., -0.7740,  1.3061,  0.3048],
        [-0.3925,  0.4476,  1.3665,  ..., -0.3378,  1.2455,  0.2170]])

In [34]:
user_embeddings.shape

torch.Size([6040, 300])

You can access user embeddings directly with `Bert4RecModel` class

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

core_model = Bert4RecModel(
    tensor_schema,
    num_blocks=2,
    num_heads=4,
    max_len=MAX_SEQ_LEN,
    hidden_size=300,
    dropout=0.5
)
core_model.eval()
core_model = core_model.to(device)

# Get first batch of data
data = next(iter(prediction_dataloader))
tensor_map, padding_mask, tokens_mask = data["inputs"], data["pad_mask"], data["token_mask"]

# Ensure everything is on the same device
padding_mask = padding_mask.to(device)
tokens_mask = tokens_mask.to(device)
tensor_map["item_id_seq"] = tensor_map["item_id_seq"].to(device)

# Get user embeddings
user_embeddings_batch = core_model.get_query_embeddings(tensor_map, padding_mask, tokens_mask)
user_embeddings_batch

tensor([[-0.5941,  0.4922,  1.0920,  ..., -0.6684,  0.4808, -0.6346],
        [ 0.3899,  0.0972,  0.6055,  ..., -1.2973,  0.4056, -0.5432],
        [-0.7955,  0.1026,  1.0066,  ..., -1.1204,  0.2826, -0.4639],
        ...,
        [ 0.1230,  0.2439,  0.3789,  ..., -1.8021,  0.0392, -0.5761],
        [-0.4846,  0.6347,  1.2935,  ..., -1.7495,  0.5781, -0.9601],
        [-0.8379,  0.3875,  0.8284,  ..., -1.1218,  0.6631, -0.2801]],
       grad_fn=<SliceBackward0>)

In [36]:
user_embeddings_batch.shape

torch.Size([512, 300])

## Example of launching an inference for a single user without using a trainer (in order to speed up)
An example for the production of an online script

Let's assume that the user's sequence consisted of a sequence of items [1, 2, 3, 4, 5]. 
Сreate a padding mask and tokens mask corresponding to the sequence of items.

It is important to take only the latest MAX_SEQ_LEN or less items.

You can use ``candidates_to_score`` here as well. It is possible to set a property or pass ``candidates_to_score`` as a parameter of predict method.

**Note:** make sure that you set the ``torch.set_num_threads()`` parameter in the product environment. This is important because torch can consume resources exceeding the k8s limit and thus activating CPU throttling.

### Create sequence, padding_mask and tokens_mask

In [38]:
item_sequence = torch.arange(1, 5).unsqueeze(0)[:, -MAX_SEQ_LEN:]
padding_mask = torch.ones_like(item_sequence, dtype=torch.bool)
tokens_mask = padding_mask.roll(-1, dims=0)
tokens_mask[-1, ...] = 0
sequence_item_count = item_sequence.shape[1]

### Wrapping created tensors in the Bert4RecPredictionBatch entity

In [46]:
batch = Bert4RecPredictionBatch(
    query_id=torch.arange(0, item_sequence.shape[0], 1).long(),
    padding_mask=padding_mask,
    features={ITEM_FEATURE_NAME: item_sequence.long()},
    tokens_mask=tokens_mask
)

### Run predict step of the Bert4Rec and get scores from the model

In [44]:
with torch.no_grad():
    scores = best_model.predict(batch)
scores

tensor([[ 3.3967,  0.1647, -0.7265,  ..., -2.8732, -1.8079,  0.8802]])

### Getting five items with the highest score

In [45]:
torch.topk(scores, k=5).indices

tensor([[ 257, 2233,  476,    0,  352]])

You can pass ``candidates_to_score`` in predict().

In [47]:
with torch.no_grad():
    scores = best_model.predict(batch, candidates_to_score=CANDIDATES)
scores

tensor([[-0.9725,  1.6978]])

## Optimized inference on CPU with OpenVino
Bert4Rec model can be compiled into IR format of OpenVino for faster inference on CPU.

Bert4Rec model itself or the path to the checkpoint of the model can be passed as ``model`` parameter. 

Parameter ``mode`` defines inference mode and shape of inputs. Could be one of ``one_query``, ``batch``, ``dynamic_batch_size``. This parameter determines whether the first dimension of the input (the batch size) will be static or dynamic.

Parameter ``num_candidates_to_score`` defines number of item ids to calculate scores if it is necessary. This parameter determines whether the model will make a partial inference and, if so, whether the list of candidates will have a static or dynamic length.

Parameter ``num_threads`` defines number of CPU threads to use.

In [49]:
from replay.models.nn.sequential.compiled import Bert4RecCompiled

In [48]:
best_model = Bert4Rec.load_from_checkpoint(checkpoint_callback.best_model_path)

 Compile model from Bert4Rec model or checkpoint.

In [None]:
opt_model = Bert4RecCompiled.compile(
    model=best_model,  # or checkpoint_callback.best_model_path
    mode="one_query",
)

Wrapping tensors in the Bert4RecPredictionBatch entity

In [51]:
item_sequence = torch.arange(1, 5).unsqueeze(0)[:, -MAX_SEQ_LEN:]
padding_mask = torch.ones_like(item_sequence, dtype=torch.bool)
tokens_mask = padding_mask.roll(-1, dims=0)
tokens_mask[-1, ...] = 0
sequence_item_count = item_sequence.shape[1]

In [52]:
batch = Bert4RecPredictionBatch(
    query_id=torch.arange(0, item_sequence.shape[0], 1).long(),
    padding_mask=padding_mask,
    features={ITEM_FEATURE_NAME: item_sequence.long()},
    tokens_mask=tokens_mask
)

Run predict and get scores from the model

In [53]:
opt_model.predict(batch).shape

torch.Size([1, 3883])

### Compiled model also supports inference on submitted candidates.

Wrapping created tensors in the Bert4RecPredictionBatch entity

In [54]:
batch = Bert4RecPredictionBatch(
    query_id=torch.arange(0, item_sequence.shape[0], 1).long(),
    padding_mask=padding_mask,
    features={ITEM_FEATURE_NAME: item_sequence.long()},
    tokens_mask=tokens_mask
)

Compile model with defined ``num_candidates_to_score``. There are 3 alternatives:
- ``-1`` - sets candidates_to_score shape to dynamic range [1, ?]
- ``N`` - sets candidates_to_score shape to [1, N]
- ``None`` - disable candidates_to_score usage

In [55]:
opt_model = Bert4RecCompiled.compile(
    model=best_model,
    mode="one_query",
    num_candidates_to_score=2,
)

Run predict and get scores from the model

In [56]:
opt_model.predict(batch, CANDIDATES).shape

torch.Size([1, 2])