In [78]:
import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

In [None]:
n_vocab = 30522
n_ctx = 512
n_layers = 12
n_head = 12

d_emb = 768
d_head = 64
d_mlp = 3072

assert d_head*n_head == d_emb

In [None]:
class MLP(nn.Module):
    def __init__(self, d_emb, d_mlp):
        super().__init__()
        self.w1 = nn.Linear(d_emb, d_mlp)
        self.relu = nn.ReLU()
        self.w2 = nn.Linear(d_mlp, d_emb)
        
    def forward(self, x: torch.Tensor):
        # x: [B, N, d_emb] -> [B, N, d_emb]
        return self.w2(self.relu(self.w1(x)))

class AttentionHead(nn.Module):
    def __init__(self, d_emb, d_head):
        super().__init__()
        self.d_head = d_head
        self.wq = nn.Linear(d_emb, d_head, bias=False)
        self.wk = nn.Linear(d_emb, d_head, bias=False)
        self.wv = nn.Linear(d_emb, d_head, bias=False)
    
    def forward(self, x: torch.Tensor):
        # x: [B, T, d_emb] -> [B, T, d_head]
        q = self.wq(x) # [B, T, d_head]
        k = self.wk(x) # [B, T, d_head]
        v = self.wv(x) # [B, T, d_head]
        A = (q @ k.transpose(-1, -2)) / self.d_head**0.5 # [B, T, T]
        A = F.softmax(A, dim=-1) # softmax along key dimension
        x = A @ v
        return x
        
class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, d_emb, d_head):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(d_emb=d_emb, d_head=d_head) for _ in range(n_head)])
        self.wo = nn.Linear(d_emb, d_emb) # == torch.cat([wo_1, ..., wo_12])
    
    def forward(self, x: torch.Tensor):
        # x: [B, T, d_emb] -> [B, T, d_emb]
        x = torch.cat([h(x) for h in self.heads], dim=-1) # [B, T, d_head*n_head] = [B, T, d_emb]
        x = self.wo(x) # equivalent to `sum(wo_i * h_i(x) for i in range(n_layers))`
        return x

class EfficientMultiHeadAttention(nn.Module):
    # Do everything at the same time!
    def __init__(self, n_head, d_emb, d_head, masked_attention=False):
        super().__init__()
        self.n_head, self.d_head, self.masked_attention = n_head, d_head, masked_attention
        # q,k,v matrices for all heads, stacked along last dim
        self.wqkv = nn.Linear(d_emb, 3*d_emb, bias=False)
        if masked_attention:
            self.register_buffer('mask', torch.tril(torch.ones((1, 1, n_ctx, n_ctx))))
    
    def forward(self, x: torch.Tensor):
        # [B, T, d_emb] -> [B, T, d_emb]
        B, T, d_emb = x.shape
        assert d_emb == self.n_head*self.d_head
        
        # q,k,v: [B, T, d_emb = n_head*d_head]
        q, k, v = self.wqkv(x).split(d_emb, dim=-1) 
        # reshape to q,k,v: [B, n_head, T, d_head], so that we can compute all attention heads as batched mm
        q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2) # [B, n_head, T, d_head]
        k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2) # [B, n_head, T, d_head]
        v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2) # [B, n_head, T, d_head]
        
        A = (q @ k.transpose(-1, -2)) / self.d_head**0.5 # [B, n_head, T, T]
        if self.masked_attention:
            A = torch.masked_fill(A, self.mask[:, :, :T, :T] == 0, float('-inf'))
        A = F.softmax(A, dim=-1)
        
        x = A @ v # [B, n_head, T, d_head]
        x = x.transpose(1, 2).reshape(B, T, d_emb)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, n_head, d_emb, d_mlp, d_head, masked_attention=False):
        super().__init__()
        self.attn = EfficientMultiHeadAttention(n_head=n_head, d_emb=d_emb, d_head=d_head, masked_attention=masked_attention)
        self.mlp = MLP(d_emb=d_emb, d_mlp=d_mlp)
        self.ln1 = nn.LayerNorm(d_emb)
        self.ln2 = nn.LayerNorm(d_emb)
        
    def forward(self, x: torch.Tensor):
        # x: [B, T, d_emb] -> [B, T, d_emb]
        x = self.ln1(self.attn(x) + x)
        x = self.ln2(self.mlp(x) + x)
        return x

