In [1]:
import torch as t
import torch.nn as nn
from torch import einsum
from einops import rearrange, reduce, repeat
import bert_tests

In [2]:
"""
Q = W_Q @ input
K = W_K @ input
V = W_V @ input
attn_pat = normalised_softmax(Q @ K^T)
attention = attn_patn @ V
O = W_O @ attention
"""

'\nQ = W_Q @ input\nK = W_K @ input\nV = W_V @ input\nattn_pat = normalised_softmax(Q @ K^T)\nattention = attn_patn @ V\nO = W_O @ attention\n'

In [3]:
def raw_attention_pattern(token_activations, num_heads, project_query, project_key):
    dk = t.tensor(64) # num heads
    # print(token_activations.shape) # [batch_size, (num_heads), input_length, hidden_size]
    Q = project_query(token_activations) # W_Q: [hidden_size, num_heads * head_size], [batch_size, input_length, num_heads * head_size]
    K = project_key(token_activations)
    # print(project_query, project_key)
    # print(Q.shape, K.shape)
    Q = rearrange(Q, "... n (h s) -> ... h n s", h = num_heads)
    K = rearrange(K, "... n (h s) -> ... h n s", h = num_heads)
    res = t.einsum('...qc,...kc -> ...kq', Q, K)/t.sqrt(dk)
    # print(res.shape)
    return res

bert_tests.test_attention_pattern_fn(raw_attention_pattern)

attention pattern raw MATCH!!!!!!!!
 SHAPE (2, 12, 3, 3) MEAN: 0.007844 STD: 0.1089 VALS [0.1368 0.02227 0.01142 -0.06201 -0.04411 0.09608 -0.04278 -0.04786 0.05769 -0.1033...]


In [4]:
def bert_attention(token_activations, num_heads, attention_pattern, project_value, project_output):
    softmaxed = t.nn.functional.softmax(attention_pattern, dim=-2)
    V = project_value(token_activations)
    V = rearrange(V, "... n (h s) -> ... h n s", h=num_heads)
    #print(V.shape, softmaxed.shape)
    #print((softmaxed @ V).shape)
    res = project_output(rearrange(t.einsum("...htf,...hts->...hfs", softmaxed, V), "... h n s -> ... n (h s)"))
    #print(res.shape)
    return res

bert_tests.test_attention_fn(bert_attention)

attention MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.001597 STD: 0.1131 VALS [-0.1304 -0.002212 0.04429 0.1036 0.1437 -0.000659 -0.2454 0.0493 0.04145 -0.09295...]


In [5]:
class MultiHeadedSelfAttention(t.nn.Module):
    def __init__(self, num_heads, hidden_size):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.pattern = nn.ModuleDict({
            'project_query': nn.Linear(hidden_size, hidden_size),
            'project_key': nn.Linear(hidden_size, hidden_size)
        })
        self.project_value = nn.Linear(hidden_size, hidden_size)
        self.project_out = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, x):
        # print(x.shape, self.num_heads, self.hidden_size)
        attn_pattern = raw_attention_pattern(x, self.num_heads, self.pattern['project_query'], self.pattern['project_key'])
        return bert_attention(x, self.num_heads, attn_pattern, self.project_value, self.project_out)



In [6]:
def bert_mlp(token_activations, linear_1, linear_2):
    return linear_2(t.nn.functional.gelu(linear_1(token_activations)))

bert_tests.test_bert_mlp(bert_mlp)

bert mlp MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -0.002911 STD: 0.1037 VALS [0.1037 -0.06468 -0.1908 0.1056 -0.1703 -0.1911 0.001357 -0.008986 0.09983 0.07543...]


In [7]:
class BertMLP(nn.Module):
    def __init__(self, input_size: int, intermediate_size: int) -> None:
        super().__init__()
        self.mlp1 = nn.Linear(input_size, intermediate_size)
        self.mlp2 = nn.Linear(intermediate_size, input_size)
        self.layer_norm = LayerNorm(input_size)
    
    def forward(self, x):
        return self.layer_norm(bert_mlp(x, self.mlp1, self.mlp2) + x)

In [8]:
class LayerNorm(nn.Module):
    EPS = 1e-5
    def __init__(self, normalized_dim: int) -> None:
        super().__init__()
        self.weight = nn.Parameter(t.ones(normalized_dim))
        self.bias = nn.Parameter(t.zeros(normalized_dim))
    
    def forward(self, x):
        # print(x.shape)
        x = x - x.mean(dim=-1, keepdim=True).detach()
        x = x/(x.var(dim=-1, unbiased=False, keepdim=True).detach() + self.EPS).sqrt()
        # print(self.weight.shape, self.bias.shape, x.shape)
        # print(self.weight)
        # print(self.bias)
        # print(x)
        # print(t.einsum('...i,i->...i', x, self.weight) + self.bias)
        return t.einsum('...i,i->...i', x, self.weight) + self.bias

