In [24]:
import torch

import datasets
from models import BERTLM
from models.merged_retrained_bert import MergedRetrainedBert
from models.merged_retrained_bert import RetrainedBlock
from copy import deepcopy

In [25]:
class config():
    def __init__(self):
        self.vocab = "bert-google"
        self.vocab_path = "../data/wikitext2/all.txt"
        self.bert_google_vocab = "../data/uncased_L-12_H-768_A-12/vocab.txt"
        self.test_dataset = "../data/wikitext2/test_data_single_sentence.txt"
        self.vocab_max_size = None
        self.vocab_min_frequency = 1
        self.dataset = "wikitext2"
        self.seq_len = 40
        self.on_memory = True
        self.corpus_lines = None
        self.train_dataset = "../data/wikitext2/test_data_single_sentence.txt"
        self.encoding = "utf-8"
        self.batch_size = 1
        self.num_workers = 1
        self.hidden_features = 768
        self.layers = 12
        self.heads = 12
        self.device = "cpu"
        self.dropout = 0.1
        self.train = True
        self.lr = 1e-3
        self.adam_beta1=0.999
        self.adam_beta2=0.999
        self.adam_weight_decay = 0.01
        self.warmup_steps =1000
        self.storage_directory = "C:/Users/Raphi/PycharmProjects/simplifying-transformers"
        self.model = "MergedRetrainedBert"
        self.model_checkpoint = ""

In [26]:
config = config()

In [27]:
vocab = datasets.get_vocab(config)

Using Bert Vocab


98856it [00:00, 122650.20it/s]
30522it [00:00, 1017239.28it/s]


In [28]:
heads = [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12]
dks = [64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64]
merged = MergedRetrainedBert(
    config,
    vocab_size=len(vocab),
    dks=dks,
    heads=heads
)

bert = BERTLM(config, vocab_size=len(vocab))
bert.load_state(load_optimizer=False, overwrite_path="../models/_checkpoints/wikitext2/BERTLM-latest.pth")
bert.eval()

Loading checkpoint from [../models/_checkpoints/wikitext2/BERTLM-latest.pth]
Loaded checkpoint
Successfully loaded state dict


