In [1]:
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

import tiktoken

In [3]:
from transformers import GPT2LMHeadModel
model_hf = GPT2LMHeadModel.from_pretrained("gpt2") # 124M
sd_hf = model_hf.state_dict() # raw tensors

for k, v in sd_hf.items(): # different parameters inside the model
    print(k, v.shape)

transformer.wte.weight torch.Size([50257, 768])
transformer.wpe.weight torch.Size([1024, 768])
transformer.h.0.ln_1.weight torch.Size([768])
transformer.h.0.ln_1.bias torch.Size([768])
transformer.h.0.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.0.attn.c_attn.bias torch.Size([2304])
transformer.h.0.attn.c_proj.weight torch.Size([768, 768])
transformer.h.0.attn.c_proj.bias torch.Size([768])
transformer.h.0.ln_2.weight torch.Size([768])
transformer.h.0.ln_2.bias torch.Size([768])
transformer.h.0.mlp.c_fc.weight torch.Size([768, 3072])
transformer.h.0.mlp.c_fc.bias torch.Size([3072])
transformer.h.0.mlp.c_proj.weight torch.Size([3072, 768])
transformer.h.0.mlp.c_proj.bias torch.Size([768])
transformer.h.1.ln_1.weight torch.Size([768])
transformer.h.1.ln_1.bias torch.Size([768])
transformer.h.1.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.1.attn.c_attn.bias torch.Size([2304])
transformer.h.1.attn.c_proj.weight torch.Size([768, 768])
transformer.h.1.attn.c_proj.bias 

### Token processing

In [244]:
 B = 4 # batch size
T = 8 # sequence length
n_embd = 32
vocab_size = 50257

def get_first_batch(B, T):
    with open('input.txt', 'r') as f: # shakespere dataset
        text = f.read()
    data = text[:10000]
    enc = tiktoken.get_encoding('gpt2')
    tokens = enc.encode(data)
    buf = torch.tensor(tokens[:B*T+1])
    x = buf[:-1].view(B, T)
    y = buf[1:].view(B, T)
    return x, y

idx, targets = get_first_batch(B, T)
wte = nn.Embedding(vocab_size, n_embd)
wpe = nn.Embedding(T, n_embd)

tok_emb = wte(idx) # (B, T, n_embd)
pos_emb = wpe(torch.arange(0, T)) # (T, n_embd)
x = tok_emb + pos_emb # (B, T, n_embd)
x.shape

torch.Size([4, 8, 32])

In [200]:
def single_head_attention(x, head_size):
    T = x.shape[1]

    key = nn.Linear(n_embd, head_size, bias=False)
    query = nn.Linear(n_embd, head_size, bias=False)
    value = nn.Linear(n_embd, head_size, bias=False)

    q = query(x) # (B, T, head_size)
    k = key(x)
    attn_wei = q @ k.transpose(-2, -1) # (B, T, T)
    attn_wei *= head_size**-0.5 # smaller weights makes softmax more diffused/less peaky

    tril = torch.tril(torch.ones(T, T))
    attn_wei = attn_wei.masked_fill(tril == 0, float('-inf')) # autoregressive masking
    attn_wei = F.softmax(attn_wei, dim=-1) # (B, T, T)

    v = value(x) # (B, T, head_size)
    out = attn_wei @ v # (B, T, head_size)
    return out

def multi_head_attention(x, n_embd, n_head):
    head_size = n_embd // n_head
    
    out_heads = [single_head_attention(x, head_size) for _ in range(n_head)]
    out = torch.concat(out_heads, dim=-1) # (B, T, n_embd)

    proj = nn.Linear(n_embd, n_embd)
    out = proj(out) # (B, T, n_embd)
    return out

def transformer_block(x, n_embd, n_head):
    
    ln1 = nn.LayerNorm(n_embd)
    ln2 = nn.LayerNorm(n_embd)
    
    ffwd = nn.Sequential(
        nn.Linear(n_embd, 4*n_embd), # inner layer
        nn.ReLU(),
        nn.Linear(4*n_embd, n_embd)
    )

    # pre-norm, then mhsa/ffwd, then skip connection add
    out = x + multi_head_attention(ln1(x), n_embd, n_head) # (B, T, n_embd)
    out = out + ffwd(ln2(out)) # (B, T, n_embd)
    return out

# Example usage
B, T, n_embd = 4, 8, 32
x = torch.randn(B, T, n_embd) # input to transformer block
n_head = 2

out = transformer_block(x, n_embd, n_head)
out.shape

torch.Size([4, 8, 32])

In [194]:
# BatchNorm LayerNorm
batch_size, num_features = 10, 5
x = np.random.randn(batch_size, num_features) * 10 - 3

gamma_bn = np.ones((num_features,))
beta_bn = np.zeros((num_features,))
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
x_normalized_bn = (x - batch_mean)# / np.sqrt(batch_var + 1e-5)
out_bn = gamma_bn * x_normalized_bn + beta_bn

gamma_ln = np.ones((num_features,))
beta_ln = np.zeros((num_features,))
feature_mean = np.mean(x, axis=-1, keepdims=True)
feature_var = np.var(x, axis=-1, keepdims=True)
x_normalized_ln = (x - feature_mean)# / np.sqrt(feature_var + 1e-5)
out_ln = gamma_ln * x_normalized_ln + beta_ln


### Same code in pytorch module

