# Imports

In [2]:
# if running on Google colab
!pip install einops
import torch as t
from torch import einsum
from einops import rearrange, repeat, reduce
import math

from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/MyDrive/mlab/

!pip install transformers
!pip install torchtyping
import days.w2d1.bert_tests as bert_tests

# if running elsewhere, install dependencies (einops, transformers, torchyping), then:
"""
import torch as t
from torch import einsum
from einops import rearrange, repeat, reduce
import math
import bert_tests # this command might need to be fiddled with depending on where this file is stored
"""

Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1
Mounted at /content/gdrive
/content/gdrive/MyDrive/mlab
Collecting transformers
  Downloading transformers-4.19.1-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 4.4 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 36.0 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.6.0-py3-none-any.whl (84 kB)
[K     |████████████████████████████████| 84 kB 2.4 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 46.4 MB/s 
Installing collected packages: pyyaml, tokenizers, huggingface-hub, tra

# Part 1: Attention

In [3]:
# outputs pre-softmax attention scores 
# as a Tensor of shape [batch_size, num_heads, seq_length (key), seq_length (query)]
def raw_attention_scores(token_activations, num_heads, project_query, project_key):
  queries = rearrange(project_query(token_activations), 'b sl (nh hs) -> b nh sl hs', nh=num_heads)
  keys    = rearrange(project_key(token_activations),   'b sl (nh hs) -> b nh sl hs', nh=num_heads)
  head_size = queries.size(-1)
  return einsum('bhqi,bhki->bhkq', queries, keys) / math.sqrt(head_size)

bert_tests.test_attention_pattern_fn(raw_attention_scores)

attention pattern raw MATCH!!!!!!!!
 SHAPE (2, 12, 3, 3) MEAN: 0.006629 STD: 0.1046 VALS [0.04644 0.09279 -0.2193 0.05949 0.05956 0.1955 -0.09895 0.01574 -0.07148 -0.165...]


In [4]:
def bert_attention(token_activations, num_heads, attention_pattern, project_value, project_output):
  values = rearrange(project_value(token_activations), 'b sl (nh hs) -> b nh sl hs', nh=num_heads)
  attn_scores = attention_pattern.softmax(-2)
  attn = einsum('bhki,bhkq->bhqi', values, attn_scores)
  return project_output(rearrange(attn, 'b nh sl hs -> b sl (nh hs)'))

bert_tests.test_attention_fn(bert_attention)

attention MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: 0.001727 STD: 0.1128 VALS [0.08029 0.3392 0.0663 -0.1077 -0.009387 0.1234 0.1148 -0.2861 -0.07218 0.05894...]


In [5]:
from torch import nn

class MultiHeadedSelfAttention(nn.Module):
  def __init__(self, num_heads, hidden_size):
    super().__init__()
    self.head_size = 64
    self.num_heads = num_heads
    self.project_query = nn.Linear(hidden_size, num_heads * self.head_size)
    self.project_key   = nn.Linear(hidden_size, num_heads * self.head_size)
    self.project_value = nn.Linear(hidden_size, num_heads * self.head_size)
    self.project_output= nn.Linear(num_heads * self.head_size, hidden_size)

  def forward(self, input):
    raw_scores = raw_attention_scores(input, self.num_heads, self.project_query, self.project_key)
    return bert_attention(input, self.num_heads, raw_scores, self.project_value, self.project_output)

bert_tests.test_bert_attention(MultiHeadedSelfAttention)


bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.001554 STD: 0.1736 VALS [-0.08316 -0.09165 -0.03188 -0.03013 0.1001 0.09549 -0.1046 0.07742 0.0424 0.05553...]


# Part 2: Transformer Encoder block

In [6]:
from torch.nn.functional import gelu

def bert_mlp(token_activations, linear_1, linear_2):
  return linear_2(gelu(linear_1(token_activations)))

bert_tests.test_bert_mlp(bert_mlp)

bert mlp MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.0001934 STD: 0.1044 VALS [-0.1153 0.1189 -0.0813 0.1021 0.0296 0.06182 0.0341 0.1446 0.2622 -0.08507...]


In [7]:
class BertMLP(nn.Module):
  def __init__(self, input_size, intermediate_size):
    super().__init__()
    self.linear_1 = nn.Linear(input_size, intermediate_size)
    self.linear_2 = nn.Linear(intermediate_size, input_size)

  def forward(self, input):
    return bert_mlp(input, self.linear_1, self.linear_2)

In [57]:
class LayerNorm(nn.Module):
  def __init__(self, normalized_dim):
    super().__init__()
    self.weight = nn.Parameter(t.ones(normalized_dim))
    self.bias   = nn.Parameter(t.zeros(normalized_dim))

  def forward(self, input):
    input = input - input.mean(-1, keepdim=True)
    input = input / (input.var(-1, keepdim=True, unbiased=False) + 1e-5).sqrt()
    return input * self.weight + self.bias

bert_tests.test_layer_norm(LayerNorm)


layer norm MATCH!!!!!!!!
 SHAPE (20, 10) MEAN: -4.768e-09 STD: 1.003 VALS [1.126 0.6667 -0.174 1.782 -0.9279 -1.816 -0.578 0.5947 -0.2722 -0.4015...]


In [58]:
class BertBlock(nn.Module):
  def __init__(self, hidden_size, intermediate_size, num_heads, dropout):
    super().__init__()
    self.attention = MultiHeadedSelfAttention(num_heads, hidden_size)
    self.layer_norm1 = LayerNorm(hidden_size)
    self.mlp = BertMLP(hidden_size, intermediate_size)
    self.dropout = nn.Dropout(dropout)
    self.layer_norm2 = LayerNorm(hidden_size)

  def forward(self, input):
    post_attn = self.layer_norm1(input + self.attention(input))
    return self.layer_norm2(post_attn + self.dropout(self.mlp(post_attn)))

bert_tests.test_bert_block(BertBlock)


bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: 1.656e-09 STD: 1 VALS [0.007132 -0.04372 0.6502 -0.5972 -1.097 0.7267 0.1275 -0.6035 -0.2226 0.2145...]


# Part 3: BERT Embedding

In [59]:
class Embedding(nn.Module):
  def __init__(self, vocab_size, embed_size):
    super().__init__()
    self.emb_matrix = nn.Parameter(t.randn(vocab_size, embed_size))

  def forward(self, input):
    return self.emb_matrix[input]

bert_tests.test_embedding(Embedding)

embedding MATCH!!!!!!!!
 SHAPE (2, 3, 5) MEAN: -0.06748 STD: 1.062 VALS [1.176 -0.1914 0.8212 1.047 -0.481 0.7106 -1.304 -1.307 -0.438 -0.2764...]


In [60]:
def bert_embedding(
    input_ids,      # : [batch, seqlen]
    token_type_ids, # : [batch, seqlen]
    position_embedding,   # : Embedding
    token_embedding,      # : Embedding
    token_type_embedding, # : Embedding, 
    layer_norm, # : LayerNorm, 
    dropout     # : nn.Dropout
):
  seqlen = input_ids.size(1)
  positions = t.arange(0, seqlen, device=input_ids.device)
  emb = token_embedding(input_ids) + token_type_embedding(token_type_ids) + position_embedding(positions)
  return layer_norm(dropout(emb))

bert_tests.test_bert_embedding_fn(bert_embedding)

bert embedding MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: 8.278e-10 STD: 1 VALS [-1.319 -0.4378 -2.074 0.9679 0.9274 1.479 -0.501 -1.9 -0.212 0.7961...]


In [61]:
class BertEmbedding(nn.Module):
  def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout):
    super().__init__()
    self.token_embedding      = Embedding(vocab_size, hidden_size)
    self.position_embedding   = Embedding(max_position_embeddings, hidden_size)
    self.token_type_embedding = Embedding(type_vocab_size, hidden_size)
    self.layer_norm = LayerNorm(hidden_size)
    self.dropout = nn.Dropout(dropout)

  def forward(self, input_ids, token_type_ids):
    return bert_embedding(
        input_ids, 
        token_type_ids,
        self.position_embedding,
        self.token_embedding,
        self.token_type_embedding,
        self.layer_norm,
        self.dropout)
    