BERTLM(
  (bert): BERT(
    (embedding): BERTEmbedding(
      (token): TokenEmbedding(30522, 768, padding_idx=0)
      (position): PositionalEmbedding()
      (segment): SegmentEmbedding(2, 768, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
      (layer_norm): LayerNorm()
    )
    (transformer_blocks): ModuleList(
      (0): TransformerBlock(
        (attention): MultiHeadedAttention(
          (linear_layers): ModuleList(
            (0): Linear(in_features=768, out_features=768, bias=True)
            (1): Linear(in_features=768, out_features=768, bias=True)
            (2): Linear(in_features=768, out_features=768, bias=True)
          )
          (output_linear): Linear(in_features=768, out_features=768, bias=True)
          (attention): Attention()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=768, out_features=3072, bias=True)
          (w_2): Linear(in_features=30

In [29]:
print(merged)

MergedRetrainedBert(
  (embedding): BERTEmbedding(
    (token): TokenEmbedding(30522, 768, padding_idx=0)
    (position): PositionalEmbedding()
    (segment): SegmentEmbedding(2, 768, padding_idx=0)
    (dropout): Dropout(p=0.1, inplace=False)
    (layer_norm): LayerNorm()
  )
  (layers): Sequential(
    (0): RetrainedTransformer(
      (attentionblock): BlockMultiHeadedAttention(
        (linear_layers): ModuleList(
          (0): Linear(in_features=768, out_features=768, bias=True)
          (1): Linear(in_features=768, out_features=768, bias=True)
          (2): Linear(in_features=768, out_features=768, bias=True)
        )
        (output_linear): Linear(in_features=768, out_features=768, bias=True)
        (attention): Attention()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (feed_forward): PositionwiseFeedForward(
        (w_1): Linear(in_features=768, out_features=3072, bias=True)
        (w_2): Linear(in_features=3072, out_features=768, bias=True)
        (dro

In [30]:
print(bert)

BERTLM(
  (bert): BERT(
    (embedding): BERTEmbedding(
      (token): TokenEmbedding(30522, 768, padding_idx=0)
      (position): PositionalEmbedding()
      (segment): SegmentEmbedding(2, 768, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
      (layer_norm): LayerNorm()
    )
    (transformer_blocks): ModuleList(
      (0): TransformerBlock(
        (attention): MultiHeadedAttention(
          (linear_layers): ModuleList(
            (0): Linear(in_features=768, out_features=768, bias=True)
            (1): Linear(in_features=768, out_features=768, bias=True)
            (2): Linear(in_features=768, out_features=768, bias=True)
          )
          (output_linear): Linear(in_features=768, out_features=768, bias=True)
          (attention): Attention()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=768, out_features=3072, bias=True)
          (w_2): Linear(in_features=30

Copy Encoder from Bert to our Merged Model

In [31]:
for name, param in bert.named_parameters():
    print(name, param.size())

bert.embedding.token.weight torch.Size([30522, 768])
bert.embedding.segment.weight torch.Size([2, 768])
bert.embedding.layer_norm.a_2 torch.Size([768])
bert.embedding.layer_norm.b_2 torch.Size([768])
bert.transformer_blocks.0.attention.linear_layers.0.weight torch.Size([768, 768])
bert.transformer_blocks.0.attention.linear_layers.0.bias torch.Size([768])
bert.transformer_blocks.0.attention.linear_layers.1.weight torch.Size([768, 768])
bert.transformer_blocks.0.attention.linear_layers.1.bias torch.Size([768])
bert.transformer_blocks.0.attention.linear_layers.2.weight torch.Size([768, 768])
bert.transformer_blocks.0.attention.linear_layers.2.bias torch.Size([768])
bert.transformer_blocks.0.attention.output_linear.weight torch.Size([768, 768])
bert.transformer_blocks.0.attention.output_linear.bias torch.Size([768])
bert.transformer_blocks.0.feed_forward.w_1.weight torch.Size([3072, 768])
bert.transformer_blocks.0.feed_forward.w_1.bias torch.Size([3072])
bert.transformer_blocks.0.feed_forw

In [32]:
merged.embedding.token.weight.data = deepcopy(bert.bert.embedding.token.weight)
merged.embedding.position.pe = deepcopy(bert.bert.embedding.position.pe)
merged.embedding.segment.weight.data = deepcopy(bert.bert.embedding.segment.weight)
merged.embedding.layer_norm.a_2.data = deepcopy(bert.bert.embedding.layer_norm.a_2)
merged.embedding.layer_norm.b_2.data = deepcopy(bert.bert.embedding.layer_norm.b_2)

Copy MLP weights from Bert and Attention weights from retrained blocks

In [33]:
# load weights for each transformer block
for index in range(12):

    retrained = RetrainedBlock(config, depth=index, hidden=config.hidden_features, heads=heads[index], dk=dks[index], dropout=config.dropout)
    retrained.load_state(load_optimizer=False)

    # for j in range(3):
    #     merged.layers[index].attentionblock.linear_layers[j].weight.data = deepcopy(bert.bert.transformer_blocks[index].attention.linear_layers[j].weight)
    #     merged.layers[index].attentionblock.linear_layers[j].bias.data = deepcopy(bert.bert.transformer_blocks[index].attention.linear_layers[j].bias)
    #
    # merged.layers[index].attentionblock.output_linear.weight.data = deepcopy(bert.bert.transformer_blocks[index].attention.output_linear.weight)
    # merged.layers[index].attentionblock.output_linear.bias.data = deepcopy(bert.bert.transformer_blocks[index].attention.output_linear.bias)

    # block_checkpoints
    for j in range(3):
        merged.layers[index].attentionblock.linear_layers[j].weight.data = retrained.attentionblock.linear_layers[j].weight
        merged.layers[index].attentionblock.linear_layers[j].bias.data = retrained.attentionblock.linear_layers[j].bias

    merged.layers[index].attentionblock.output_linear.weight.data = retrained.attentionblock.output_linear.weight
    merged.layers[index].attentionblock.output_linear.bias.data = retrained.attentionblock.output_linear.bias

    # feed_forward
    merged.layers[index].feed_forward.w_1.weight.data = deepcopy(bert.bert.transformer_blocks[index].feed_forward.w_1.weight)
    merged.layers[index].feed_forward.w_1.bias.data = deepcopy(bert.bert.transformer_blocks[index].feed_forward.w_1.bias)
    merged.layers[index].feed_forward.w_2.weight.data = deepcopy(bert.bert.transformer_blocks[index].feed_forward.w_2.weight)
    merged.layers[index].feed_forward.w_2.bias.data = deepcopy(bert.bert.transformer_blocks[index].feed_forward.w_2.bias)

    #input_sublayer
    merged.layers[index].input_sublayer.norm.a_2.data = deepcopy(bert.bert.transformer_blocks[index].input_sublayer.norm.a_2)
    merged.layers[index].input_sublayer.norm.b_2.data = deepcopy(bert.bert.transformer_blocks[index].input_sublayer.norm.b_2)

    #output_sublayer
    merged.layers[index].output_sublayer.norm.a_2.data = deepcopy(bert.bert.transformer_blocks[index].output_sublayer.norm.a_2)
    merged.layers[index].output_sublayer.norm.b_2.data = deepcopy(bert.bert.transformer_blocks[index].output_sublayer.norm.b_2)

Block Path: C:/Users/Raphi/PycharmProjects/simplifying-transformers/models/_checkpoints/wikitext2/block_0_64_12/BLOCK-latest.pth
Loaded block checkpoint
Successfully loaded state dict for block model
Block Path: C:/Users/Raphi/PycharmProjects/simplifying-transformers/models/_checkpoints/wikitext2/block_1_64_12/BLOCK-latest.pth
Loaded block checkpoint
Successfully loaded state dict for block model
Block Path: C:/Users/Raphi/PycharmProjects/simplifying-transformers/models/_checkpoints/wikitext2/block_2_64_12/BLOCK-latest.pth
Loaded block checkpoint
Successfully loaded state dict for block model
Block Path: C:/Users/Raphi/PycharmProjects/simplifying-transformers/models/_checkpoints/wikitext2/block_3_64_12/BLOCK-latest.pth
Loaded block checkpoint
Successfully loaded state dict for block model
Block Path: C:/Users/Raphi/PycharmProjects/simplifying-transformers/models/_checkpoints/wikitext2/block_4_64_12/BLOCK-latest.pth
Loaded block checkpoint
Successfully loaded state dict for block model


Copy MaskLM Weights

In [34]:
merged.mask_lm.linear.weight.data = deepcopy(bert.mask_lm.linear.weight)
merged.mask_lm.linear.bias.data = deepcopy(bert.mask_lm.linear.bias)
merged.mask_lm.layer_norm.a_2.data = deepcopy(bert.mask_lm.layer_norm.a_2)
merged.mask_lm.layer_norm.b_2.data = deepcopy(bert.mask_lm.layer_norm.b_2)
merged.mask_lm.decoder.weight.data = deepcopy(bert.mask_lm.decoder.weight)
merged.mask_lm.decoder.bias.data = deepcopy(bert.mask_lm.decoder.bias)

In [35]:
import numpy

sentence = ["[CLS]", "i", "like", "to", "[MASK]", "pizza", "[SEP]"] + ["[PAD]"] * 33
segment_label = torch.from_numpy(numpy.array([1] * 7 + [0] * 33))
ids = numpy.array(list(map(lambda x: vocab.stoi[x], sentence)))
ids = torch.unsqueeze(torch.from_numpy(ids), dim=0)
bert.eval()
merged.eval()
teacher_pred = bert(ids.to(config.device), segment_label.to(config.device))
teacher_pred = teacher_pred[0][4]
print(teacher_pred)
teacher_pred = torch.argmax(teacher_pred, dim=0)
teacher_word = vocab.itos[teacher_pred]
print(teacher_word)

pred = merged(ids.to(config.device), segment_label.to(config.device))
pred = pred[0][4]
print(pred)
pred = torch.argmax(pred, dim=0)
word = vocab.itos[pred]
print(word)

tensor([-22.7625, -22.8823, -21.2407,  ..., -21.9298, -20.7592, -20.0416],
       grad_fn=<SelectBackward0>)
eat
tensor([-16.0766, -16.5733, -16.6725,  ..., -16.2520, -16.4822, -10.4133],
       grad_fn=<SelectBackward0>)
,


In [36]:
for (name, param), (na, pa) in zip(merged.named_parameters(), bert.named_parameters()):
    equal = param == pa
    equal = torch.all(equal)

    print(str(equal.item()) + "\t", name)

    # print(name, param.size(), na, pa.size())

True	 embedding.token.weight
True	 embedding.segment.weight
True	 embedding.layer_norm.a_2
True	 embedding.layer_norm.b_2
False	 layers.0.attentionblock.linear_layers.0.weight
False	 layers.0.attentionblock.linear_layers.0.bias
False	 layers.0.attentionblock.linear_layers.1.weight
False	 layers.0.attentionblock.linear_layers.1.bias
False	 layers.0.attentionblock.linear_layers.2.weight
False	 layers.0.attentionblock.linear_layers.2.bias
False	 layers.0.attentionblock.output_linear.weight
False	 layers.0.attentionblock.output_linear.bias
True	 layers.0.feed_forward.w_1.weight
True	 layers.0.feed_forward.w_1.bias
True	 layers.0.feed_forward.w_2.weight
True	 layers.0.feed_forward.w_2.bias
True	 layers.0.input_sublayer.norm.a_2
True	 layers.0.input_sublayer.norm.b_2
True	 layers.0.output_sublayer.norm.a_2
True	 layers.0.output_sublayer.norm.b_2
False	 layers.1.attentionblock.linear_layers.0.weight
False	 layers.1.attentionblock.linear_layers.0.bias
False	 layers.1.attentionblock.linear_laye

In [37]:
merged.save_model(running=True)