In [12]:
class SingleHeadAttention(nn.Module):
    def __init__(self, n_embd, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.head_size = head_size

    def forward(self, x):
        T = x.shape[1]
        
        q = self.query(x)  # (B, T, head_size)
        k = self.key(x)
        v = self.value(x)

        attn_wei = q @ k.transpose(-2, -1)  # (B, T, T)
        attn_wei *= self.head_size**-0.5  # smaller weights makes softmax more diffused/less peaky
        tril = torch.tril(torch.ones(T, T))
        attn_wei = attn_wei.masked_fill(tril == 0, float('-inf'))  # autoregressive masking
        attn_wei = F.softmax(attn_wei, dim=-1)  # (B, T, T)
        out = attn_wei @ v  # (B, T, head_size)

        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.n_head = n_head
        self.head_size = n_embd // n_head
        self.heads = nn.ModuleList([SingleHeadAttention(n_embd, self.head_size) for _ in range(n_head)])
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        out_heads = [head(x) for head in self.heads]
        out = torch.cat(out_heads, dim=-1)  # (B, T, n_embd)
        out = self.proj(out)  # (B, T, n_embd)
        return out

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.mhsa = MultiHeadAttention(n_embd, n_head)
        self.ffwd = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),  # inner layer
            nn.GELU(approximate='tanh'),
            nn.Linear(4 * n_embd, n_embd)
        )

    def forward(self, x):
        out = x + self.mhsa(self.ln1(x))  # (B, T, n_embd)
        out = out + self.ffwd(self.ln2(out))  # (B, T, n_embd)
        return out
    
class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layers, T):
        super().__init__()
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(vocab_size, n_embd),
            wpe = nn.Embedding(T, n_embd),
            h = nn.ModuleList([Block(n_embd, n_head) for _ in range(n_layers)]),
            ln_f = nn.LayerNorm(n_embd) # final layernorm before classifier
        ))
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

    def forward(self, idx, targets=None):

        tok_emb = self.transformer.wte(idx) # (B, T, n_embd)
        pos_emb = self.transformer.wpe(torch.arange(0, T)) # (T, n_embd)
        x = tok_emb + pos_emb # (B, T, n_embd)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x) # (B, T, n_embd)
        logits = self.lm_head(x) # (B, T, vocab_size)
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) # (B*T, vocab_size) , (B*T)
        return logits, loss
        

B, T, n_embd = 4, 8, 32
n_head = 2
n_layers = 4 # number of transformer blocks
vocab_size = 5

# model = GPTLanguageModel(vocab_size, n_embd, n_head, n_layers, T)
idx = torch.randint(0, vocab_size, (B, T))
targets = torch.randint(0, vocab_size, (B, T))
# logits, loss = model(idx, targets)
# logits.shape, loss.item()

transformer = nn.ModuleDict(dict(
    wte = nn.Embedding(vocab_size, n_embd),
    wpe = nn.Embedding(T, n_embd),
    h = nn.ModuleList([Block(n_embd, n_head) for _ in range(n_layers)]),
    ln_f = nn.LayerNorm(n_embd) # final layernorm before classifier
))
lm_head = nn.Linear(n_embd, vocab_size, bias=False)

tok_emb = transformer.wte(idx) # (B, T, n_embd)
pos_emb = transformer.wpe(torch.arange(0, T)) # (T, n_embd)
x = tok_emb + pos_emb # (B, T, n_embd)
for block in transformer.h:
    x = block(x)
x = transformer.ln_f(x) # (B, T, n_embd)
logits = lm_head(x) # (B, T, vocab_size)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) # (B*T, vocab_size) , (B*T)
loss

tensor(1.9355, grad_fn=<NllLossBackward0>)

In [55]:
z = logits.view(-1, logits.size(-1))
y = targets.view(-1)
print(z.shape, y.shape)

loss1 = nn.CrossEntropyLoss()(z, y)
loss2 = F.cross_entropy(z, y)

probs = torch.exp(z) / torch.sum(torch.exp(z), dim=-1, keepdim=True) #torch.softmax(z, dim=-1)
y_one_hot = torch.tensor(np.eye(vocab_size)[y])
correct_class_probs = (probs * y_one_hot).sum(dim=-1)
loss3 = -torch.log(correct_class_probs).mean()

loss1.item(), loss2.item(), loss3.item()

torch.Size([32, 5]) torch.Size([32])


(1.935465931892395, 1.935465931892395, 1.9354659436984933)

### Rotary single head attention

In [37]:
def get_pos_emb(emb_dim, seq_len):
    inv_freq = 1.0 / (10000 ** (torch.arange(0, emb_dim, 2).float() / emb_dim))
    position_ids = torch.arange(seq_len, dtype=torch.float)
    freqs = torch.einsum("i,j->ij", position_ids, inv_freq)
    emb = torch.cat((freqs, freqs), dim=-1)
    return emb.cos(), emb.sin()

def apply_rotary_pos_emb(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def rotate_half(x):
    # rotate_half(q) for a vector q = [q_1, q_2, q_3, q_4] would be [-q_2, q_1, -q_4, q_3]
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


B, T, n_embd = 4, 8, 32 # B: batch_size, T: max_seq_len
n_head = 2
head_size = n_embd // n_head
x = torch.randn(B, T, n_embd)  # input to transformer block

key = nn.Linear(n_embd, head_size, bias=False)
query = nn.Linear(n_embd, head_size, bias=False)
q = query(x)  # (B, T, head_size)
k = key(x)

cos, sin = get_pos_emb(head_size, T)
print(cos.shape, sin.shape)
q, k = apply_rotary_pos_emb(q, k, cos, sin)