bert_tests.test_bert_embedding(BertEmbedding)

bert embedding MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: 1.242e-09 STD: 1 VALS [-0.009385 -0.4919 0.9852 -0.3535 -3.624 1.333 1.163 1.449 1.063 0.246...]


# Part 4: Putting it all together

In [62]:
class Bert(nn.Module):
  def __init__(
      self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, 
      dropout, intermediate_size, num_heads, num_layers: int
):
    super().__init__()
    self.embedding = BertEmbedding(
        vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout)
    self.transformer = nn.Sequential(
        *[BertBlock(hidden_size, intermediate_size, num_heads, dropout) for _ in range(num_layers)])
    self.linear = nn.Linear(hidden_size, hidden_size)
    self.layer_norm = LayerNorm(hidden_size)
    self.unembed = nn.Linear(hidden_size, vocab_size)

  def forward(self, input_ids):
    token_type_ids = t.zeros(*input_ids.shape, dtype=int, device=input_ids.device)
    return self.unembed(self.layer_norm(gelu(self.linear(self.transformer(self.embedding(input_ids, token_type_ids))))))

bert_tests.test_bert(Bert)

bert MATCH!!!!!!!!
 SHAPE (1, 4, 28996) MEAN: 0.003031 STD: 0.5765 VALS [-0.5742 -0.432 0.1186 -0.7165 -0.5261 0.4967 1.223 0.3165 -0.3247 -0.5716...]


# Step 5: Load pretrained weights

In [63]:
my_bert = Bert(
    vocab_size=28996, hidden_size=768, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=3072, 
    num_heads=12, num_layers=12
)
pretrained_bert = bert_tests.get_pretrained_bert()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [64]:
import re
def mapkey(k):
  k = k.replace('_embedding.weight', '_embedding.emb_matrix')
  k = k.replace('.pattern', '')
  k = k.replace('out', 'output')
  k = re.sub(r'(?<!(dual\.|ding\.))layer_norm', 'layer_norm1', k)
  k = re.sub(r'residual\.mlp(?=[1-9])', 'mlp.linear_', k)
  k = re.sub(r'residual\.layer_norm', 'layer_norm2', k)
  k = k.replace('lm_head.mlp', 'linear')
  k = k.replace('lm_head.layer_norm1', 'layer_norm')
  k = k.replace('lm_head.unembedding', 'unembed')
  k = re.sub(r'classification.*', '', k)
  return k

for k in pretrained_bert.state_dict(): 
  if mapkey(k) not in my_bert.state_dict(): print(k)

# should only display the classification heads

classification_head.weight
classification_head.bias


In [65]:
load_dict = {}
for k,v in pretrained_bert.state_dict().items():
  load_dict[mapkey(k)] = v
load_dict.pop('') # get rid of the data for the classification heads
my_bert.load_state_dict(load_dict)

<All keys matched successfully>

In [66]:
bert_tests.test_same_output(my_bert, pretrained_bert, tol=1e-4)

comparing Berts MATCH!!!!!!!!
 SHAPE (10, 20, 28996) MEAN: -2.732 STD: 2.413 VALS [-5.65 -6.041 -6.096 -6.062 -5.946 -5.777 -5.977 -6.015 -6.028 -5.935...]
