In [1]:
%load_ext autoreload
%autoreload 2

In [14]:
import bert_sol as bert
import bert_tests
import gpt
from hook_handler import HookHandler

import torch as t
import transformers


### Bert vs GPT-2 embeddings

In [8]:
my_bert = bert.Bert(
        vocab_size=28996, hidden_size=768, max_position_embeddings=512,
        type_vocab_size=2, dropout=0.1, intermediate_size=3072,
        num_heads=12, num_layers=12
    )
pretrained_bert = bert_tests.get_pretrained_bert()
mapped_params = {bert.mapkey(k): v for k, v in pretrained_bert.state_dict().items()
                if not k.startswith('classification_head')}
my_bert.load_state_dict(mapped_params)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [4]:
my_gpt = gpt.get_gpt_with_pretrained_weights()

In [17]:
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")
my_bert.cuda()
my_gpt.cuda()
with HookHandler() as hh:
    hh.add_save_activation_hook(my_bert.blocks[-1], "bert.blocks[-1]")
    hh.add_save_activation_hook(my_gpt.gpt_blocks[-1], "gpt.blocks[-1]")

    xs = t.tensor(
        tokenizer(
            [
                "My life motto:",
                "My life motto: Fortune",
                "My life motto: Fortune favors",
                "My life motto: Fortune favors the",
                "My life motto: Fortune favors the bold",
            ],
            padding="longest",
            truncation=True,
        )["input_ids"],
        dtype=t.long,
        device=my_gpt.device,
    )

    my_gpt.eval()
    my_bert.eval()

    my_gpt(xs)
    my_bert(xs)

    print("bert.blocks[-1]", hh.activations["bert.blocks[-1]"][:, 4])
    print("gpt.blocks[-1]", hh.activations["gpt.blocks[-1]"][:, 4])


bert.blocks[-1] tensor([[-0.2722,  0.2118,  0.1717,  ...,  0.4553, -0.0064,  0.2353],
        [-0.3859,  0.1262,  0.1194,  ...,  0.8246,  0.1501,  0.1782],
        [-0.3540,  0.1779,  0.0127,  ...,  0.6782,  0.2399,  0.1714],
        [-0.3492,  0.1421,  0.2156,  ...,  0.7240,  0.4313,  0.0808],
        [-0.3356,  0.0488,  0.0885,  ...,  0.6900,  0.5473,  0.1323]],
       device='cuda:0')
gpt.blocks[-1] tensor([[10.4828,  2.3441,  3.3644,  ...,  3.3975, -4.6225,  0.4826],
        [10.4828,  2.3441,  3.3644,  ...,  3.3975, -4.6225,  0.4826],
        [10.4828,  2.3441,  3.3644,  ...,  3.3975, -4.6225,  0.4826],
        [10.4828,  2.3441,  3.3644,  ...,  3.3975, -4.6225,  0.4826],
        [10.4828,  2.3441,  3.3644,  ...,  3.3975, -4.6225,  0.4826]],
       device='cuda:0')
All hooks removed!
