### **Structure of BERT**

```
BertEmbeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (token_type_embeddings): Embedding(2, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
```

```
BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=768, out_features=768, bias=True)
      (key): Linear(in_features=768, out_features=768, bias=True)
      (value): Linear(in_features=768, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=768, out_features=3072, bias=True)
  )
  (output): BertOutput(
    (dense): Linear(in_features=3072, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)
```


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import subprocess
from models import *
from utils import *
from parse import *
import random
from bert_optimizer import BertAdam
from copy import deepcopy as cc
from transformers import BertTokenizer, BertModel
import gc
from torch import nn
from torch.nn import GELU as GELU
gc.disable()
gc.collect()

In [None]:
class Config():
    "Configuration for BERT model"
    vocab_size: int = 30522 # Size of Vocabulary
    dim: int = 768 # Dimension of Hidden Layer in Transformer Encoder
    n_layers: int = 12 # Numher of Hidden Layers
    n_heads: int = 12 # Numher of Heads in Multi-Headed Attention Layers
    dim_ff: int = 768*4 # Dimension of Intermediate Layers in Positionwise Feedforward Net
    activ_fn: str = "gelu" # Non-linear Activation Function Type in Hidden Layers
    p_drop_hidden: float = 0.1 # Probability of Dropout of various Hidden Layers
    p_drop_attn: float = 0.1 # Probability of Dropout of Attention Layers
    max_len: int = 512 # Maximum Length for Positional Embeddings
    n_segments: int = 2 # Number of Sentence Segments
    layer_norm_eps: int = 1e-12 # eps value for the LayerNorms
    output_attentions : bool = False # Weather to output the attention scores

In [None]:
# defining the tree-transformer

# h = 8
# d_model = 768
# cuda_present = True
# d_ff = 2048
# dropout = 0.1
# vocab_size = 30522
# N = 10
# attn = MultiHeadedAttention(h, d_model, no_cuda = True)
# group_attn = GroupAttention(d_model, no_cuda = False)
# ff = PositionwiseFeedForward(d_model, d_ff, dropout)
# position = PositionalEncoding(d_model, dropout)
# word_embed = nn.Sequential(Embeddings(d_model, vocab_size), cc(position))
# model = Encoder(EncoderLayer(d_model, cc(attn), cc(ff), group_attn, dropout), 
#         N, d_model, vocab_size, cc(word_embed))
cnfg = Config()

In [None]:
# defining the BERT
BERT = BertModel.from_pretrained('bert-base-uncased')

### Embedding Layer

In [None]:
class BertEmbeddings(nn.Module):
    "The embedding module from word, position and token_type embeddings."
    def __init__(self, cfg):
        super().__init__()
        self.word_embeddings = nn.Embedding(cfg.vocab_size, cfg.dim, padding_idx=0) # token embedding
        self.position_embeddings = nn.Embedding(cfg.max_len, cfg.dim) # position embedding
        self.token_type_embeddings = nn.Embedding(cfg.n_segments, cfg.dim) # segment(token type) embedding

        self.LayerNorm = nn.LayerNorm(cfg.dim, eps=cfg.layer_norm_eps)
        self.dropout = nn.Dropout(cfg.p_drop_hidden)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
        pos = pos.unsqueeze(0).expand_as(x) # (S,) -> (B, S)

        e = self.word_embeddings(x) + self.position_embeddings(pos) + self.token_type_embeddings(seg)
        e = self.LayerNorm(e)
        e = self.dropout(e)
        return e

### Self Attention Layer

In [132]:

# X -> self_attn -> X [calculated by self attention]
class BertSelfAttention(nn.Module):
    """ Multi-Headed Dot Product Attention """
    def __init__(self, cfg):
        super().__init__()
        self.query = nn.Linear(cfg.dim, cfg.dim)
        self.key = nn.Linear(cfg.dim, cfg.dim)
        self.value = nn.Linear(cfg.dim, cfg.dim)
        self.dropout = nn.Dropout(cfg.p_drop_attn)
        self.n_heads = cfg.n_heads

    def forward(self, x, mask = None, output_attentions = False):
        """
        x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
        mask : (B(batch_size) x S(seq_len))
        * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
        """
        B, S, D = x.shape
        H = self.n_heads
        W = int( D/H )
        assert W * H == D

        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q, k, v = self.query(x), self.key(x), self.value(x)
        q, k, v = q.reshape((B, S, H, W)), k.reshape((B, S, H, W)), v.reshape((B, S, H, W))
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -Masking -> softmax-> (B, H, S, S)
        scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        
        # Set the attention score at places where MASK = 0 to very low value (-1e9)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = self.dropout(F.softmax(scores, dim=-1))
        
        # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) -merge-> (B, S, D)
        hidden_states = (scores @ v).transpose(1, 2).contiguous()
        hidden_states = hidden_states.reshape(B, S, D)
        return (hidden_states, scores) if output_attentions else (hidden_states,)


