In [5]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import sys
%load_ext autoreload
%autoreload 2
import numpy as np

In [6]:
sys.path.append("../")
from models.bert import BERTLM

In [22]:
class config():
    def __init__(self):
        self.vocab = "bert-google"
        self.vocab_path = "data/wikitext2/all.txt"
        self.bert_google_vocab = "uncased_L-12_H-768_A-12/vocab.txt"
        self.vocab_max_size = None
        self.vocab_min_frequency = 1
        self.dataset = "wikitext2"
        self.seq_len = 40
        self.on_memory = True
        self.corpus_lines = None
        self.train_dataset = "data/wikitext2/all.txt"
        self.encoding = "utf-8"
        self.batch_size = 1
        self.num_workers = 1
        self.hidden_features = 768
        self.layers = 12
        self.heads = 12
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.dropout = 0.1
        self.train = True
        self.lr = 1e-3
        self.adam_beta1=0.999
        self.adam_beta2=0.999
        self.adam_weight_decay = 0.01
        self.warmup_steps =1000
        self.storage_directory = "/Users/raphaelwinkler/PycharmProjects/simplifying-transformers"
        self.model = "BERTLM"

In [23]:
conf = config()

In [5]:
bert_ml = BERTLM(conf, 30522)

### Load Pretrained Weights

In [7]:
pt_model = torch.load("torch_dump_model")

In [8]:
mlm_rel_params = {}
for name, param in pt_model.items():
    if "pooler" in name or "seq_relationship" in name:
        continue
    else:
        mlm_rel_params[name] = param

In [9]:
for name in mlm_rel_params:
    if 'embedding' in name:
        print(name)

bert.embeddings.position_ids
bert.embeddings.word_embeddings.weight
bert.embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight
bert.embeddings.LayerNorm.weight
bert.embeddings.LayerNorm.bias


In [10]:
from copy import deepcopy

### Set Embeddings

In [11]:
dic = deepcopy(bert_ml.state_dict())

In [12]:
dic['bert.embedding.position.pe'][0] = deepcopy(pt_model['bert.embeddings.position_embeddings.weight'])

In [12]:
dic['bert.embedding.token.weight'] = deepcopy(pt_model['bert.embeddings.word_embeddings.weight'])

In [13]:
dic['bert.embedding.segment.weight'] = deepcopy(pt_model['bert.embeddings.token_type_embeddings.weight'])

In [14]:
dic['bert.embedding.layer_norm.a_2'] = deepcopy(pt_model['bert.embeddings.LayerNorm.weight'])

In [15]:
dic['bert.embedding.layer_norm.b_2'] = deepcopy(pt_model['bert.embeddings.LayerNorm.bias'])

In [16]:
mapping = {
    'attention.self.query.weight':'attention.linear_layers.0.weight',
    'attention.self.query.bias':'attention.linear_layers.0.bias',
    'attention.self.key.weight':'attention.linear_layers.1.weight',
    'attention.self.key.bias':'attention.linear_layers.1.bias',
    'attention.self.value.weight':'attention.linear_layers.2.weight',
    'attention.self.value.bias':'attention.linear_layers.2.bias',
    'attention.output.dense.weight':'attention.output_linear.weight',
    'attention.output.dense.bias':'attention.output_linear.bias',
    'attention.output.LayerNorm.weight':'input_sublayer.norm.a_2',
    'attention.output.LayerNorm.bias': 'input_sublayer.norm.b_2',
    'intermediate.dense.weight':'feed_forward.w_1.weight',
    'intermediate.dense.bias':'feed_forward.w_1.bias',
    'output.dense.weight':'feed_forward.w_2.weight',
    'output.dense.bias':'feed_forward.w_2.bias',
    'output.LayerNorm.weight':'output_sublayer.norm.a_2',
    'output.LayerNorm.bias':'output_sublayer.norm.b_2',
}

In [17]:
inv_mapping = {}
for key, value in mapping.items():
    inv_mapping[value] = key

In [18]:
len(bert_ml.state_dict())

203

In [19]:
cnt = 0
for layer in range(12):
    # We have 12 transformer layers, iterate through them one by one
    for name, p_val in bert_ml.bert.transformer_blocks[layer].named_parameters():
        to_copy = f'bert.encoder.layer.{layer}.' + inv_mapping[name]
        param_to_copy = deepcopy(pt_model[to_copy])
        dic[f'bert.transformer_blocks.{layer}.' + name] = param_to_copy
        assert p_val.shape == param_to_copy.shape
        cnt+=1

