In [4]:
import torch
from torch import einsum, nn
from einops import rearrange, reduce, repeat
import bert_tests
import math

In [5]:
def raw_attention_pattern(token_activations, num_heads, project_query, project_key):   
    queries = project_query(token_activations)
    keys = project_key(token_activations)
    queries = rearrange(queries, "b n (h c) -> b h n c", h=num_heads)
    keys = rearrange(keys, "b n (h c) -> b h n c", h=num_heads)
    result = einsum("bhnc,bhmc->bhnm", keys, queries)
    return result / math.sqrt(queries.shape[-1])

bert_tests.test_attention_pattern_fn(raw_attention_pattern)

attention pattern raw MATCH!!!!!!!!
 SHAPE (2, 12, 3, 3) MEAN: 0.01187 STD: 0.1058 VALS [-0.08359 0.04135 -0.3284 -0.07717 -0.08166 -0.06112 0.167 -0.09143 -0.05247 0.08712...]


In [6]:
def bert_attention(token_activations, num_heads, attention_pattern, project_value, project_output):
    softmax_fn = torch.nn.Softmax(dim=-2)
    attention_pattern = softmax_fn(attention_pattern)
    values = project_value(token_activations)
    values = rearrange(values, "b n (h c) -> b h n c", h=num_heads)
    output = einsum("bhkq, bhkc -> bhqc", attention_pattern, values)
    output = rearrange(output, "b h n c -> b n (h c)")
    result = project_output(output)
    return result

bert_tests.test_attention_fn(bert_attention)

attention MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.005397 STD: 0.1188 VALS [-0.1328 -0.03846 0.0552 0.05028 -0.2345 0.2486 -0.05643 -0.1067 -0.05023 0.2406...]


In [7]:
class RawAttentionPattern(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.project_query = nn.Linear(hidden_size, hidden_size)
        self.project_key = nn.Linear(hidden_size, hidden_size)

    def forward(self, token_activations, num_heads):
        return raw_attention_pattern(token_activations, num_heads, self.project_query, self.project_key)

class MultiHeadedSelfAttention(nn.Module):
    def __init__(self, num_heads, hidden_size):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.pattern = RawAttentionPattern(hidden_size)       
        self.project_value = nn.Linear(hidden_size, hidden_size)
        self.project_out = nn.Linear(hidden_size, hidden_size)

    def forward(self, token_activations):
        raw_attention = self.pattern(token_activations, self.num_heads)
        return bert_attention(token_activations, self.num_heads, raw_attention, self.project_value, self.project_out)

bert_tests.test_bert_attention(MultiHeadedSelfAttention)

bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.001554 STD: 0.1736 VALS [-0.08316 -0.09165 -0.03188 -0.03013 0.1001 0.09549 -0.1046 0.07742 0.0424 0.05553...]


In [8]:
def bert_mlp(token_activations, linear_1, linear_2):
    return linear_2(nn.GELU()(linear_1(token_activations)))
bert_tests.test_bert_mlp(bert_mlp)

bert mlp MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.0001934 STD: 0.1044 VALS [-0.1153 0.1189 -0.0813 0.1021 0.0296 0.06182 0.0341 0.1446 0.2622 -0.08507...]


In [9]:
class BertMLP(nn.Module):
    def __init__(self, input_size, intermediate_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, intermediate_size)
        self.linear2 = nn.Linear(intermediate_size, input_size)
    
    def forward(self, X):
        return bert_mlp(X, self.linear1, self.linear2)


In [10]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_dim, eps=1e-05):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(normalized_dim))
        self.bias = torch.nn.Parameter(torch.zeros(normalized_dim))
        self.eps = eps

    def forward(self, X):
        var = torch.var(X.detach(), -1, unbiased=False)
        mean = torch.mean(X.detach(), -1)

        return (X.detach() - mean.unsqueeze(-1)) / torch.sqrt(var.unsqueeze(-1) + self.eps) * self.weight + self.bias

bert_tests.test_layer_norm(LayerNorm)

layer norm MATCH!!!!!!!!
 SHAPE (20, 10) MEAN: -1.431e-08 STD: 1.003 VALS [0.6906 -0.84 1.881 1.711 -0.5116 -0.9577 -0.1387 -0.6943 -0.6741 -0.4662...]


In [11]:
class BertResidual(nn.Module):
    def __init__(self, hidden_size, intermediate_size, dropout):
        super().__init__()
        self.layer_norm = LayerNorm(hidden_size)
        self.mlp1 = nn.Linear(hidden_size, intermediate_size)
        self.mlp2 = nn.Linear(intermediate_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, X):
        identity = torch.clone(X)
        X = bert_mlp(X, self.mlp1, self.mlp2)
        return self.layer_norm(self.dropout(X) + identity)

class BertBlock(nn.Module):
    def __init__(self, hidden_size, intermediate_size, num_heads, dropout):

        # attention: batches tokens heads * channels -> batches tokens heads * channels
        # layer_norm: batches tokens heads * channels -> batches tokens heads * channels
        # mlp: batches tokens heads * channels -> batches tokens heads * channels
        # layer_norm: batches tokens heads * channels
        super().__init__()
        self.attention = MultiHeadedSelfAttention(num_heads, hidden_size)
        self.layer_norm = LayerNorm(hidden_size)
        # self.mlp = BertMLP(hidden_size, intermediate_size) # hidden_size = heads * channels
        self.residual = BertResidual(hidden_size, intermediate_size, dropout)
        

    def forward(self, X):
        residual1 = torch.clone(X)
        output1 = self.layer_norm(self.attention(X) + residual1)
        return self.residual(output1)


