In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import subprocess
from models import *
from utils import *
from parse import *
import random
from copy import deepcopy as cc
from transformers import BertTokenizer, BertModel
import gc
from torch import nn
from torch.nn import GELU as GELU
import argparse
from solver import Solver
gc.disable()
gc.collect()

8

In [36]:
class Config():
    "Configuration for BERT model"
    vocab_size: int = 30522 # Size of Vocabulary
    dim: int = 768 # Dimension of Hidden Layer in Transformer Encoder
    n_layers: int = 12 # Numher of Hidden Layers
    n_heads: int = 12 # Numher of Heads in Multi-Headed Attention Layers
    dim_ff: int = 768*4 # Dimension of Intermediate Layers in Positionwise Feedforward Net
    activ_fn: str = "gelu" # Non-linear Activation Function Type in Hidden Layers
    p_drop_hidden: float = 0.1 # Probability of Dropout of various Hidden Layers
    p_drop_attn: float = 0.1 # Probability of Dropout of Attention Layers
    max_len: int = 512 # Maximum Length for Positional Embeddings
    n_segments: int = 2 # Number of Sentence Segments
    layer_norm_eps: int = 1e-12 # eps value for the LayerNorms
    output_attentions : bool = False # Weather to output the attention scores

config = Config()

+ #  <u>Embedding Layer</u>

In [37]:
class BertEmbeddings(nn.Module):
    "The embedding module from word, position and token_type embeddings."
    def __init__(self, cfg):
        super().__init__()
        self.word_embeddings = nn.Embedding(cfg.vocab_size, cfg.dim, padding_idx=0) # token embedding
        self.position_embeddings = nn.Embedding(cfg.max_len, cfg.dim) # position embedding
        self.token_type_embeddings = nn.Embedding(cfg.n_segments, cfg.dim) # segment(token type) embedding

        self.LayerNorm = nn.LayerNorm(cfg.dim, eps=cfg.layer_norm_eps)
        self.dropout = nn.Dropout(cfg.p_drop_hidden)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
        pos = pos.unsqueeze(0).expand_as(x) # (S,) -> (B, S)

        e = self.word_embeddings(x) + self.position_embeddings(pos) + self.token_type_embeddings(seg)
        e = self.LayerNorm(e)
        e = self.dropout(e)
        return e

* # <u>Self Attention Layer</u>

In [38]:

# X -> self_attn -> X [calculated by self attention]
class BertSelfAttention(nn.Module):
    """ Multi-Headed Dot Product Attention """
    def __init__(self, cfg):
        super().__init__()
        self.query = nn.Linear(cfg.dim, cfg.dim)
        self.key = nn.Linear(cfg.dim, cfg.dim)
        self.value = nn.Linear(cfg.dim, cfg.dim)
        self.dropout = nn.Dropout(cfg.p_drop_attn)
        self.n_heads = cfg.n_heads

    def forward(self, x, attention_mask = None, output_attentions = False):
        """
        x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
        mask : (B(batch_size) x S(seq_len))
        * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
        attention_mask_dim = B x S (dim S is for every value in the "x"), so need to repeat it for every query
        """
        B, S, D = x.shape
        H = self.n_heads
        W = int( D/H )
        assert W * H == D

        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q, k, v = self.query(x), self.key(x), self.value(x)
        q, k, v = q.reshape((B, S, H, W)), k.reshape((B, S, H, W)), v.reshape((B, S, H, W))
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        # (B, H, S(Q), W) @ (B, H, W, S(K/V)) -> (B, H, S(Q), S(K/V)) -Masking -> softmax-> (B, H, S(Q), S(K/V))
        attn_scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        
        # Set the attention score at places where MASK = 0 to very low value (-1e9)
        if attention_mask is not None:
            attn_scores = attn_scores.masked_fill(attention_mask[:, None, None, :] == 0, -1e9)
        attn_scores = self.dropout(F.softmax(attn_scores, dim=-1))

        # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) -merge-> (B, S, D)
        hidden_states = (attn_scores @ v).transpose(1, 2).contiguous()
        hidden_states = hidden_states.reshape(B, S, D)

        result = {}
        result['hidden_states'] = hidden_states 
        result['attn_scores'] =  None
        if output_attentions :
            result['attn_scores'] =  attn_scores

        return result


