In [2]:
%load_ext autoreload
%autoreload 2

In [61]:
from collections import OrderedDict
from typing import Callable, Dict, Optional

import torch as t
from torch import nn
import torch.nn.functional as F
from torch import einsum
from einops import rearrange, reduce, repeat
import bert_tests

## Attention

In [4]:
def raw_attention_pattern(
    token_activations: t.Tensor,
    num_heads: int,
    project_query: Callable[[t.Tensor], t.Tensor],
    project_key: Callable[[t.Tensor], t.Tensor],
) -> t.Tensor:
    """
    token_activations: Tensor[batch_size, seq_length, hidden_size (768)]
    project_query: function( (Tensor[..., 768]) -> Tensor[..., 768] )
    project_key:   function( (Tensor[..., 768]) -> Tensor[..., 768] )
    return: Tensor[batch_size, head_num, key_token: seq_length, query_token: seq_length]
    """

    queries = rearrange(
        project_query(token_activations), "b s (head d) -> b head s d", head=num_heads
    )
    keys = rearrange(
        project_key(token_activations), "b s (head d) -> b head s d", head=num_heads
    )

    head_size = t.tensor(keys.shape[-1])
    return einsum("bhid, bhjd -> bhij", keys, queries) / t.sqrt(head_size)


bert_tests.test_attention_pattern_fn(raw_attention_pattern)


attention pattern raw MATCH!!!!!!!!
 SHAPE (2, 12, 3, 3) MEAN: -0.01314 STD: 0.1147 VALS [-0.1276 0.1329 0.1119 0.06959 -0.05208 0.1612 0.03845 0.0386 -0.1686 0.1203...]


In [5]:
def bert_attention(
    token_activations: t.Tensor,
    num_heads: int,
    attention_pattern: t.Tensor,
    project_value: Callable[[t.Tensor], t.Tensor],
    project_output: Callable[[t.Tensor], t.Tensor],
) -> t.Tensor:
    """
    token_activations: Tensor[batch_size, seq_length, hidden_size (768)],
    num_heads: int,
    attention_pattern: Tensor[batch_size,num_heads, seq_length, seq_length],
    project_value: function( (Tensor[..., 768]) -> Tensor[..., 768] ),
    project_output: function( (Tensor[..., 768]) -> Tensor[..., 768] )
    return: Tensor[batch_size, seq_length, hidden_size]
    """

    attention_prob = t.softmax(attention_pattern, dim=-2)  # dim: b head s s
    values = rearrange(
        project_value(token_activations), "b s (head d) -> b head s d", head=num_heads
    )

    output_by_head = einsum("bhis, bhid -> bhsd", attention_prob, values)
    concatenated = rearrange(output_by_head, "b h s d -> b s (h d)")

    return project_output(concatenated)


bert_tests.test_attention_fn(bert_attention)


attention MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.001297 STD: 0.1089 VALS [0.1178 0.0506 -0.04344 0.1474 0.1352 0.08401 -0.04605 0.08768 0.1694 -0.05225...]


In [6]:
class MultiHeadedSelfAttention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        hidden_size: int,
        attention_dim: int = 64,
        per_head_output_dim: int = 64,
        output_dim: Optional[int] = None,
    ):
        super().__init__()

        if output_dim is None:
            output_dim = hidden_size

        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.attention_dim = attention_dim
        self.per_head_output_dim = per_head_output_dim
        self.output_dim: int = output_dim

        self.Q = nn.Linear(
            in_features=hidden_size, out_features=num_heads * attention_dim
        )
        self.K = nn.Linear(
            in_features=hidden_size, out_features=num_heads * attention_dim
        )
        self.V = nn.Linear(
            in_features=hidden_size, out_features=num_heads * per_head_output_dim
        )
        self.O = nn.Linear(
            in_features=num_heads * per_head_output_dim, out_features=output_dim
        )

    def forward(self, input: t.Tensor) -> t.Tensor:
        """
        input: Tensor[batch_size, seq_length, hidden_size]
        """

        attention_pattern = raw_attention_pattern(
            input,
            self.num_heads,
            project_key=self.K,
            project_query=self.Q,
        )

        return bert_attention(
            token_activations=input,
            num_heads=self.num_heads,
            attention_pattern=attention_pattern,
            project_value=self.V,
            project_output=self.O,
        )


