In [25]:
import datasets
from models import BERTLM
from models.squish_bert import SquishBert
from models.squish_bert import RetrainedBlock

In [26]:
class config():
    def __init__(self):
        self.vocab = "bert-google"
        self.vocab_path = "../data/wikitext2/all.txt"
        self.bert_google_vocab = "../data/uncased_L-12_H-768_A-12/vocab.txt"
        self.vocab_max_size = None
        self.vocab_min_frequency = 1
        self.dataset = "wikitext2"
        self.seq_len = 40
        self.on_memory = True
        self.corpus_lines = None
        self.train_dataset = "../data/wikitext2/test_data_single_sentence.txt"
        self.encoding = "utf-8"
        self.batch_size = 1
        self.num_workers = 1
        self.hidden_features = 768
        self.layers = 12
        self.heads = 12
        self.device = "cpu"
        self.dropout = 0.1
        self.train = True
        self.lr = 1e-3
        self.adam_beta1=0.999
        self.adam_beta2=0.999
        self.adam_weight_decay = 0.01
        self.warmup_steps =1000
        self.storage_directory = "C:/Users/Raphi/PycharmProjects/simplifying-transformers"
        self.model = "SquishBert"

In [27]:
config = config()

In [28]:
vocab = datasets.get_vocab(config)

Using Bert Vocab


98856it [00:00, 114284.49it/s]
30522it [00:00, 984621.72it/s]


In [29]:
heads = [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12]
dks = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16]
squished = SquishBert(
    config,
    vocab_size=len(vocab),
    dks=dks,
    heads=heads
)

bert = BERTLM(config, vocab_size=len(vocab))

In [30]:
print(squished)