bert_tests.test_bert_block(BertBlock)



bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -5.381e-09 STD: 1 VALS [0.007132 -0.04372 0.6502 -0.5972 -1.097 0.7267 0.1275 -0.6035 -0.2226 0.2145...]


In [12]:
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")
print(tokenizer(['Hello, I am a sentence.']))

{'input_ids': [[101, 8667, 117, 146, 1821, 170, 5650, 119, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1]]}


In [13]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.randn(vocab_size, embed_size))

    def forward(self, X):
        return (self.weight[X, :]).to(X.device)

bert_tests.test_embedding(Embedding)



embedding MATCH!!!!!!!!
 SHAPE (2, 3, 5) MEAN: -0.06748 STD: 1.062 VALS [1.176 -0.1914 0.8212 1.047 -0.481 0.7106 -1.304 -1.307 -0.438 -0.2764...]


In [14]:
# cool stuff!!!!

x = torch.randint(0, 10, (2, 3)).cuda()
print(x.is_cuda)
emb1 = Embedding(10, 5)
print(emb1(x).is_cuda)

True
True


In [15]:
def bert_embedding(
    input_ids, # batch_size x max sentence length -> token_id
    token_type_ids, # batch_size x max sentence length -> token_type
    position_embedding, # [positions] -> [embeddings for each position]
    token_embedding, # token_ids -> [embeddings for each token]
    token_type_embedding, # token_type -> 
    layer_norm, # keeps dimensions
    dropout # keeps dimensions
):
    token_embed = token_embedding(input_ids).to(input_ids.device)
    token_type_embed = token_type_embedding(token_type_ids).to(input_ids.device)
    position_embed = position_embedding(torch.arange(input_ids.shape[1])).to(input_ids.device)
    return dropout(layer_norm(token_embed + token_type_embed + position_embed)).to(input_ids.device)
    

bert_tests.test_bert_embedding_fn(bert_embedding)

bert embedding MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: 0 STD: 1 VALS [0.2316 0.08455 -0.5146 1.436 2.029 -1.117 2.775 -0.5305 -0.4485 -0.2485...]


In [16]:


class BertEmbedding(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout):
        super().__init__()
        self.token_embedding = Embedding(vocab_size, hidden_size)
        self.position_embedding = Embedding(max_position_embeddings, hidden_size)
        self.token_type_embedding = Embedding(type_vocab_size, hidden_size)
        self.layer_norm = LayerNorm(hidden_size)
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, input_ids, token_type_ids):
        return bert_embedding(
            input_ids, 
            token_type_ids, 
            self.position_embedding, 
            self.token_embedding,
            self.token_type_embedding,
            self.layer_norm,
            self.dropout)



In [24]:
class BertUnembed(nn.Module):
    def __init__(self, hidden_size, vocab_size):
        super().__init__()
        self.mlp = nn.Linear(hidden_size, hidden_size)
        self.layer_norm = LayerNorm(hidden_size)
        self.unembedding = nn.Linear(hidden_size, vocab_size)
        self.gelu = nn.GELU()

    def forward(self, activations):
        return self.unembedding(self.layer_norm(self.gelu(self.mlp(activations))))


class Bert(nn.Module):
    def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size,
        dropout, intermediate_size, num_heads, num_layers, num_classes=2
    ):
        super().__init__()
        self.embedding = BertEmbedding(vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout)
        self.transformer = nn.ModuleList([BertBlock(hidden_size, intermediate_size, num_heads, dropout) for _ in range(num_layers)])
        self.lm_head = BertUnembed(hidden_size, vocab_size)
        self.classification_head = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, input_ids, only_classification=False):
        token_type_ids = torch.zeros_like(input_ids, device=input_ids.device)

        embeddings = self.embedding(input_ids, token_type_ids)

        post_bert_blocks = embeddings
        for block in self.transformer:
            post_bert_blocks = block(post_bert_blocks)

        lm_output = None if only_classification else self.lm_head(post_bert_blocks)
        return lm_output, self.classification_head(post_bert_blocks[:,0])

#bert_tests.test_bert(Bert)
bert_tests.test_bert_classification(Bert)

bert MATCH!!!!!!!!
 SHAPE (1, 4, 28996) MEAN: 0.003031 STD: 0.5765 VALS [-0.5742 -0.432 0.1186 -0.7165 -0.5261 0.4967 1.223 0.3165 -0.3247 -0.5716...]
bert MATCH!!!!!!!!
 SHAPE (1, 2) MEAN: 0.09479 STD: 1.411 VALS [-0.903 1.093]


In [20]:

pretrained_bert = bert_tests.get_pretrained_bert()


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [21]:
my_bert = Bert(
    vocab_size=28996, hidden_size=768, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=3072, 
    num_heads=12, num_layers=12
)
pretrained_state_dict = pretrained_bert.state_dict()

#pretrained_keys = pretrained_state_dict.keys()
#my_keys = my_bert.state_dict().keys()
#print(set(my_keys) - set(pretrained_keys))

# we could load these but actually we want to be fine-tuning this stuff in w2d2
# so let's not take these 
del pretrained_state_dict["classification_head.weight"]
del pretrained_state_dict["classification_head.bias"]

my_bert.load_state_dict(pretrained_state_dict, strict=False)
#bert_tests.test_same_output(my_bert, pretrained_bert)

_IncompatibleKeys(missing_keys=['classification_head.weight', 'classification_head.bias'], unexpected_keys=[])