bert_tests.test_layer_norm(LayerNorm)

layer norm MATCH!!!!!!!!
 SHAPE (20, 10) MEAN: 4.768e-09 STD: 1.003 VALS [1.399 -0.9391 -1.153 -0.4013 -0.01391 0.9618 -0.8159 1.842 -0.8274 -0.05183...]


In [9]:
class BertBlock(nn.Module):
    def __init__(self, hidden_size : int, intermediate_size : int, num_heads : int, dropout : float):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_heads = num_heads
        self.dropout = dropout
        self.layer_norm = LayerNorm(hidden_size)
        self.attention = MultiHeadedSelfAttention(num_heads, hidden_size)
        self.residual = BertMLP(hidden_size, intermediate_size)
        self.dropout_dist = t.distributions.bernoulli.Bernoulli(1 - self.dropout)
    
    def forward(self, x):
        y = self.attention(x)
        y = self.layer_norm(y + x)
        z = self.residual(y)
        if self.training == True:
            z = z * self.dropout_dist(z.shape[-1]) / (1 - self.dropout)
        return z

bert_tests.test_bert_block(BertBlock)

bert MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -5.381e-09 STD: 1 VALS [0.007132 -0.04372 0.6502 -0.5972 -1.097 0.7267 0.1275 -0.6035 -0.2226 0.2145...]


In [10]:
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")
print(tokenizer(['Hello, I am a sentence.']))


{'input_ids': [[101, 8667, 117, 146, 1821, 170, 5650, 119, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1]]}


In [11]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.weight = nn.Parameter(t.randn(vocab_size, embed_size))

    def forward(self, input):
        onehot = t.nn.functional.one_hot(input, self.vocab_size).float()
        return onehot @ self.weight

bert_tests.test_embedding(Embedding)

embedding MATCH!!!!!!!!
 SHAPE (2, 3, 5) MEAN: -0.06748 STD: 1.062 VALS [1.176 -0.1914 0.8212 1.047 -0.481 0.7106 -1.304 -1.307 -0.438 -0.2764...]


In [12]:
def bert_embedding(input_ids, token_type_ids, position_embedding, token_embedding, token_type_embedding, layer_norm, dropout):
    embeddings = token_embedding(input_ids)
    embeddings += position_embedding(t.arange(0, input_ids.shape[-1]).to(input_ids.device))
    embeddings += token_type_embedding(token_type_ids)
    return dropout(layer_norm(embeddings))

bert_tests.test_bert_embedding_fn(bert_embedding)


bert embedding MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: 4.967e-09 STD: 1 VALS [-1.319 -0.4378 -2.074 0.9679 0.9274 1.479 -0.501 -1.9 -0.212 0.7961...]


In [13]:
class BertEmbedding(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int, max_position_embeddings: int, type_vocab_size: int, dropout: float) -> None:
        super().__init__()
        self.token_embedding = Embedding(vocab_size, hidden_size)
        self.position_embedding = Embedding(max_position_embeddings, hidden_size)
        self.token_type_embedding = Embedding(type_vocab_size, hidden_size)
        self.layer_norm = LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input_ids, token_type_ids):
        return bert_embedding(input_ids, token_type_ids, self.position_embedding, self.token_embedding, self.token_type_embedding, self.layer_norm, self.dropout)

bert_tests.test_bert_embedding(BertEmbedding)

bert embedding MATCH!!!!!!!!
 SHAPE (2, 3, 768) MEAN: -2.897e-09 STD: 1 VALS [-0.009385 -0.4919 0.9852 -0.3535 -3.624 1.333 1.163 1.449 1.063 0.246...]


In [14]:
class Bert(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int, 
max_position_embeddings: int, type_vocab_size: int, 
dropout: float, intermediate_size: int, num_heads: int, 
num_layers: int
) -> None:
        super().__init__()
        self.embedding = BertEmbedding(vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout)
        self.transformer = nn.Sequential(
            *[BertBlock(hidden_size, intermediate_size, num_heads, dropout) for _ in range(num_layers)],
        )
        self.lm_head = nn.ModuleDict({
            'mlp': nn.Linear(hidden_size, hidden_size),
            'gelu': nn.GELU(),
            'unembedding': nn.Linear(hidden_size, vocab_size),
            'layer_norm': LayerNorm(hidden_size),
        })
    
    def forward(self, input_ids):
        x = self.embedding(input_ids, t.zeros_like(input_ids))
        x = self.transformer(x)
        x = self.lm_head['mlp'](x)
        x = self.lm_head['gelu'](x)
        x = self.lm_head['layer_norm'](x)
        x = self.lm_head['unembedding'](x)
        return x