SquishBert(
  (embedding): BERTEmbedding(
    (token): TokenEmbedding(30522, 768, padding_idx=0)
    (position): PositionalEmbedding()
    (segment): SegmentEmbedding(2, 768, padding_idx=0)
    (dropout): Dropout(p=0.1, inplace=False)
    (layer_norm): LayerNorm()
  )
  (layers): Sequential(
    (0): RetrainedTransformer(
      (attentionblock): BlockMultiHeadedAttention(
        (linear_layers): ModuleList(
          (0): Linear(in_features=768, out_features=192, bias=True)
          (1): Linear(in_features=768, out_features=192, bias=True)
          (2): Linear(in_features=768, out_features=192, bias=True)
        )
        (output_linear): Linear(in_features=192, out_features=768, bias=True)
        (attention): Attention()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (feed_forward): PositionwiseFeedForward(
        (w_1): Linear(in_features=768, out_features=3072, bias=True)
        (w_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout): Dr

In [31]:
print(bert)

BERTLM(
  (bert): BERT(
    (embedding): BERTEmbedding(
      (token): TokenEmbedding(30522, 768, padding_idx=0)
      (position): PositionalEmbedding()
      (segment): SegmentEmbedding(2, 768, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
      (layer_norm): LayerNorm()
    )
    (transformer_blocks): ModuleList(
      (0): TransformerBlock(
        (attention): MultiHeadedAttention(
          (linear_layers): ModuleList(
            (0): Linear(in_features=768, out_features=768, bias=True)
            (1): Linear(in_features=768, out_features=768, bias=True)
            (2): Linear(in_features=768, out_features=768, bias=True)
          )
          (output_linear): Linear(in_features=768, out_features=768, bias=True)
          (attention): Attention()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=768, out_features=3072, bias=True)
          (w_2): Linear(in_features=30

Copy Encoder from Bert to our Merged Model

In [32]:
squished.embedding.token.weight.data = bert.bert.embedding.token.weight
squished.embedding.segment.weight.data = bert.bert.embedding.segment.weight
squished.embedding.layer_norm.a_2.data = bert.bert.embedding.layer_norm.a_2
squished.embedding.layer_norm.b_2.data = bert.bert.embedding.layer_norm.b_2

Copy MLP weights from Bert and Attention weights from retrained blocks

In [33]:
# load weights for each transformer block
for index in range(12):
    retrained = RetrainedBlock(config, depth=index, hidden=config.hidden_features, heads=heads[index], dk=dks[index], dropout=config.dropout)
    retrained.load_state(load_optimizer=False)

    # block_checkpoints
    for j in range(3):
        squished.layers[index].attentionblock.linear_layers[j].weight.data = retrained.attentionblock.linear_layers[j].weight
        squished.layers[index].attentionblock.linear_layers[j].bias.data = retrained.attentionblock.linear_layers[j].bias

    squished.layers[index].attentionblock.output_linear.weight.data = retrained.attentionblock.output_linear.weight
    squished.layers[index].attentionblock.output_linear.bias.data = retrained.attentionblock.output_linear.bias

    # feed_forward
    squished.layers[index].feed_forward.w_1.weight.data = bert.bert.transformer_blocks[index].feed_forward.w_1.weight
    squished.layers[index].feed_forward.w_1.bias.data = bert.bert.transformer_blocks[index].feed_forward.w_1.bias
    squished.layers[index].feed_forward.w_2.weight.data = bert.bert.transformer_blocks[index].feed_forward.w_2.weight
    squished.layers[index].feed_forward.w_2.bias.data = bert.bert.transformer_blocks[index].feed_forward.w_2.bias

    #input_sublayer
    squished.layers[index].input_sublayer.norm.a_2.data = bert.bert.transformer_blocks[index].input_sublayer.norm.a_2
    squished.layers[index].input_sublayer.norm.b_2.data = bert.bert.transformer_blocks[index].input_sublayer.norm.b_2

    #output_sublayer
    squished.layers[index].output_sublayer.norm.a_2.data = bert.bert.transformer_blocks[index].output_sublayer.norm.a_2
    squished.layers[index].output_sublayer.norm.b_2.data = bert.bert.transformer_blocks[index].output_sublayer.norm.b_2

Block Path: C:/Users/Raphi/PycharmProjects/simplifying-transformers/models/_checkpoints/wikitext2/block_0_16_12/BLOCK-latest.pth
Loaded block checkpoint
Successfully loaded state dict for block model
Block Path: C:/Users/Raphi/PycharmProjects/simplifying-transformers/models/_checkpoints/wikitext2/block_1_16_12/BLOCK-latest.pth
Loaded block checkpoint
Successfully loaded state dict for block model
Block Path: C:/Users/Raphi/PycharmProjects/simplifying-transformers/models/_checkpoints/wikitext2/block_2_16_12/BLOCK-latest.pth
Loaded block checkpoint
Successfully loaded state dict for block model
Block Path: C:/Users/Raphi/PycharmProjects/simplifying-transformers/models/_checkpoints/wikitext2/block_3_16_12/BLOCK-latest.pth
Loaded block checkpoint
Successfully loaded state dict for block model
Block Path: C:/Users/Raphi/PycharmProjects/simplifying-transformers/models/_checkpoints/wikitext2/block_4_16_12/BLOCK-latest.pth
Loaded block checkpoint
Successfully loaded state dict for block model


Copy MaskLM Weights

In [34]:
squished.mask_lm.linear.weight.data = bert.mask_lm.linear.weight
squished.mask_lm.linear.bias.data = bert.mask_lm.linear.bias
squished.mask_lm.layer_norm.a_2.data = bert.mask_lm.layer_norm.a_2
squished.mask_lm.layer_norm.b_2.data = bert.mask_lm.layer_norm.b_2
squished.mask_lm.decoder.weight.data = bert.mask_lm.decoder.weight
squished.mask_lm.decoder.bias.data = bert.mask_lm.decoder.bias

In [35]:
#Cope
for i in range(1000):
    # scream
    continue
print(squished)

SquishBert(
  (embedding): BERTEmbedding(
    (token): TokenEmbedding(30522, 768, padding_idx=0)
    (position): PositionalEmbedding()
    (segment): SegmentEmbedding(2, 768, padding_idx=0)
    (dropout): Dropout(p=0.1, inplace=False)
    (layer_norm): LayerNorm()
  )
  (layers): Sequential(
    (0): RetrainedTransformer(
      (attentionblock): BlockMultiHeadedAttention(
        (linear_layers): ModuleList(
          (0): Linear(in_features=768, out_features=192, bias=True)
          (1): Linear(in_features=768, out_features=192, bias=True)
          (2): Linear(in_features=768, out_features=192, bias=True)
        )
        (output_linear): Linear(in_features=192, out_features=768, bias=True)
        (attention): Attention()
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (feed_forward): PositionwiseFeedForward(
        (w_1): Linear(in_features=768, out_features=3072, bias=True)
        (w_2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout): Dr

In [36]:
squished.save_model(running=True)