In [20]:
cnt

192

### Set Last Layers

In [21]:
dic['mask_lm.linear.weight'] = deepcopy(pt_model['cls.predictions.transform.dense.weight'])
dic['mask_lm.linear.bias'] = deepcopy(pt_model['cls.predictions.transform.dense.bias'])
dic['mask_lm.decoder.weight'] = deepcopy(pt_model['cls.predictions.decoder.weight'])
dic['mask_lm.decoder.bias'] = deepcopy(pt_model['cls.predictions.decoder.bias'])
dic['mask_lm.layer_norm.a_2'] = deepcopy(pt_model['cls.predictions.transform.LayerNorm.weight'])
dic['mask_lm.layer_norm.b_2'] = deepcopy(pt_model['cls.predictions.transform.LayerNorm.bias'])

In [22]:
bert_ml.load_state_dict(dic)

<All keys matched successfully>

In [23]:
bert_ml.save_model(running=True)

In [24]:
bert_ml.eval()

BERTLM(
  (bert): BERT(
    (embedding): BERTEmbedding(
      (token): TokenEmbedding(30522, 768, padding_idx=0)
      (position): PositionalEmbedding()
      (segment): SegmentEmbedding(2, 768, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
      (layer_norm): LayerNorm()
    )
    (transformer_blocks): ModuleList(
      (0): TransformerBlock(
        (attention): MultiHeadedAttention(
          (linear_layers): ModuleList(
            (0): Linear(in_features=768, out_features=768, bias=True)
            (1): Linear(in_features=768, out_features=768, bias=True)
            (2): Linear(in_features=768, out_features=768, bias=True)
          )
          (output_linear): Linear(in_features=768, out_features=768, bias=True)
          (attention): Attention()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=768, out_features=3072, bias=True)
          (w_2): Linear(in_features=30

In [2]:
from parent_bert import get_pretrained_berd

bert_ml = get_pretrained_berd()


### Comparison Time

### Reading the Warning, it's clear it's discarding NSP layers!

In [3]:
from transformers import BertForMaskedLM
ml_model = BertForMaskedLM.from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
ml_model.eval()

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [7]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text = "Today is a cold and [MASK] day."
encoded_input = tokenizer(text, return_tensors='pt')

In [8]:
op = ml_model(**encoded_input)[0][0]

#### This is Hugging Face OP

In [9]:
for i in range(10):
    print(tokenizer.ids_to_tokens[op[i].argmax().item()])

.
today
is
a
cold
and
rainy
day
.
.


### They use Segment Mask 0 , Just changing to 1 doesn't change OPs 

We need this as we use mask using 0 index!

In [10]:
encoded_input['token_type_ids']+=1

In [31]:
op = ml_model(**encoded_input)[0][0]

In [11]:
for i in range(10):
    print(tokenizer.ids_to_tokens[op[i].argmax().item()])

.
today
is
a
cold
and
rainy
day
.
.


### TESTING WITH OUR MODEL

In [17]:
new_data = {}
bert_input = [[0 for i in range(10)]]
for i in range(10):
    bert_input[0][i] = encoded_input['input_ids'][0][i].item()
segment_label = [[0 for i in range(10)]]
for i in range(10):
    segment_label[0][i] = encoded_input['token_type_ids'][0][i].item()

new_data['bert_input'] = torch.tensor(bert_input).int()
new_data['segment_label'] = torch.tensor(segment_label).int()

In [18]:
op_our = bert_ml(new_data['bert_input'], new_data['segment_label'])[0][:10]

## We can see outputs are close!

In [19]:
op

tensor([[ -6.4448,  -6.3908,  -6.4175,  ...,  -5.8137,  -5.6503,  -3.9097],
        [-14.2450, -14.5067, -14.0750,  ..., -12.1318, -12.2131, -13.8554],
        [-10.1342,  -9.8525,  -9.9614,  ...,  -8.5781,  -6.2810,  -5.1799],
        ...,
        [-13.8870, -14.3584, -13.9323,  ..., -11.1448, -10.5272, -11.7787],
        [-13.3764, -12.8394, -13.1806,  ..., -10.7812, -11.3668, -11.1040],
        [-19.0271, -19.0210, -19.0633,  ..., -17.7467, -15.8703, -12.1800]],
       grad_fn=<SelectBackward0>)

In [20]:
op_our

tensor([[ -6.5247,  -6.4878,  -6.5218,  ...,  -5.9450,  -5.7974,  -3.9458],
        [-14.1594, -14.3849, -14.0029,  ..., -12.1670, -11.9847, -14.2297],
        [ -8.6822,  -8.4283,  -8.5070,  ...,  -8.0076,  -5.6218,  -4.0874],
        ...,
        [-13.1808, -13.5999, -13.2546,  ..., -11.0337,  -9.9279, -10.9528],
        [-11.6317, -11.2551, -11.4205,  ...,  -8.9649,  -9.4258, -10.4603],
        [-14.4416, -14.7170, -14.5038,  ..., -13.1296, -11.8556, -10.9936]],
       grad_fn=<SliceBackward0>)

In [25]:
from datasets.vocabulary import BertVocab
import datasets

vocab = BertVocab(conf)

# vocab.pad_index

# load the dataset specified with --dataset_name & get data loaders
train_dataset = datasets.get(dataset_name="wikitext2")(config=conf, vocab=vocab)

train_loader = train_dataset.get_data_loader()

Using Bert Vocab


98856it [00:00, 144997.86it/s]
30522it [00:00, 1131455.62it/s]
Loading Dataset: 98856it [00:00, 428726.94it/s]


## Use BERT Tokenizer and our Own to confirm!

In [26]:
for i in range(10):
    print(tokenizer.ids_to_tokens[op_our[i].argmax().item()])

.
today
is
a
cold
and
rainy
day
.
.


In [27]:
for i in range(10):
    print(vocab.itos[op_our[i].argmax().item()])

.
today
is
a
cold
and
rainy
day
.
.


## Good! Outputs match for this test case :)  

Now let us use and check if this holds even for WikiText2 Dataset :)

In [28]:
for data in train_loader:
    break

In [29]:
data

{'bert_input': tensor([[  101,   100,  2053,   100,  1017,  1024,   103, 11906,  1006,  2887,
           1024,   100,  1010,  5507,  1012,   102,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]),
 'bert_label': tensor([[  0,   0,   0,   0,   0,   0, 100,   0,   0,   0,   0,   0,   0,   0,
            0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
            0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0]]),
 'segment_label': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'mask_index': tensor([6])}

In [30]:
data['bert_label'][0].max().item()

100

In [31]:
ix = 0
for ix, data in enumerate(train_loader):
    ix = ix + 1
    op_our = bert_ml(data['bert_input'], data['segment_label'])[0]
    mask_index = data['mask_index'][0].item()
    ixs = np.where(data['bert_input'][0].numpy()==0)[0]
    n = len(ixs)
    print("==========================")
    print("Input Sentence:\n")
    str_vals = []
    for elem in data['bert_input'][0][1:-1]:
        if elem.item() == 0:
            break
        str_vals.append(vocab.itos[elem.item()])
    print(" ".join(str_vals))
    print("\n")
    print(f"\nMask Prediction for input {ix}: ", vocab.itos[op_our[mask_index].argmax().item()])
    print(f"True Mask for above input: {vocab.itos[data['bert_label'][0].max().item()]}")
    if ix == 20:
        break
    print("\n\n")

Input Sentence:

[UNK] no [UNK] 3 : [UNK] chronicles [MASK] japanese : [UNK] , lit . [SEP]



Mask Prediction for input 1:  (
True Mask for above input: (



Input Sentence:

[UNK] of the battlefield 3 ) , commonly referred to as [UNK] chronicles iii outside [MASK] , is a tactical role [UNK] playing video game developed by sega and [UNK] for the playstation portable . [SEP]



Mask Prediction for input 2:  japan
True Mask for above input: japan



Input Sentence:

released in january 2011 in japan , it is the third game in the [MASK] series . [SEP]



Mask Prediction for input 3:  pokemon
True Mask for above input: [UNK]



Input Sentence:

[UNK] the same fusion of tactical and real [UNK] time gameplay as its predecessors , the story runs parallel to the first game and follows the " [UNK] " , a penal military unit serving [MASK] nation of



Mask Prediction for input 4:  the
True Mask for above input: the



Input Sentence:

the game began development [MASK] 2010 , carrying over a larg