In [39]:
#  f(X) -> Linear (D => D) -> Dropout -> X1
#  X, f(X) -> LayerNorm( X + X1 )
class BertSelfOutput(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.dense = nn.Linear(cfg.dim, cfg.dim)
        self.LayerNorm = nn.LayerNorm(cfg.dim, eps= cfg.layer_norm_eps)
        self.dropout = nn.Dropout(cfg.p_drop_hidden)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

In [49]:
# X -> self_attn -> f(X) -> Linear (D => D) -> Dropout -> X1
# Y = LayerNorm( X + X1 )
class BertAttention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.self = BertSelfAttention(cfg)
        self.output = BertSelfOutput(cfg)

    def forward(self, hidden_states, attention_mask=None, output_attentions = False):
        self_outputs = self.self(hidden_states,  attention_mask, output_attentions)
        attention_output = self.output(self_outputs['hidden_states'], hidden_states)

        outputs = {'hidden_states': attention_output, 'attn_scores': self_outputs['attn_scores']}
        return outputs        

* # <u>Forward Expansion layer</u>

In [50]:
# f(X) -> Linear(D => 4xD) -> GELU() -> X2
class BertIntermediate(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.dense = nn.Linear(cfg.dim, cfg.dim_ff)
        self.GELU = nn.GELU()

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.GELU(hidden_states)
        return hidden_states

# f(X) -> Linear(4xD => D) -> Drouput -> X1
#  X, f(X) -> LayerNorm( X + X1 )
class BertOutput(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.dense = nn.Linear(cfg.dim_ff, cfg.dim)
        self.LayerNorm = nn.LayerNorm(cfg.dim, eps=cfg.layer_norm_eps)
        self.dropout = nn.Dropout(cfg.p_drop_hidden)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

* # <u>Defining single BERT-encoder layer</u>

In [51]:
class BertLayer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attention = BertAttention(cfg)
        self.intermediate = BertIntermediate(cfg)
        self.output = BertOutput(cfg)

    def forward(self, hidden_states, attention_mask=None, output_attentions = False):
        self_attention_outputs = self.attention(hidden_states, attention_mask, output_attentions)

        intermediate_output = self.intermediate(self_attention_outputs['hidden_states'])
        layer_output = self.output(intermediate_output, self_attention_outputs['hidden_states'])

        outputs = {'hidden_states':layer_output, 'attn_scores': self_attention_outputs['attn_scores']}
        return outputs

* # <u>Defining the BERT model</u>

In [52]:
class BertEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layer = nn.ModuleList([BertLayer(cfg) for _ in range(cfg.n_layers)])

    def forward(self, hidden_states, attention_mask=None, output_attentions = False, output_hidden_states = False):        
        all_hidden_states = [] if output_hidden_states else None
        all_self_attentions = [] if output_attentions else None

        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states.append(hidden_states)

            layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
            hidden_states, attn_scores = layer_outputs['hidden_states'], layer_outputs['attn_scores']

            if output_attentions:
                all_self_attentions.append(attn_scores)

        outputs = {'hidden_states': hidden_states, 'attn_scores': all_self_attentions, 'all_hidden_states': all_hidden_states}
        return outputs 


class my_BERT(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.embeddings = BertEmbeddings(cfg)
        self.encoder = BertEncoder(cfg)
    
    def forward(self, x, seg, attention_mask=None, output_attentions = False, output_hidden_states = False):

        embedding_output = self.embeddings(x, seg)
        encoded_output = self.encoder(hidden_states = embedding_output, 
                                      attention_mask = attention_mask, 
                                      output_attentions = output_attentions, 
                                      output_hidden_states = output_hidden_states)

        return encoded_output

In [53]:
config = Config()
MY_BERT = my_BERT(config)
BERT = BertModel.from_pretrained('bert-base-uncased')
MY_BERT.load_state_dict(BERT.state_dict(), strict = False)

In [74]:
MY_BERT_params = set([n for n, _ in MY_BERT.named_parameters()])
BERT_params = set([n for n, _ in BERT.named_parameters()])
print(f'Parametes in the model that are not in BERT:: {BERT_params.difference(MY_BERT_params)}')

Parametes in the model that are not in BERT:: {'pooler.dense.weight', 'pooler.dense.bias'}


In [73]:
temp = torch.randint(low = 0, high = 100, size = (4, 20))
encoded_temp = MY_BERT(x = temp, seg = torch.zeros_like(temp),  attention_mask = torch.ones_like(temp),
                       output_hidden_states = True, output_attentions = True)

12