In [None]:
# default_exp models.sasrec

# SASRec
> Self-Attentive Sequential Recommendation Model.

Sequential dynamics are a key feature of many modern recommender systems, which seek to capture the ‘context’ of users’ activities on the basis of actions they have performed recently. To capture such patterns, two approaches have proliferated: Markov Chains (MCs) and Recurrent Neural Networks (RNNs). Markov Chains assume that a user’s next action can be predicted on the basis of just their last (or last few) actions, while RNNs in principle allow for longer-term semantics to be uncovered. Generally speaking, MC-based methods perform best in extremely sparse datasets, where model parsimony is critical, while RNNs perform better in denser datasets where higher model complexity is affordable. SASRec captures the long-term semantics (like an RNN), but, using an attention mechanism, makes its predictions based on relatively few actions (like an MC).

At each time step, SASRec seeks to identify which items are ‘relevant’ from a user’s action history, and use them to predict the next item. Extensive empirical studies show that this method outperforms various state-of-the-art sequential models (including MC/CNN/RNN-based approaches) on both sparse and dense datasets. Moreover, the model is an order of magnitude more efficient than comparable CNN/RNN-based models.

We adopt the binary cross entropy loss as the objective function:

$$-\sum_{S^u\in S} \sum_{t \in [1,2,\dots,n]}\left[ log(\sigma(r_{o_t,t})) + \sum_{j \notin S^u} log(1-\sigma(r_{j,t})) \right]$$

In [None]:
#hide
from nbdev.showdoc import *
from fastcore.nb_imports import *
from fastcore.test import *

In [None]:
#export
import torch
from torch import nn
import math

from recohut.models.layers.attention import *
from recohut.models.layers.encoding import Encoder

In [None]:
#exporti
class SASEmbedding(nn.Module):
    def __init__(self, args):
        super().__init__()
        vocab_size = args.num_items + 1
        hidden = args.bert_hidden_units
        max_len = args.bert_max_len
        dropout = args.bert_dropout

        self.token = TokenEmbedding(
            vocab_size=vocab_size, embed_size=hidden)
        self.position = PositionalEmbedding(
            max_len=max_len, d_model=hidden)

        self.dropout = nn.Dropout(p=dropout)

    def get_mask(self, x):
        if len(x.shape) > 2:
            x = torch.ones(x.shape[:2]).to(x.device)
        return (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

    def forward(self, x):
        mask = self.get_mask(x)
        if len(x.shape) > 2:
            pos = self.position(torch.ones(x.shape[:2]).to(x.device))
            x = torch.matmul(x, self.token.weight) + pos
        else:
            x = self.token(x) + self.position(x)
        return self.dropout(x), mask

In [None]:
#exporti
class SASModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        hidden = args.bert_hidden_units
        heads = args.bert_num_heads
        head_size = args.bert_head_size
        dropout = args.bert_dropout
        attn_dropout = args.bert_attn_dropout
        layers = args.bert_num_blocks

        self.transformer_blocks = nn.ModuleList([SASTransformerBlock(
            hidden, heads, head_size, hidden * 4, dropout, attn_dropout) for _ in range(layers)])

    def forward(self, x, embedding_weight, mask):
        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)
        scores = torch.matmul(x, embedding_weight.permute(1, 0))
        return scores

In [None]:
#export
class SASRec(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.embedding = SASEmbedding(self.args)
        self.model = SASModel(self.args)
        self.truncated_normal_init()

    def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04):
        with torch.no_grad():
            l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2.
            u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2.

            for n, p in self.model.named_parameters():
                if not 'layer_norm' in n:
                    p.uniform_(2 * l - 1, 2 * u - 1)
                    p.erfinv_()
                    p.mul_(std * math.sqrt(2.))
                    p.add_(mean)
        
    def forward(self, x):
        x, mask = self.embedding(x)
        scores = self.model(x, self.embedding.token.weight, mask)
        return scores

In [None]:
class Args:
    bert_hidden_units = 4
    bert_num_heads = 2
    bert_head_size = 4
    bert_dropout = 0.2
    bert_attn_dropout = 0.2
    bert_num_blocks = 4
    num_items = 10
    bert_hidden_units = 4
    bert_max_len = 8
    bert_dropout = 0.2

args = Args()
model = SASRec(args)
model.parameters