bert_tests.test_bert(Bert)

bert MATCH!!!!!!!!
 SHAPE (1, 4, 28996) MEAN: 0.003031 STD: 0.5765 VALS [-0.5742 -0.432 0.1186 -0.7165 -0.5261 0.4967 1.223 0.3165 -0.3247 -0.5716...]


In [15]:
my_bert = Bert(
    vocab_size=28996, hidden_size=768, max_position_embeddings=512, 
    type_vocab_size=2, dropout=0.1, intermediate_size=3072, 
    num_heads=12, num_layers=12
)
pretrained_bert = bert_tests.get_pretrained_bert()

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
pretrained_state_dict = pretrained_bert.state_dict()
del pretrained_state_dict['classification_head.weight']
del pretrained_state_dict['classification_head.bias']
my_bert.load_state_dict(pretrained_state_dict)
bert_tests.test_same_output(my_bert, pretrained_bert)

comparing Berts MATCH!!!!!!!!
 SHAPE (10, 20, 28996) MEAN: -2.732 STD: 2.413 VALS [-5.65 -6.041 -6.096 -6.062 -5.946 -5.777 -5.977 -6.015 -6.028 -5.935...]


In [17]:
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-cased")
uncased_tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased")

encoded = uncased_tokenizer.encode("Hi, my name is bert")
tokenizer.decode(encoded)

'[CLS] colleges 天 largest happened smile donation [SEP]'

In [94]:
def ascii_art_probs(sentence):
    mask_encoding = tokenizer.encode('[MASK]')[1]

    my_bert.eval()

    encoding = t.tensor(tokenizer.encode(sentence))

    logits = my_bert(encoding)[encoding == mask_encoding]

    probs = nn.functional.softmax(logits, dim=-1)
    probs, word_indices = probs.sort(descending=True, dim=-1)

    probs = probs[:, :10]
    word_indices = word_indices[:, :10]

    words = [[tokenizer.decode(word) for word in word_options] for word_options in word_indices]
    if len(words) > 1:
        print("please don't double mask")
    words_with_probs = zip(words[0], probs[0])
    
    sentence_for_display = sentence.replace('[MASK]', '---')
    print(sentence_for_display)
    for word, prob in words_with_probs:
        print(f"%{sentence_for_display.index('---') - 2}.1d%% %s" % (prob * 100, word))
ascii_art_probs("The fish loves to eat [MASK].")

The fish loves to eat ---.
                  17% it
                   9% fish
                   9% them
                   4% meat
                   3% food
                   2% eggs
                   1% honey
                   1% insects
                   1% too
                   1% rice


In [97]:
class CBert(nn.Module):
    def __init__(self, vocab_size: int, hidden_size: int, 
max_position_embeddings: int, type_vocab_size: int, 
dropout: float, intermediate_size: int, num_heads: int, 
num_layers: int, num_classes : int
) -> None:
        super().__init__()
        self.embedding = BertEmbedding(vocab_size, hidden_size, max_position_embeddings, type_vocab_size, dropout)
        self.transformer = nn.Sequential(
            *[BertBlock(hidden_size, intermediate_size, num_heads, dropout) for _ in range(num_layers)],
        )
        self.lm_head = nn.ModuleDict({
            'mlp': nn.Linear(hidden_size, hidden_size),
            'gelu': nn.GELU(),
            'unembedding': nn.Linear(hidden_size, vocab_size),
            'layer_norm': LayerNorm(hidden_size),
        })
        self.classification_head = nn.Linear(hidden_size, num_classes)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input_ids):
        x = self.embedding(input_ids, t.zeros_like(input_ids))
        x = self.transformer(x)
        # x = self.lm_head['mlp'](x)
        # x = self.lm_head['gelu'](x)
        # x = self.lm_head['layer_norm'](x)
        # x = self.lm_head['unembedding'](x)
        x = self.classification_head(self.dropout(x))
        return x

#bert_tests.test_bert(CBert)

In [98]:
import torchtext
data_train, data_test = torchtext.datasets.IMDB(
    root='.data',
    split=('train', 'test')
)

100%|██████████| 84.1M/84.1M [00:02<00:00, 30.4MB/s]


In [100]:
data_train

<torchtext.data.datasets_utils._RawTextIterableDataset at 0x7fa6051bd850>