bert_tests.test_bert_attention(MultiHeadedSelfAttention)


bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.001554 STD: 0.1736 VALS [-0.08316 -0.09165 -0.03188 -0.03013 0.1001 0.09549 -0.1046 0.07742 0.0424 0.05553...]


In [7]:
mhsa = MultiHeadedSelfAttention(
    num_heads=17,
    hidden_size=768,
    attention_dim=37,
    per_head_output_dim=89,
    output_dim=2,
)
mhsa(t.ones((10, 117, 768))).shape


torch.Size([10, 117, 2])

## Transformer Encoder block

In [8]:
def bert_mlp(
    token_activations: t.Tensor,
    linear_1: nn.Module,
    linear_2: nn.Module,
) -> t.Tensor:
    """
    token_activations: torch.Tensor[batch_size,seq_length,768],
    return: torch.Tensor[batch_size, seq_length, 768]
    """

    x = linear_1(token_activations)
    x = F.gelu(x)
    x = linear_2(x)
    return x


bert_tests.test_bert_mlp(bert_mlp)


bert mlp MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.002065 STD: 0.1061 VALS [0.0343 -0.03701 0.02609 0.009201 -0.07531 -0.07379 0.04684 -0.08373 0.006134 -0.1191...]


In [9]:
class BertMLP(nn.Module):
    def __init__(self, input_size: int, intermediate_size: int):
        super().__init__()
        self.linear_1 = nn.Linear(
            in_features=input_size, out_features=intermediate_size
        )
        self.linear_2 = nn.Linear(
            in_features=intermediate_size, out_features=input_size
        )

    def forward(self, input: t.Tensor) -> t.Tensor:
        return bert_mlp(input, self.linear_1, self.linear_2)

In [10]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_dim: int):
        super().__init__()
        self.weight = t.nn.Parameter(t.ones((normalized_dim,)))
        self.bias = t.nn.Parameter(t.zeros((normalized_dim,)))

    def forward(self, input: t.Tensor):  # shape[..., normalized_dim]
        m = t.mean(input, dim=-1, keepdim=True).detach()
        s = t.std(input, dim=-1, keepdim=True, unbiased=False).detach()
        input = (input - m) / s
        return input * self.weight + self.bias


bert_tests.test_layer_norm(LayerNorm)

layer norm MATCH!!!!!!!!
 SHAPE (20, 10) MEAN: -9.537e-09 STD: 1.003 VALS [-0.3893 -1.309 1.483 0.3582 0.4961 0.1515 -1.697 -0.7905 1.395 0.3024...]


In [11]:
class BertBlock(nn.Module):
    def __init__(
        self, hidden_size: int, intermediate_size: int, num_heads: int, dropout: float
    ):
        super().__init__()

        self.mha = MultiHeadedSelfAttention(
            num_heads=num_heads, hidden_size=hidden_size
        )
        self.ln1 = LayerNorm(normalized_dim=hidden_size)
        self.bmlp = BertMLP(input_size=hidden_size, intermediate_size=intermediate_size)
        self.ln2 = LayerNorm(normalized_dim=hidden_size)

        self.dropout = nn.Dropout(dropout)

    def forward(self, input: t.Tensor) -> t.Tensor:
        x1 = self.ln1(self.mha(input) + input)
        return self.ln2(self.dropout(self.bmlp(x1)) + x1)

bert_tests.test_bert_block(BertBlock)

bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -4.139e-09 STD: 1 VALS [0.007131 -0.04372 0.6502 -0.5972 -1.097 0.7267 0.1275 -0.6035 -0.2226 0.2145...]


## BERT Embedding

In [12]:
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")
print(tokenizer(['Hello, I am a sentence.']))

