In [5]:
import torch
from torch import einsum, nn
from einops import rearrange, reduce, repeat
import bert_tests
import math

In [6]:
def raw_attention_pattern(token_activations, num_heads, project_query, project_key):   
    print(token_activations.shape)
    queries = project_query(token_activations)
    keys = project_key(token_activations)
    print(queries.shape, keys.shape)
    queries = rearrange(queries, "b n (h c) -> b h n c", h=num_heads)
    keys = rearrange(keys, "b n (h c) -> b h n c", h=num_heads)
    result = einsum("bhnc,bhmc->bhnm", keys, queries)
    return result / math.sqrt(queries.shape[-1])

bert_tests.test_attention_pattern_fn(raw_attention_pattern)

torch.Size([2, 3, 768])
torch.Size([2, 3, 768]) torch.Size([2, 3, 768])
attention pattern raw MATCH!!!!!!!!
 SHAPE (2, 12, 3, 3) MEAN: 0.002978 STD: 0.1198 VALS [0.0555 -0.01219 0.02629 -0.06629 0.02201 0.06769 -0.01915 -0.02644 0.1231 -0.0728...]


In [7]:
def bert_attention(token_activations, num_heads, attention_pattern, project_value, project_output):
    softmax_fn = torch.nn.Softmax(dim=-2)
    attention_pattern = softmax_fn(attention_pattern)
    values = project_value(token_activations)
    values = rearrange(values, "b n (h c) -> b h n c", h=num_heads)
    output = einsum("bhkq, bhkc -> bhqc", attention_pattern, values)
    output = rearrange(output, "b h n c -> b n (h c)")
    result = project_output(output)
    return result

bert_tests.test_attention_fn(bert_attention)

attention MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.001244 STD: 0.1123 VALS [0.02294 0.04522 0.1502 -0.04545 -0.08154 -0.1219 -0.01069 -0.03793 -0.1267 0.09468...]


In [8]:
class MultiHeadedSelfAttention(nn.Module):

    def __init__(self, num_heads, hidden_size):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.query_embedding = nn.Linear(hidden_size, hidden_size)
        self.key_embedding = nn.Linear(hidden_size, hidden_size)
        self.value_embedding = nn.Linear(hidden_size, hidden_size)
        self.output_embedding = nn.Linear(hidden_size, hidden_size)

    def forward(self, token_activations):
        raw_attention = raw_attention_pattern(token_activations, self.num_heads, self.query_embedding, self.key_embedding)
        return bert_attention(token_activations, self.num_heads, raw_attention, self.value_embedding, self.output_embedding)

bert_tests.test_bert_attention(MultiHeadedSelfAttention)

torch.Size([2, 3, 768])
torch.Size([2, 3, 768]) torch.Size([2, 3, 768])
bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.001554 STD: 0.1736 VALS [-0.08316 -0.09165 -0.03188 -0.03013 0.1001 0.09549 -0.1046 0.07742 0.0424 0.05553...]


In [10]:
def bert_mlp(token_activations, linear_1, linear_2):
    return linear_2(nn.GELU()(linear_1(token_activations)))
bert_tests.test_bert_mlp(bert_mlp)

bert mlp MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.0003685 STD: 0.1087 VALS [-0.07879 -0.1109 0.1264 -0.173 -0.06065 0.1507 -0.03468 -0.2432 -0.09689 0.05654...]


In [32]:
class BertMLP(nn.Module):
    def __init__(self, input_size, intermediate_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, intermediate_size)
        self.linear2 = nn.Linear(intermediate_size, input_size)
    
    def forward(self, X):
        return bert_mlp(X, self.linear1, self.linear2)


In [29]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_dim):
        super().__init__()
        self.weight = torch.ones(normalized_dim)
        self.bias = torch.zeros(normalized_dim)

    def forward(self, X):
        std = torch.std(X.detach(), -1, unbiased=False)
        mean = torch.mean(X.detach(), -1)

        print(((X.detach() - mean.unsqueeze(-1)) / std.unsqueeze(-1)).shape, self.weight.shape)
        return ((X.detach() - mean.unsqueeze(-1)) / std.unsqueeze(-1)) * self.weight + self.bias

bert_tests.test_layer_norm(LayerNorm)

torch.Size([20, 10]) torch.Size([10])
layer norm MATCH!!!!!!!!
 SHAPE (20, 10) MEAN: 2.384e-09 STD: 1.003 VALS [-0.5207 0.05752 2.278 0.05417 -1.073 -0.2957 -1.152 -0.5042 1.302 -0.1456...]


In [34]:
class BertBlock(nn.Module):
    def __init__(self, hidden_size, intermediate_size, num_heads, dropout):

        # attention: batches tokens heads * channels -> batches tokens heads * channels
        # layer_norm: batches tokens heads * channels -> batches tokens heads * channels
        # mlp: batches tokens heads * channels -> batches tokens heads * channels
        # layer_norm: batches tokens heads * channels
        super().__init__()
        self.attention = MultiHeadedSelfAttention(num_heads, hidden_size)
        self.layer_norm = LayerNorm(hidden_size)
        self.mlp = BertMLP(hidden_size, intermediate_size) # hidden_size = heads * channels
        self.dropout = nn.Dropout(dropout)

    def forward(self, X):
        residual1 = torch.clone(X)
        output1 = self.layer_norm(self.attention(X) + residual1)
        residual2 = torch.clone(output1)
        return self.layer_norm(self.dropout(self.mlp(output1)) + residual2)

bert_tests.test_bert_block(BertBlock)



torch.Size([2, 3, 768])
torch.Size([2, 3, 768]) torch.Size([2, 3, 768])
torch.Size([2, 3, 768]) torch.Size([768])
torch.Size([2, 3, 768]) torch.Size([768])
bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -4.139e-09 STD: 1 VALS [0.007131 -0.04372 0.6502 -0.5972 -1.097 0.7267 0.1275 -0.6035 -0.2226 0.2145...]


In [37]:
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")
print(tokenizer(['Hello, I am a sentence.']))

{'input_ids': [[101, 8667, 117, 146, 1821, 170, 5650, 119, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1]]}


In [54]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        self.embed_matrix = torch.nn.Parameter(torch.randn(vocab_size, embed_size))

    def forward(self, X):
        return (self.embed_matrix[X, :]).to(X.device)

bert_tests.test_embedding(Embedding)

x = torch.randint(0, 10, (2, 3)).cuda()
print(x.is_cuda)
emb1 = Embedding(10, 5)
print(emb1(x).is_cuda)

embedding MATCH!!!!!!!!
 SHAPE (2, 3, 5) MEAN: -0.4328 STD: 1.09 VALS [-1.529 -1.113 1.017 -0.9385 -1.151 -0.8435 0.0199 -0.7648 1.023 -1.396...]
True
True


In [None]:
def bert_embedding(
    input_ids,
    token_type_ids,
    position_embedding,
    token_embedding,
    token_type_embedding,
    layer_norm
):
    pass
    

bert_tests.test_bert_embedding_fn(bert_embedding)

In [None]:


class BertEmbedding(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout):
        super().__init__()
        self.vocab_embed = Embedding(vocab_size, dropout)
        self.position_embed = torch.nn.Parameter(torch.randn(max_position_embeddings, hidden_size))
        self.token_type_embed = torch.nn.Parameter(torch.randn(type_vocab_size, hidden_size))
        self.layer_norm = LayerNorm(hidden_size)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, input_ids, token_type_ids):
        return bert_embedding(
            input_ids, 
            token_type_ids, 
            self.position_embed, 
            self.vocab_embed,
            self.token_type_embed,
            self.layer_norm)