In [None]:
# default_exp models.bert4rec

# BERT4Rec
> Implementation of BERT4Rec transformer-based recommender model in Pytorch Lightning.

In [None]:
#hide
from nbdev.showdoc import *
from fastcore.nb_imports import *
from fastcore.test import *

In [None]:
#export
from typing import Any, Iterable, List, Optional, Tuple, Union, Callable

import torch
import torch.nn as nn
from torch.nn import Linear

from recohut.models.bases.sequential import SequentialModel

In [None]:
#export
class BERT4Rec(SequentialModel):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.item_embeddings = torch.nn.Embedding(
            self.vocab_size, embedding_dim=self.channels
        )
        self.input_pos_embedding = torch.nn.Embedding(512, embedding_dim=self.channels)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.channels, nhead=4, dropout=self.dropout
        )
        self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=6)
        self.linear_out = Linear(self.channels, self.vocab_size)
        self.do = nn.Dropout(p=self.dropout)

    def forward(self, src_items):
        src = self.encode_src(src_items)
        out = self.linear_out(src)
        return out

    def encode_src(self, src_items):
        src_items = self.item_embeddings(src_items)
        batch_size, in_sequence_len = src_items.size(0), src_items.size(1)
        pos_encoder = (
            torch.arange(0, in_sequence_len, device=src_items.device)
            .unsqueeze(0)
            .repeat(batch_size, 1)
        )
        pos_encoder = self.input_pos_embedding(pos_encoder)
        src_items += pos_encoder
        src = src_items.permute(1, 0, 2)
        src = self.encoder(src)
        return src.permute(1, 0, 2)

Example

In [None]:
class Args:
    def __init__(self):
        self.pad = 0
        self.mask = 1
        self.cap = 0
        self.seed = 42
        self.vocab_size = 10000
        self.channels = 128
        self.dropout = 0.4
        self.learning_rate = 1e-4
        self.history_size = 30
        self.data_dir = '/content/data'
        self.log_dir = '/content/recommender_logs'
        self.model_dir = '/content/recommender_models'
        self.batch_size = 32
        self.shuffle = True
        self.max_epochs = 2
        self.val_epoch = 1
        self.gpus = None
        self.monitor = 'valid_loss'
        self.mode = 'min'

args = Args()

In [None]:
def pl_trainer(model, datamodule, max_epochs=10, val_epoch=5, gpus=None, log_dir=None,
               model_dir=None, monitor='val_loss', mode='min', *args, **kwargs):
    log_dir = log_dir if log_dir is not None else os.getcwd()
    model_dir = model_dir if model_dir is not None else os.getcwd()

    logger = TensorBoardLogger(save_dir=log_dir)

    checkpoint_callback = ModelCheckpoint(
        monitor=monitor,
        mode=mode,
        dirpath=model_dir,
        filename="recommender",
    )

    trainer = Trainer(
    max_epochs=max_epochs,
    logger=logger,
    check_val_every_n_epoch=val_epoch,
    callbacks=[checkpoint_callback],
    num_sanity_val_steps=0,
    gradient_clip_val=1,
    gradient_clip_algorithm="norm",
    gpus=gpus
    )

    trainer.fit(model, datamodule=datamodule)
    test_result = trainer.test(model, datamodule=datamodule)
    return test_result

In [None]:
ds = ML1mDataModule(data_sir=args.data_dir, **args.__dict__)
ds.prepare_data()

args.vocab_size = len(ds.data.mapping) + 2

model = BERT4Rec(**args.__dict__)

result_val = pl_trainer(model, ds, **args.__dict__)

output_json = {
    "val_loss": result_val[0]["test_loss"],
    "best_model_path": checkpoint_callback.best_model_path,
}

print(output_json)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Processing...
Done!
Processing...
Done!

  | Name                | Type               | Params
-----------------------------------------------------------
0 | item_embeddings     | Embedding          | 474 K 
1 | input_pos_embedding | Embedding          | 65.5 K
2 | encoder             | TransformerEncoder | 3.6 M 
3 | linear_out          | Linear             | 478 K 
4 | do                  | Dropout            | 0     
-----------------------------------------------------------
4.6 M     Trainable params
0         Non-trainable params
4.6 M     Total params
18.307    Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Processing...
Done!


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_accuracy': 0.0032284767366945744, 'test_loss': 7.648849010467529}
--------------------------------------------------------------------------------
{'val_loss': 7.648849010467529, 'best_model_path': '/content/recommender_models/recommender.ckpt'}
