# Tiny Transformer (Mini-GPT)
Self-attention, MHA, FFN, residuals, layernorm.


In [None]:
# Setup
!pip -q install torch --upgrade
import math, random, textwrap, requests
import torch, torch.nn as nn, torch.nn.functional as F
from tqdm.auto import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"; torch.manual_seed(1337); print("Device:", device)

# Data
text = requests.get("https://www.gutenberg.org/files/11/11-0.txt").text
text = text[:500_000]
chars = sorted(list(set(text))); vocab_size = len(chars)
stoi = {ch:i for i,ch in enumerate(chars)}; itos = {i:ch for ch,i in stoi.items()}
encode = lambda s: torch.tensor([stoi[c] for c in s], dtype=torch.long)
decode = lambda t: "".join(itos[int(i)] for i in t)
data = torch.tensor([stoi[c] for c in text], dtype=torch.long)
n = int(0.9*len(data)); train_data, val_data = data[:n], data[n:]

batch_size, block_size = 64, 128
n_embd, n_head, n_layer, dropout = 192, 6, 4, 0.1
max_steps, eval_interval, lr = 2500, 250, 3e-4

def get_batch(split):
    src = train_data if split=="train" else val_data
    ix = torch.randint(len(src)-block_size-1, (batch_size,))
    x = torch.stack([src[i:i+block_size] for i in ix])
    y = torch.stack([src[i+1:i+block_size+1] for i in ix])
    return x.to(device), y.to(device)

class Head(nn.Module):
    def __init__(self, n_embd, head_size, block_size, dropout):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x); q = self.query(x)
        wei = q @ k.transpose(-2,-1) / math.sqrt(k.size(-1))
        wei = wei.masked_fill(self.tril[:T,:T]==0, float("-inf"))
        wei = F.softmax(wei, dim=-1); wei = self.dropout(wei)
        v = self.value(x); out = wei @ v
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout):
        super().__init__()
        head_size = n_embd // n_head
        self.heads = nn.ModuleList([Head(n_embd, head_size, block_size, dropout) for _ in range(n_head)])
        self.proj = nn.Linear(n_embd, n_embd); self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out)); return out

class FeedForward(nn.Module):
    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(n_embd, 4*n_embd), nn.GELU(), nn.Linear(4*n_embd, n_embd), nn.Dropout(dropout))
    def forward(self, x): return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout):
        super().__init__()
        self.sa = MultiHeadAttention(n_embd, n_head, block_size, dropout)
        self.ff = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd); self.ln2 = nn.LayerNorm(n_embd)
    def forward(self, x):
        x = x + self.sa(self.ln1(x)); x = x + self.ff(self.ln2(x)); return x

class TinyGPT(nn.Module):
    def __init__(self, vocab_size, n_embd, n_head, n_layer, block_size, dropout):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb   = nn.Embedding(block_size, n_embd)
        self.blocks    = nn.ModuleList([Block(n_embd, n_head, block_size, dropout) for _ in range(n_layer)])
        self.ln_final  = nn.LayerNorm(n_embd)
        self.lm_head   = nn.Linear(n_embd, vocab_size)
        self.block_size= block_size
    def forward(self, idx, targets=None):
        B,T = idx.shape
        tok = self.token_emb(idx)
        pos = self.pos_emb(torch.arange(T, device=idx.device))
        x = tok + pos
        for blk in self.blocks: x = blk(x)
        x = self.ln_final(x)
        logits = self.lm_head(x)
        loss=None
        if targets is not None: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss
    @torch.no_grad()
    def generate(self, idx, max_new_tokens=300, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits,_ = self(idx_cond)
            logits = logits[:,-1,:]/temperature
            if top_k is not None:
                v,_ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:,[-1]]] = -float("inf")
            probs = F.softmax(logits, dim=-1)
            nxt = torch.multinomial(probs, 1)
            idx = torch.cat([idx, nxt], dim=1)
        return idx

model = TinyGPT(vocab_size, n_embd, n_head, n_layer, block_size, dropout).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr)

@torch.no_grad()
def estimate_loss():
    out={}; model.eval()
    for split in ["train","val"]:
        losses=[]; 
        for _ in range(20):
            xb,yb = get_batch(split); _,loss = model(xb,yb); losses.append(loss.item())
        out[split]=sum(losses)/len(losses)
    model.train(); return out

for step in range(max_steps):
    if step % eval_interval == 0:
        l = estimate_loss(); print(f"step {step:4d}: train {l['train']:.3f} | val {l['val']:.3f}")
    xb,yb = get_batch("train"); _,loss = model(xb,yb)
    opt.zero_grad(set_to_none=True); loss.backward(); opt.step()

print("Training done.")

def sample_text(prompt="Alice", max_new_tokens=400, temperature=1.0, top_k=None):
    start = encode(prompt).unsqueeze(0).to(device)
    out = model.generate(start, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)[0].tolist()
    return decode(out)

print("\n=== SAMPLE (temp=1.0) ===\n"); print(textwrap.fill(sample_text("Alice", 400, 1.0, None), width=90))
print("\n=== SAMPLE (temp=0.8, top_k=50) ===\n"); print(textwrap.fill(sample_text("Alice", 400, 0.8, 50), width=90))

# Save checkpoint (optional)
import torch
ckpt = "/content/tinygpt_char_alice.pt"
torch.save({"model": model.state_dict()}, ckpt); print("Saved:", ckpt)
