In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append("../")
from models.bert import BERT 
from copy import deepcopy

### Load Bert-Base-Uncased Weights dumped as a Dictionary

In [2]:
bert_pretrained_model = torch.load("../../../torch_dump")

### Instantiate a dummy class to call our BERT Model to replace the config file

In [3]:
class Conf():
    def __init__(self):
        self.hidden_features = 768
        self.layers = 12
        self.heads = 12
        self.device = 'cpu'
        self.dropout = 0.1
        self.lr = 1e-4
        self.adam_beta1=0.999
        self.adam_beta2 =0.999
        self.adam_weight_decay = 1e-5
        self.warmup_steps = 10000

In [4]:
config = Conf()

### Instantiate our BERT Model 
- Ideally want for BERTLM but the same procedure will work.

In [5]:
bert_model = BERT(config, vocab_size=30522)

In [6]:
param_size_default = []
for key, value in bert_pretrained_model.items():
    param_size_default.append(value.shape)

In [7]:
param_size = []
for param in bert_model.parameters():
    param_size.append(param.shape)

In [8]:
param_size_default

[torch.Size([1, 512]),
 torch.Size([30522, 768]),
 torch.Size([512, 768]),
 torch.Size([2, 768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([3072, 768]),
 torch.Size([3072]),
 torch.Size([768, 3072]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([3072, 768]),
 torch.Size([3072]),
 torch.Size([768, 3072]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768

In [9]:
param_size

[torch.Size([30522, 768]),
 torch.Size([3, 768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([3072, 768]),
 torch.Size([3072]),
 torch.Size([768, 3072]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([3072, 768]),
 torch.Size([3072]),
 torch.Size([768, 3072]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([768, 768]),
 torch.Size([768]),
 torch.Size([3072, 768]),
 torch.Size([3072]),
 torch.Size([768,

We see some minor disagreements especially with regard to initial embeddings. We will ignore it for the time being and proceed.

 Observe that the weights are stored Layerwise for the transformer modules
 
 We will just swap these weights around with the weights of our network at the correct locations.

In [10]:
for key in bert_pretrained_model:
    print(key)

bert.embeddings.position_ids
bert.embeddings.word_embeddings.weight
bert.embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight
bert.embeddings.LayerNorm.weight
bert.embeddings.LayerNorm.bias
bert.encoder.layer.0.attention.self.query.weight
bert.encoder.layer.0.attention.self.query.bias
bert.encoder.layer.0.attention.self.key.weight
bert.encoder.layer.0.attention.self.key.bias
bert.encoder.layer.0.attention.self.value.weight
bert.encoder.layer.0.attention.self.value.bias
bert.encoder.layer.0.attention.output.dense.weight
bert.encoder.layer.0.attention.output.dense.bias
bert.encoder.layer.0.attention.output.LayerNorm.weight
bert.encoder.layer.0.attention.output.LayerNorm.bias
bert.encoder.layer.0.intermediate.dense.weight
bert.encoder.layer.0.intermediate.dense.bias
bert.encoder.layer.0.output.dense.weight
bert.encoder.layer.0.output.dense.bias
bert.encoder.layer.0.output.LayerNorm.weight
bert.encoder.layer.0.output.LayerNorm.bias
bert.encoder.layer.1.attenti

### Our Bert Model

In [11]:
bert_model

BERT(
  (embedding): BERTEmbedding(
    (token): TokenEmbedding(30522, 768, padding_idx=0)
    (position): PositionalEmbedding()
    (segment): SegmentEmbedding(3, 768, padding_idx=0)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): ModuleList(
    (0): TransformerBlock(
      (attention): MultiHeadedAttention(
        (linear_layers): ModuleList(
          (0): Linear(in_features=768, out_features=768, bias=True)
          (1): Linear(in_features=768, out_features=768, bias=True)
          (2): Linear(in_features=768, out_features=768, bias=True)
        )
        (output_linear): Linear(in_features=768, out_features=768, bias=True)
        (attention): Attention()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (feed_forward): PositionwiseFeedForward(
        (w_1): Linear(in_features=768, out_features=3072, bias=True)
        (w_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (a

In [12]:
for name, params in bert_model.named_parameters():
    print(name)

embedding.token.weight
embedding.segment.weight
transformer_blocks.0.attention.linear_layers.0.weight
transformer_blocks.0.attention.linear_layers.0.bias
transformer_blocks.0.attention.linear_layers.1.weight
transformer_blocks.0.attention.linear_layers.1.bias
transformer_blocks.0.attention.linear_layers.2.weight
transformer_blocks.0.attention.linear_layers.2.bias
transformer_blocks.0.attention.output_linear.weight
transformer_blocks.0.attention.output_linear.bias
transformer_blocks.0.feed_forward.w_1.weight
transformer_blocks.0.feed_forward.w_1.bias
transformer_blocks.0.feed_forward.w_2.weight
transformer_blocks.0.feed_forward.w_2.bias
transformer_blocks.0.input_sublayer.norm.a_2
transformer_blocks.0.input_sublayer.norm.b_2
transformer_blocks.0.output_sublayer.norm.a_2
transformer_blocks.0.output_sublayer.norm.b_2
transformer_blocks.1.attention.linear_layers.0.weight
transformer_blocks.1.attention.linear_layers.0.bias
transformer_blocks.1.attention.linear_layers.1.weight
transformer_bl

Notice that Transformer Block Weights are what we will swap around.

Manual Mapping Done by Me after looking closely into our Transformer Code and comparing it with the Pre-Trained Key Values

In [13]:
mapping = {
    'attention.self.query.weight':'attention.linear_layers.0.weight',
    'attention.self.query.bias':'attention.linear_layers.0.bias',
    'attention.self.key.weight':'attention.linear_layers.1.weight',
    'attention.self.key.bias':'attention.linear_layers.1.bias',
    'attention.self.value.weight':'attention.linear_layers.2.weight',
    'attention.self.value.bias':'attention.linear_layers.2.bias',
    'attention.output.dense.weight':'attention.output_linear.weight',
    'attention.output.dense.bias':'attention.output_linear.bias',
    'attention.output.LayerNorm.weight':'input_sublayer.norm.a_2',
    'attention.output.LayerNorm.bias': 'input_sublayer.norm.b_2',
    'intermediate.dense.weight':'feed_forward.w_1.weight',
    'intermediate.dense.bias':'feed_forward.w_1.bias',
    'output.dense.weight':'feed_forward.w_2.weight',
    'output.dense.bias':'feed_forward.w_2.bias',
    'output.LayerNorm.weight':'output_sublayer.norm.a_2',
    'output.LayerNorm.bias':'output_sublayer.norm.b_2',
}

In [14]:
inv_mapping = {}
for key, value in mapping.items():
    inv_mapping[value] = key

### Set the State Dictionary

- To replace weights, we will create a new state-dictionary with the BERT weights and then load that State Dictionary for our model

In [15]:
dic = deepcopy(bert_model.state_dict())

In [16]:
cnt = 0
for layer in range(12):
    # We have 12 transformer layers, iterate through them one by one
    for name, p_val in bert_model.transformer_blocks[layer].named_parameters():
        # Iterate through each transformer back one by one, name is name of the parameter (refer to mapping)
        # p_val is the value of the parameter --> We want to change this value :)
        to_copy = f'bert.encoder.layer.{layer}.' + inv_mapping[name]
        # to_copy is the name of the same parameter in the pre-trained BERT model, obtained by invert_map 
        # refer to inv_mapping above obtained by swapping keys and values of mapping
        # I first created mapping but later realized we needed inverse mapping and not mapping per say
        param_to_copy = bert_pretrained_model[to_copy]
        # Obtain the parameter to copy by indexing into the Dictionary that stores the weights from PT BERT
        dic[f'transformer_blocks.{layer}.' + name] = param_to_copy
        # Set the value of the parameter in this state dictionary
        assert p_val.shape == param_to_copy.shape
        print(f"Layer: {layer}, {name}, {p_val.shape}, \n\t\t\t {to_copy},  {param_to_copy.shape} \n\n")
        # Log 
        cnt+=1

Layer: 0, attention.linear_layers.0.weight, torch.Size([768, 768]), 
			 bert.encoder.layer.0.attention.self.query.weight,  torch.Size([768, 768]) 


Layer: 0, attention.linear_layers.0.bias, torch.Size([768]), 
			 bert.encoder.layer.0.attention.self.query.bias,  torch.Size([768]) 


Layer: 0, attention.linear_layers.1.weight, torch.Size([768, 768]), 
			 bert.encoder.layer.0.attention.self.key.weight,  torch.Size([768, 768]) 


Layer: 0, attention.linear_layers.1.bias, torch.Size([768]), 
			 bert.encoder.layer.0.attention.self.key.bias,  torch.Size([768]) 


Layer: 0, attention.linear_layers.2.weight, torch.Size([768, 768]), 
			 bert.encoder.layer.0.attention.self.value.weight,  torch.Size([768, 768]) 


Layer: 0, attention.linear_layers.2.bias, torch.Size([768]), 
			 bert.encoder.layer.0.attention.self.value.bias,  torch.Size([768]) 


Layer: 0, attention.output_linear.weight, torch.Size([768, 768]), 
			 bert.encoder.layer.0.attention.output.dense.weight,  torch.Size([768, 768])

### Obtain New Model

In [17]:
bert_model.load_state_dict(dic)

<All keys matched successfully>

In [18]:
torch.save(bert_model, "./BERT_with_PT_weights.bin")

In [19]:
new_dict = bert_model.state_dict()