{'input_ids': [[101, 8667, 117, 146, 1821, 170, 5650, 119, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1]]}


In [13]:
class Embedding(nn.Module):
    def __init__(self, vocab_size: int, embed_size: int):
        super().__init__()
        self.embedding = nn.Parameter(t.randn((vocab_size, embed_size),))

    def forward(self, input: t.LongTensor) -> t.FloatTensor:
        """
        input: tensor[...]
        return: tensor[..., embed_size]
        """
        return self.embedding[input]

bert_tests.test_embedding(Embedding)


embedding MATCH!!!!!!!!
 SHAPE (2, 3, 5) MEAN: -0.06748 STD: 1.062 VALS [1.176 -0.1914 0.8212 1.047 -0.481 0.7106 -1.304 -1.307 -0.438 -0.2764...]


In [14]:
def bert_embedding(
    input_ids: t.Tensor,  # [batch, seqlen]
    token_type_ids: t.Tensor,  # [batch, seqlen]
    position_embedding: Embedding,
    token_embedding: Embedding,
    token_type_embedding: Embedding,
    layer_norm: LayerNorm,
    dropout: nn.Dropout,
) -> t.Tensor:
    seq_len = input_ids.shape[-1]
    device = input_ids.device

    inputs = token_embedding(input_ids)
    tokens = token_type_embedding(token_type_ids)
    positions = position_embedding(t.arange(seq_len, dtype=t.long, device=device))

    return dropout(layer_norm(inputs + tokens + positions))

bert_tests.test_bert_embedding_fn(bert_embedding)


bert embedding MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: 8.278e-10 STD: 1 VALS [-1.319 -0.4378 -2.074 0.9679 0.9274 1.479 -0.501 -1.9 -0.212 0.7961...]


In [15]:
class BertEmbedding(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        hidden_size: int,
        max_position_embeddings: int,
        type_vocab_size: int,
        dropout: float,
    ):
        super().__init__()
        self.token_embedding = Embedding(vocab_size, hidden_size)
        self.position_embedding = Embedding(max_position_embeddings, hidden_size)
        self.token_type_embedding = Embedding(type_vocab_size, hidden_size)

        self.layer_norm = LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids: t.Tensor, token_type_ids: t.Tensor) -> t.Tensor:
        return bert_embedding(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            position_embedding=self.position_embedding,
            token_embedding=self.token_embedding,
            token_type_embedding=self.token_type_embedding,
            layer_norm=self.layer_norm,
            dropout=self.dropout,
        )


bert_tests.test_bert_embedding(BertEmbedding)


bert embedding MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -3.104e-09 STD: 1 VALS [-0.009385 -0.4919 0.9852 -0.3535 -3.624 1.333 1.163 1.449 1.063 0.246...]


## Putting it all together

In [63]:
class Bert(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        hidden_size: int,
        max_position_embeddings: int,
        type_vocab_size: int,
        dropout: float,
        intermediate_size: int,
        num_heads: int,
        num_layers: int,
    ):
        super().__init__()

        self.embedding = BertEmbedding(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            max_position_embeddings=max_position_embeddings,
            type_vocab_size=type_vocab_size,
            dropout=dropout,
        )

        self.transformer = nn.Sequential(
            *[
                BertBlock(
                    hidden_size=hidden_size,
                    intermediate_size=intermediate_size,
                    num_heads=num_heads,
                    dropout=dropout,
                )
                for _ in range(num_layers)
            ]
        )

        self.lm_head = nn.Sequential(
            OrderedDict(
                [
                    (
                        "mlp",
                        nn.Linear(in_features=hidden_size, out_features=hidden_size),
                    ),
                    ("gelu", nn.GELU()),
                    ("layer_norm", LayerNorm(hidden_size)),
                    (
                        "unembedding",
                        nn.Linear(in_features=hidden_size, out_features=vocab_size),
                    ),
                ]
            )
        )

    def forward(self, input_ids: t.Tensor) -> t.Tensor:
        token_type_ids = t.zeros_like(input_ids)
        embed = self.embedding(input_ids, token_type_ids)
        return self.lm_head(self.transformer(embed))