#  f(X) -> Linear (D => D) -> Dropout -> X1
#  X, f(X) -> LayerNorm( X + X1 )
class BertSelfOutput(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.dense = nn.Linear(cfg.dim, cfg.dim)
        self.LayerNorm = nn.LayerNorm(cfg.dim, eps= cfg.layer_norm_eps)
        self.dropout = nn.Dropout(cfg.p_drop_hidden)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


# X -> self_attn -> f(X) -> Linear (D => D) -> Dropout -> X1
# Y = LayerNorm( X + X1 )
class BertAttention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.self = BertSelfAttention(cfg)
        self.output = BertSelfOutput(cfg)

    def forward(self, hidden_states, attention_mask=None, output_attentions = False):
        self_outputs = self.self(hidden_states,  attention_mask, output_attentions)
        attention_output = self.output(self_outputs[0], hidden_states)

        # add attentions if we output them; self_outputs[1:] is the attn score, attention_output is the hidden representation
        outputs = (attention_output,) + self_outputs[1:] if output_attentions else (attention_output,)
        return outputs        

In [139]:
sample_ip = torch.randint(low = 0, high = 100, size=(3, 10))

bert_emb = BertEmbeddings(cnfg)
bert_self_attn = BertAttention(cnfg)

op = bert_emb(sample_ip, torch.zeros_like(sample_ip))
op = bert_self_attn(hidden_states = op, output_attentions = False)

### Forward Expansion layer

In [135]:
# f(X) -> Linear(D => 4xD) -> GELU() -> X2
class BertIntermediate(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.dense = nn.Linear(cfg.dim, cfg.dim_ff)
        self.GELU = nn.GELU()

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.GELU(hidden_states)
        return hidden_states


# f(X) -> Linear(4xD => D) -> Drouput -> X1
#  X, f(X) -> LayerNorm( X + X1 )

class BertOutput(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.dense = nn.Linear(cfg.dim_ff, cfg.dim)
        self.LayerNorm = nn.LayerNorm(cfg.dim, eps=cfg.layer_norm_eps)
        self.dropout = nn.Dropout(cfg.p_drop_hidden)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states


### Defining single BERT-encoder layer

In [147]:
bert_emb = BertEmbeddings(cnfg)
bert_self_attn = BertAttention(cnfg)
Bert_intermeidate = BertIntermediate(cnfg)
Bert_op = BertOutput(cnfg)


sample_ip = torch.randint(low = 0, high = 100, size=(3, 10))
op = bert_emb(sample_ip, torch.zeros_like(sample_ip))
op_attn = bert_self_attn(hidden_states = op, output_attentions = True)
op = Bert_intermeidate(hidden_states = op_attn[0])
op = Bert_op(hidden_states = op, input_tensor = op_attn[0])
op.shape
gc.collect()

25772

In [148]:
class BertLayer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attention = BertAttention(cfg)
        self.intermediate = BertIntermediate(cfg)
        self.output = BertOutput(cfg)

    def forward(self, hidden_states, attention_mask=None, output_attentions = False):
        self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions)
        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        outputs = (layer_output,) + outputs if output_attentions else (layer_output,)
        return outputs

### Defining the BERT model

In [180]:
class BertEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layer = nn.ModuleList([BertLayer(cfg) for _ in range(cfg.n_layers)])

    def forward(self, hidden_states, attention_mask=None, output_attentions = False, output_hidden_states = False):        
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None

        for i, layer_module in enumerate(self.layer):
            print(f'>> Layer {i}')
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)

        outputs = (hidden_states,) 
        outputs = outputs + (all_hidden_states,) + (all_self_attentions,)
        return outputs  # last-layer hidden state, (all hidden states), (all attentions)


class my_BERT(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.embeddings = BertEmbeddings(cfg)
        self.encoder = BertEncoder(cfg)
    
    def forward(self, x, seg, attention_mask=None, output_attentions = False, output_hidden_states = False):

        embedding_output = self.embeddings(x, seg)
        encoded_output = self.encoder(hidden_states = embedding_output, 
                                      attention_mask = attention_mask, 
                                      output_attentions = output_attentions, 
                                      output_hidden_states = output_hidden_states)

        return encoded_output


In [181]:
config = Config()
MY_BERT = my_BERT(config)

In [182]:
my_BERT_params = set([n for n, _ in MY_BERT.named_parameters()])
BERT_params = set([n for n, _ in BERT.named_parameters()])

In [183]:
len(my_BERT_params.intersection(BERT_params))

197

In [184]:
temp = torch.randint(low = 0, high = 100, size = (4, 20))
encoded_temp = MY_BERT(x = temp, seg = torch.zeros_like(temp),  output_hidden_states = True)

>> Layer 0
>> Layer 1
>> Layer 2
>> Layer 3
>> Layer 4
>> Layer 5
>> Layer 6
>> Layer 7
>> Layer 8
>> Layer 9
>> Layer 10
>> Layer 11


In [154]:
a = (1,)
a + (2,) + (3,)

(1, 2, 3)

In [173]:
encoded_temp[1][-1]

tensor([[[-1.1848, -0.0285,  0.5436,  ..., -1.1377, -0.8266, -0.7126],
         [-0.0309, -0.4084, -0.8200,  ..., -0.6317, -0.9142,  0.2376],
         [-0.6651,  0.5134,  0.2074,  ..., -1.2302, -0.4554, -0.4232],
         ...,
         [ 0.4243, -0.3227,  0.7576,  ..., -0.7486, -0.8489,  1.0825],
         [ 0.5961, -1.0654, -0.2513,  ..., -0.9364,  0.6991, -0.7306],
         [ 0.6256,  0.7886,  0.6593,  ..., -1.5783, -0.8665,  0.0558]],

        [[-0.4661, -0.6714,  1.1090,  ..., -0.6477, -1.1112, -0.9664],
         [-0.7594,  0.1596,  0.3844,  ..., -0.9336, -1.4286,  1.4158],
         [-0.5134, -1.2912,  0.9980,  ..., -0.5096,  0.2018,  0.3360],
         ...,
         [ 0.8344,  1.5394,  1.4684,  ..., -0.9579, -1.7836,  0.9458],
         [-0.2102, -1.2844, -0.0779,  ..., -0.5074, -0.6555, -0.3627],
         [-0.0123,  0.3298,  1.0333,  ..., -1.3660, -2.0198, -0.5635]],

        [[ 0.0591, -0.3759,  1.0905,  ..., -1.9186, -0.8224, -0.3560],
         [ 1.1063, -0.4229, -1.0585,  ..., -0

In [174]:
encoded_temp[0]

tensor([[[-1.3521, -0.0554,  0.6890,  ..., -0.9465, -0.6662, -0.9000],
         [-0.3505, -0.3804, -0.7439,  ..., -0.2845, -0.9817, -0.2180],
         [-0.5933,  0.4646,  0.1405,  ..., -0.8407, -0.3823, -0.5760],
         ...,
         [ 0.2114, -0.4964,  0.8479,  ..., -0.4648, -0.8789,  0.7657],
         [ 0.2352, -1.0495, -0.1933,  ..., -0.6638,  0.6567, -0.8506],
         [ 0.7089,  0.6450,  0.7922,  ..., -1.2498, -0.9352, -0.1806]],

        [[-0.5991, -0.3532,  1.2220,  ..., -0.8466, -1.2140, -1.0271],
         [-0.9246,  0.1129,  0.7574,  ..., -0.7914, -1.4160,  1.1726],
         [-0.3398, -0.7966,  0.8751,  ..., -0.2994, -0.1088,  0.2824],
         ...,
         [ 0.6497,  1.7207,  1.7003,  ..., -0.6179, -1.8581,  0.8488],
         [-0.3679, -0.9265, -0.0716,  ..., -0.2223, -0.7588, -0.3392],
         [-0.3000,  0.5441,  1.2114,  ..., -1.1746, -1.9180, -0.8477]],

        [[-0.1487, -0.5375,  1.4424,  ..., -1.6875, -0.6888, -0.7684],
         [ 0.5754, -0.3451, -0.7638,  ..., -0