<bound method Module.parameters of SASRec(
  (embedding): SASEmbedding(
    (token): TokenEmbedding(11, 4, padding_idx=0)
    (position): PositionalEmbedding(
      (pe): Embedding(9, 4)
    )
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (model): SASModel(
    (transformer_blocks): ModuleList(
      (0): SASTransformerBlock(
        (layer_norm): LayerNorm()
        (attention): SASMultiHeadedAttention(
          (linear_layers): ModuleList(
            (0): Linear(in_features=4, out_features=8, bias=True)
            (1): Linear(in_features=4, out_features=8, bias=True)
            (2): Linear(in_features=4, out_features=8, bias=True)
          )
          (attention): Attention()
          (dropout): Dropout(p=0.2, inplace=False)
          (layer_norm): LayerNorm()
        )
        (feed_forward): SASPositionwiseFeedForward(
          (conv1): Conv1d(4, 16, kernel_size=(1,), stride=(1,))
          (activation): ReLU()
          (dropout): Dropout(p=0.2, inplace=False)
        

In [None]:
#export
class SASRec_v2(nn.Module):
    """
    References:
        1. https://github.com/RecoHut-Stanzas/STOSA/blob/ee14e2eabcc60922eb52cc7d3231df4954d9ff16/seqmodels.py#L5
    """
    def __init__(self,
                 item_size,
                 hidden_size,
                 attribute_size,
                 max_seq_length,
                 mask_id,
                 num_attention_heads,
                 num_hidden_layers=2,
                 hidden_dropout_prob=0.2,
                 attention_probs_dropout_prob=0.2,
                 hidden_act='gelu',
                 initializer_range=0.02):
        super().__init__()
        self.item_size = item_size
        self.hidden_size = hidden_size
        self.attribute_size = attribute_size
        self.max_seq_length = max_seq_length
        self.mask_id = mask_id
        self.num_attention_heads = num_attention_heads
        self.num_hidden_layers = num_hidden_layers
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range

        self.item_embeddings = nn.Embedding(item_size, hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(max_seq_length, hidden_size)
        self.item_encoder = Encoder(hidden_size, hidden_act, num_attention_heads, 
                                    hidden_dropout_prob, attention_probs_dropout_prob, 
                                    num_hidden_layers)
        self.layernorm = nn.LayerNorm(hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(hidden_dropout_prob)
        self.criterion = nn.BCELoss(reduction='none')
        self.apply(self.init_weights)

    def add_position_embedding(self, sequence):
        seq_length = sequence.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=sequence.device)
        position_ids = position_ids.unsqueeze(0).expand_as(sequence)
        item_embeddings = self.item_embeddings(sequence)
        position_embeddings = self.position_embeddings(position_ids)
        sequence_emb = item_embeddings + position_embeddings
        sequence_emb = self.layernorm(sequence_emb)
        sequence_emb = self.dropout(sequence_emb)
        return sequence_emb

    def forward(self, input_ids):
        attention_mask = (input_ids > 0).long()
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64
        max_len = attention_mask.size(-1)
        attn_shape = (1, max_len, max_len)
        subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1) # torch.uint8
        subsequent_mask = (subsequent_mask == 0).unsqueeze(1)
        subsequent_mask = subsequent_mask.long()

        extended_attention_mask = extended_attention_mask * subsequent_mask
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        sequence_emb = self.add_position_embedding(input_ids)

        item_encoded_layers = self.item_encoder(sequence_emb,
                                                extended_attention_mask,
                                                output_all_encoded_layers=True)

        sequence_output, attention_scores = item_encoded_layers[-1]
        return sequence_output, attention_scores

    def init_weights(self, module):
        """ Initialize the weights.
        """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.initializer_range)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

In [None]:
item_size = 10
hidden_size = 10
attribute_size = 5
max_seq_length = 20
hidden_dropout_prob = 0.2
mask_id = 20
initializer_range = 0.02
num_hidden_layers = 2
attention_probs_dropout_prob = 0.2
hidden_act = 'gelu'
num_attention_heads = 2

model = SASRec_v2(item_size, hidden_size, attribute_size, max_seq_length, mask_id, 
                   num_attention_heads, num_hidden_layers, hidden_dropout_prob, 
                   attention_probs_dropout_prob, hidden_act, initializer_range)

x = torch.randint(0, 5, (item_size, hidden_size))

output = model.forward(x)
output_shapes = [list(x.shape) for x in [j for sub in output for j in sub]]

test_eq(output_shapes[0], [10, 10])
test_eq(output_shapes[11], [2, 10, 10])

> References
1. https://cseweb.ucsd.edu/~jmcauley/pdfs/icdm18.pdf
2. https://github.com/Yueeeeeeee/RecSys-Extraction-Attack/blob/main/model/sasrec.py

In [None]:
#hide
%reload_ext watermark
%watermark -a "Sparsh A." -m -iv -u -t -d -p recohut

Author: Sparsh A.

Last updated: 2021-12-31 06:51:33

recohut: 0.0.8

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.4.144+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit

PIL       : 7.1.2
torch     : 1.10.0+cu111
matplotlib: 3.2.2
numpy     : 1.19.5
IPython   : 5.5.0

