In [1]:
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F

import tiktoken

In [146]:
from transformers import GPT2LMHeadModel
model_hf = GPT2LMHeadModel.from_pretrained("gpt2") # 124M
sd_hf = model_hf.state_dict() # raw tensors

for k, v in sd_hf.items(): # different parameters inside the model
    print(k, v.shape)

transformer.wte.weight torch.Size([50257, 768])
transformer.wpe.weight torch.Size([1024, 768])
transformer.h.0.ln_1.weight torch.Size([768])
transformer.h.0.ln_1.bias torch.Size([768])
transformer.h.0.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.0.attn.c_attn.bias torch.Size([2304])
transformer.h.0.attn.c_proj.weight torch.Size([768, 768])
transformer.h.0.attn.c_proj.bias torch.Size([768])
transformer.h.0.ln_2.weight torch.Size([768])
transformer.h.0.ln_2.bias torch.Size([768])
transformer.h.0.mlp.c_fc.weight torch.Size([768, 3072])
transformer.h.0.mlp.c_fc.bias torch.Size([3072])
transformer.h.0.mlp.c_proj.weight torch.Size([3072, 768])
transformer.h.0.mlp.c_proj.bias torch.Size([768])
transformer.h.1.ln_1.weight torch.Size([768])
transformer.h.1.ln_1.bias torch.Size([768])
transformer.h.1.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.1.attn.c_attn.bias torch.Size([2304])
transformer.h.1.attn.c_proj.weight torch.Size([768, 768])
transformer.h.1.attn.c_proj.bias 

### Token processing

In [130]:
B = 4 # batch size
T = 8 # sequence length
n_embd = 32
vocab_size = 50257

def get_first_batch(B, T):
    with open('input.txt', 'r') as f: # shakespere dataset
        text = f.read()
    data = text[:10000]
    enc = tiktoken.get_encoding('gpt2')
    tokens = enc.encode(data)
    buf = torch.tensor(tokens[:B*T+1])
    x = buf[:-1].view(B, T)
    y = buf[1:].view(B, T)
    return x, y

idx, targets = get_first_batch(B, T) # (input tokens to embedding layer), (target output tokens to predict)

wte = nn.Embedding(vocab_size, n_embd)
wpe = nn.Embedding(T, n_embd)

tok_emb = wte(idx) # (B, T, n_embd)
pos_emb = wpe(torch.arange(0, T)) # (T, n_embd)
x = tok_emb + pos_emb # (B, T, n_embd) input to the transformer

In [157]:
def single_head_attention(x, head_size):
    T = x.shape[1]

    key = nn.Linear(n_embd, head_size, bias=False)
    query = nn.Linear(n_embd, head_size, bias=False)
    value = nn.Linear(n_embd, head_size, bias=False)

    q = query(x) # (B, T, head_size)
    k = key(x)
    attn_wei = q @ k.transpose(-2, -1) # (B, T, T)
    attn_wei *= head_size**-0.5 # smaller weights makes softmax more diffused/less peaky

    tril = torch.tril(torch.ones(T, T))
    attn_wei = attn_wei.masked_fill(tril == 0, float('-inf')) # autoregressive masking
    attn_wei = F.softmax(attn_wei, dim=-1) # (B, T, T)

    v = value(x) # (B, T, head_size)
    out = attn_wei @ v # (B, T, head_size)
    return out

def multi_head_attention(x, n_embd, n_head):
    head_size = n_embd // n_head
    
    out_heads = [single_head_attention(x, head_size) for _ in range(n_head)]
    out = torch.concat(out_heads, dim=-1) # (B, T, n_embd)

    proj = nn.Linear(n_embd, n_embd)
    out = proj(out) # (B, T, n_embd)
    return out

def transformer_block(x, n_embd, n_head):
    # pre-norm, then mhsa/ffwd, then skip connection add
    ln1 = nn.LayerNorm(n_embd)
    ln2 = nn.LayerNorm(n_embd)
    
    ffwd = nn.Sequential(
        nn.Linear(n_embd, 4*n_embd), # inner layer
        nn.ReLU(),
        nn.Linear(4*n_embd, n_embd)
    )

    out = x + multi_head_attention(ln1(x), n_embd, n_head) # (B, T, n_embd)
    out = out + ffwd(ln2(out)) # (B, T, n_embd)
    return out

# Example usage
B, T, n_embd = 4, 8, 32
x = torch.randn(B, T, n_embd) # input to transformer block
n_head = 2

out = transformer_block(x, n_embd, n_head)
out.shape

torch.Size([4, 8, 32])

### Same code in pytorch module

In [163]:
class SingleHeadAttention(nn.Module):
    def __init__(self, n_embd, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.head_size = head_size

    def forward(self, x):
        T = x.shape[1]
        
        q = self.query(x)  # (B, T, head_size)
        k = self.key(x)
        v = self.value(x)

        attn_wei = q @ k.transpose(-2, -1)  # (B, T, T)
        attn_wei *= self.head_size**-0.5  # smaller weights makes softmax more diffused/less peaky
        tril = torch.tril(torch.ones(T, T))
        attn_wei = attn_wei.masked_fill(tril == 0, float('-inf'))  # autoregressive masking
        attn_wei = F.softmax(attn_wei, dim=-1)  # (B, T, T)
        out = attn_wei @ v  # (B, T, head_size)

        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.n_head = n_head
        self.head_size = n_embd // n_head
        self.heads = nn.ModuleList([SingleHeadAttention(n_embd, self.head_size) for _ in range(n_head)])
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        out_heads = [head(x) for head in self.heads]
        out = torch.cat(out_heads, dim=-1)  # (B, T, n_embd)
        out = self.proj(out)  # (B, T, n_embd)
        return out

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.mhsa = MultiHeadAttention(n_embd, n_head)
        self.ffwd = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),  # inner layer
            nn.GELU(approximate='tanh'),
            nn.Linear(4 * n_embd, n_embd)
        )

    def forward(self, x):
        out = x + self.mhsa(self.ln1(x))  # (B, T, n_embd)
        out = out + self.ffwd(self.ln2(out))  # (B, T, n_embd)
        return out
    
class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layers, T):
        super().__init__()
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(vocab_size, n_embd),
            wpe = nn.Embedding(T, n_embd),
            h = nn.ModuleList([Block(n_embd, n_head) for _ in range(n_layers)]),
            ln_f = nn.LayerNorm(n_embd) # final layernorm before classifier
        ))
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

    def forward(self, idx, targets=None):

        tok_emb = self.transformer.wte(idx) # (B, T, n_embd)
        pos_emb = self.transformer.wpe(torch.arange(0, T)) # (T, n_embd)
        x = tok_emb + pos_emb # (B, T, n_embd)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x) # (B, T, n_embd)
        logits = self.lm_head(x) # (B, T, vocab_size)
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss
        

B, T, n_embd = 4, 8, 32 # B: batch_size, T: max_seq_len
n_head = 2
n_layers = 4 # number of transformer blocks
vocab_size = 1000

x = torch.randn(B, T, n_embd)  # input to transformer block
transformer_block = Block(n_embd, n_head)
out = transformer_block(x)
print(out.shape)  # (B, T, n_embd)

model = GPTLanguageModel(vocab_size, n_embd, n_head, n_layers, T)
idx = torch.randint(0, vocab_size, (B, T))
targets = torch.randint(0, vocab_size, (B, T))

logits, loss = model(idx, targets)
logits.shape, loss.item()


torch.Size([4, 8, 32])


(torch.Size([4, 8, 1000]), 7.085666656494141)