In [11]:
import torch
from transformers import AutoTokenizer, BertForMaskedLM

from transformer_lens.HookedEncoderConfig import HookedEncoderConfig
from transformer_lens.components import TokenTypeEmbed, BertEmbed
from transformer_lens.utils import get_corner

In [2]:
def _copy(mine, theirs):
    mine.detach().copy_(theirs)

In [3]:
cfg = HookedEncoderConfig(
    d_vocab=28996,
    d_model=768,
    n_ctx=512,
    eps=1e-12,
)

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
input_ids = tokenizer.encode(text="Hello world!", return_tensors="pt")

huggingface_bert = BertForMaskedLM.from_pretrained("bert-base-cased")
huggingface_word_embed = huggingface_bert.bert.embeddings.word_embeddings

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
token_type_ids = torch.zeros(10, dtype=int)
token_type_ids[7:] = 1
print(f"{token_type_ids=}")

token_type_embed = TokenTypeEmbed(cfg)

huggingface_token_type_embed = huggingface_bert.bert.embeddings.token_type_embeddings

_copy(token_type_embed.W_token_type, huggingface_token_type_embed.weight)

torch.equal(token_type_embed(token_type_ids), huggingface_token_type_embed(token_type_ids))


token_type_ids=tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1])


True

In [5]:
embed = BertEmbed(cfg)
huggingface_embed = huggingface_bert.bert.embeddings


my_parameters = list(embed.named_parameters())
their_parameters = list(huggingface_embed.named_parameters())

state_dict = {
    my_name: their_param
    for (my_name, _), (_, their_param)
      in zip(my_parameters, their_parameters)
}

embed.load_state_dict(state_dict)


torch.equal(embed(input_ids), huggingface_embed(input_ids))


['word_embeddings.W_E', 'position_embeddings.W_pos', 'token_type_embeddings.W_token_type', 'layer_norm.w', 'layer_norm.b']
['word_embeddings.weight', 'position_embeddings.weight', 'token_type_embeddings.weight', 'LayerNorm.weight', 'LayerNorm.bias']


False

In [8]:

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
next_sentence = "The sky is blue due to the shorter wavelength of blue light."
encoding = tokenizer(prompt, next_sentence, return_tensors="pt")

encoding.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [9]:

input_ids = encoding["input_ids"]
token_type_ids = encoding["token_type_ids"]

embed = BertEmbed(cfg)
my_parameters = list(embed.named_parameters())

huggingface_bert = BertForMaskedLM.from_pretrained("bert-base-cased")
huggingface_embed = huggingface_bert.bert.embeddings
their_parameters = list(huggingface_embed.named_parameters())

state_dict = {
    my_name: their_param
    for (my_name, _), (_, their_param)
    in zip(my_parameters, their_parameters)
}

embed.load_state_dict(state_dict)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [14]:
mine = embed(input_ids, token_type_ids=token_type_ids)

In [15]:
theirs = (huggingface_embed(input_ids, token_type_ids=token_type_ids))

In [22]:
torch.testing.assert_close(mine, theirs)

In [23]:
tokenizer("foo", return_tensors="pt")

{'input_ids': tensor([[ 101,  175, 5658,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1]])}