bert_tests.test_bert(Bert)


bert MATCH!!!!!!!!
 SHAPE (1, 4, 28996) MEAN: 0.003031 STD: 0.5765 VALS [-0.5742 -0.4321 0.1186 -0.7165 -0.5262 0.4967 1.223 0.3165 -0.3247 -0.5717...]


## LayerNorm experiments

In [64]:
class ExperimentalLayerNorm(nn.Module):
    def __init__(self, normalized_dim: int, eps=1e-5):
        super().__init__()
        self.weight = t.nn.Parameter(t.ones((normalized_dim,), dtype=t.float))
        self.bias = t.nn.Parameter(t.zeros((normalized_dim,), dtype=t.float))
        self.eps = eps

    def forward(self, input: t.Tensor):  # shape[..., normalized_dim]
        m = t.mean(input, dim=-1, keepdim=True)  # .detach()
        v = t.var(input - m, dim=-1, keepdim=True, unbiased=False)  # .detach()
        input = (input - m) / t.sqrt(v + self.eps)
        return input * self.weight + self.bias


N = 5
torchs = nn.LayerNorm(N, eps=1e-5)
ours = LayerNorm(N)
exp = ExperimentalLayerNorm(N, eps=1e-5)


def get_grad(layernorm):
    input = t.arange(N, dtype=t.float, requires_grad=True)

    loss = t.sum(layernorm(input.reshape(1, N)) ** 2)
    loss.backward()

    return input.grad


print(get_grad(torchs))
print(get_grad(ours))
print(get_grad(exp))


tensor([-1.0252e-05, -5.1260e-06,  0.0000e+00,  5.1260e-06,  1.0252e-05])
tensor([-2., -1.,  0.,  1.,  2.])
tensor([-1.0014e-05, -5.0068e-06,  0.0000e+00,  5.0068e-06,  1.0014e-05])


## Load weights

In [65]:
my_bert = Bert(
    vocab_size=28996, hidden_size=768, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=3072, 
    num_heads=12, num_layers=12
)
pretrained_bert = bert_tests.get_pretrained_bert()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [79]:
def hf_to_our_state_dict(hf_dict: Dict[str, t.Tensor]) -> Dict[str, t.Tensor]:
    def include_key(key: str) -> bool:
        if key.startswith('classification_head'):
            return False
        return True

    def transform_key(key: str) -> str:
        subkeys = key.split(".")

        if key.startswith("embedding") and key.endswith("_embedding.weight"):
            subkeys[-1] = "embedding"
            return ".".join(subkeys)

        if subkeys[0] == "transformer":
            if subkeys[2] == "attention":
                subkeys[2] = "mha"

                if subkeys[3] == "pattern":
                    subkeys.pop(3)

                subkeys[3] = {
                    "project_value": "V",
                    "project_query": "Q",
                    "project_key": "K",
                    "project_out": "O",
                }[subkeys[3]]

                return ".".join(subkeys)

            if subkeys[2] == "residual":
                if subkeys[3] != "layer_norm":
                    subkeys[2] = "bmlp"
                    subkeys[3] = {"mlp1": "linear_1", "mlp2": "linear_2"}[subkeys[3]]
                    return ".".join(subkeys)
                
                if subkeys[3] == "layer_norm":
                    subkeys.pop(2)
                    subkeys[2] = "ln1"
                    return ".".join(subkeys)

            if subkeys[2] == "layer_norm":
                subkeys[2] = "ln2"
                return ".".join(subkeys)

        return key

    return {transform_key(k): v for k, v in hf_dict.items() if include_key(k)}


my_bert.load_state_dict(hf_to_our_state_dict(pretrained_bert.state_dict()))


RuntimeError: Error(s) in loading state_dict for Bert:
	Unexpected key(s) in state_dict: "classification_head.weight", "classification_head.bias". 