class Transformer(nn.Module):
    def __init__(self, n_layers=12, n_head=12, n_ctx=512, d_emb=768, d_head=64, d_mlp=3072, p_dropout=0.1, masked_attention=False):
        super().__init__()
        self.tok_emb = nn.Embedding(n_vocab, d_emb)
        self.pos_emb = nn.Embedding(n_ctx, d_emb)
        self.dropout = nn.Dropout(p_dropout)
        self.blocks = nn.ModuleList([TransformerBlock(n_head=n_head, d_emb=d_emb, d_mlp=d_mlp, d_head=d_head, masked_attention=masked_attention)
                                     for _ in range(n_layers)])
    
    def forward(self, x):
        # x: [B, T] (ints of token ids) -> [B, T, d_emb]
        B, T = x.shape
        xe = self.tok_emb(x) # [B, T, d_emb]
        xp = self.pos_emb(torch.arange(0, T)) # [T, d_emb]
        x = xe + xp # [B, T, d_emb]
        x = self.dropout(x)
        for block in self.blocks:
            x = block(x)
        return x
    
    def num_params(self):
        return sum(p.numel() for p in self.parameters())
        

In [127]:
torch.manual_seed(42)

bert = Transformer(masked_attention=False)
gpt = Transformer(masked_attention=True)

In [151]:
torch.manual_seed(42)

tok = AutoTokenizer.from_pretrained("bert-base-uncased")

text = "Hello, my [MASK] is Bob"
tok_ids = tok(text)["input_ids"]
x = torch.tensor(tok_ids).unsqueeze(0)
print(x, x.shape)

y = gpt(x)

y.shape, y

tensor([[ 101, 7592, 1010, 2026,  103, 2003, 3960,  102]]) torch.Size([1, 8])


(torch.Size([1, 8, 768]),
 tensor([[[ 2.9875,  0.3583, -0.8207,  ..., -0.1635,  0.0303, -1.5485],
          [ 3.2522,  0.0058, -0.3906,  ...,  0.2550,  0.0067, -0.7064],
          [ 3.0829,  0.0514, -0.8678,  ...,  0.1277, -0.2871, -0.8008],
          ...,
          [ 2.5182, -0.1373, -0.9167,  ...,  0.0344,  0.2996, -0.9848],
          [ 2.5908,  0.1701, -0.3427,  ..., -0.2808,  0.2711, -0.5008],
          [ 2.1048,  0.7076, -0.0593,  ..., -1.1217, -0.0201, -0.9383]]],
        grad_fn=<NativeLayerNormBackward0>))

In [111]:
m = torch.tril(torch.ones(1,1,n_ctx, n_ctx))
x = torch.randn((1,1,10,10))
torch.masked_fill(x, m[:, :, :10, :10] == 0, float('-inf'))

tensor([[[[ 1.6273,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
              -inf,    -inf,    -inf],
          [-1.6599, -0.8810,    -inf,    -inf,    -inf,    -inf,    -inf,
              -inf,    -inf,    -inf],
          [-0.6337,  1.3482, -0.5252,    -inf,    -inf,    -inf,    -inf,
              -inf,    -inf,    -inf],
          [ 0.3444,  0.6288, -0.3466,  2.1626,    -inf,    -inf,    -inf,
              -inf,    -inf,    -inf],
          [-1.8394,  1.7015, -0.1831,  1.4169, -0.0463,    -inf,    -inf,
              -inf,    -inf,    -inf],
          [-0.6611,  1.9132, -0.7888,  1.0590, -0.9418, -1.5726,    -inf,
              -inf,    -inf,    -inf],
          [-0.3688,  0.5541,  1.5305, -0.7085, -1.5465, -0.4701,  0.1381,
              -inf,    -inf,    -inf],
          [ 0.1376,  1.5309, -0.0048, -0.7227,  0.7957,  0.3171, -0.9580,
            1.0536,    -inf,    -inf],
          [-2.2465,  0.7787,  0.0297, -0.1052, -1.7899,  0.7464,  0.5787,
           -0.5467,  0