In [None]:
# default_exp models.bert

# BERT
> Bidirectional Encoder Representations from Transformers.

BERT was one of the first autoencoding language models to utilize the encoder Transformer stack with slight modifications for language modeling. Its architecture is a multilayer Transformer encoder based on the Transformer original implementation. The Transformer model itself was originally for machine translation tasks, but the main improvement made by BERT is the utilization of this part of the architecture to provide better language modeling. This language model, after pretraining, is able to provide a global understanding of the language it is trained on.

In [None]:
#hide
from nbdev.showdoc import *
from fastcore.nb_imports import *
from fastcore.test import *

In [None]:
#export
import torch
from torch import nn as nn
import math

from recohut.models.layers.attention import *

In [None]:
#exporti
class BERTEmbedding(nn.Module):
    def __init__(self, args):
        super().__init__()
        vocab_size = args.num_items + 2
        hidden = args.bert_hidden_units
        max_len = args.bert_max_len
        dropout = args.bert_dropout

        self.token = TokenEmbedding(
            vocab_size=vocab_size, embed_size=hidden)
        self.position = PositionalEmbedding(
            max_len=max_len, d_model=hidden)

        self.layer_norm = LayerNorm(features=hidden)
        self.dropout = nn.Dropout(p=dropout)

    def get_mask(self, x):
        if len(x.shape) > 2:
            x = torch.ones(x.shape[:2]).to(x.device)
        return (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)

    def forward(self, x):
        mask = self.get_mask(x)
        if len(x.shape) > 2:
            pos = self.position(torch.ones(x.shape[:2]).to(x.device))
            x = torch.matmul(x, self.token.weight) + pos
        else:
            x = self.token(x) + self.position(x)
        return self.dropout(self.layer_norm(x)), mask


class BERTModel(nn.Module):
    def __init__(self, args):
        super().__init__()
        hidden = args.bert_hidden_units
        heads = args.bert_num_heads
        head_size = args.bert_head_size
        dropout = args.bert_dropout
        attn_dropout = args.bert_attn_dropout
        layers = args.bert_num_blocks

        self.transformer_blocks = nn.ModuleList([TransformerBlock(
            hidden, heads, head_size, hidden * 4, dropout, attn_dropout) for _ in range(layers)])
        self.linear = nn.Linear(hidden, hidden)
        self.bias = torch.nn.Parameter(torch.zeros(args.num_items + 2))
        self.bias.requires_grad = True
        self.activation = GELU()

    def forward(self, x, embedding_weight, mask):
        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)
        x = self.activation(self.linear(x))
        scores = torch.matmul(x, embedding_weight.permute(1, 0)) + self.bias
        return scores

In [None]:
#export
class BERT(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.embedding = BERTEmbedding(self.args)
        self.model = BERTModel(self.args)
        self.truncated_normal_init()

    def truncated_normal_init(self, mean=0, std=0.02, lower=-0.04, upper=0.04):
        with torch.no_grad():
            l = (1. + math.erf(((lower - mean) / std) / math.sqrt(2.))) / 2.
            u = (1. + math.erf(((upper - mean) / std) / math.sqrt(2.))) / 2.

            for n, p in self.model.named_parameters():
                if not 'layer_norm' in n:
                    p.uniform_(2 * l - 1, 2 * u - 1)
                    p.erfinv_()
                    p.mul_(std * math.sqrt(2.))
                    p.add_(mean)
        
    def forward(self, x):
        x, mask = self.embedding(x)
        scores = self.model(x, self.embedding.token.weight, mask)
        return scores

In [None]:
class Args:
    bert_hidden_units = 4
    bert_num_heads = 2
    bert_head_size = 4
    bert_dropout = 0.2
    bert_attn_dropout = 0.2
    bert_num_blocks = 4
    num_items = 10
    bert_hidden_units = 4
    bert_max_len = 8
    bert_dropout = 0.2

args = Args()
model = BERT(args)
model.parameters

<bound method Module.parameters of BERT(
  (embedding): BERTEmbedding(
    (token): TokenEmbedding(12, 4, padding_idx=0)
    (position): PositionalEmbedding(
      (pe): Embedding(9, 4)
    )
    (layer_norm): LayerNorm()
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (model): BERTModel(
    (transformer_blocks): ModuleList(
      (0): TransformerBlock(
        (attention): MultiHeadedAttention(
          (linear_layers): ModuleList(
            (0): Linear(in_features=4, out_features=8, bias=True)
            (1): Linear(in_features=4, out_features=8, bias=True)
            (2): Linear(in_features=4, out_features=8, bias=True)
          )
          (attention): Attention()
          (dropout): Dropout(p=0.2, inplace=False)
          (output_linear): Linear(in_features=8, out_features=4, bias=True)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=4, out_features=16, bias=True)
          (w_2): Linear(in_features=16, out_features=4, bias

> References
1. https://github.com/Yueeeeeeee/RecSys-Extraction-Attack/blob/main/model/bert.py

In [None]:
#hide
%reload_ext watermark
%watermark -a "Sparsh A." -m -iv -u -t -d

Author: Sparsh A.

Last updated: 2021-12-31 07:01:15

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.4.144+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit

numpy     : 1.19.5
matplotlib: 3.2.2
torch     : 1.10.0+cu111
PIL       : 7.1.2
IPython   : 5.5.0

