# 02 - From Bigram to Self-Attention
Following Karpathy's "Let's build GPT from scratch" — building the bigram baseline, then replacing it with single-head self-attention.

Video: https://www.youtube.com/watch?v=kCc8FmEb1nY (up to ~1:21:00)

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# Check device
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using device: {device}')

# Read the dataset
with open('../data/shakespeare/input.txt', 'r') as f:
    text = f.read()

print(f'Length of dataset: {len(text):,} characters')
print(f'First 200 chars:\n{text[:200]}')

Using device: cpu
Length of dataset: 1,115,394 characters
First 200 chars:
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


In [2]:
# Build character vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f'Vocab size: {vocab_size}')
print(''.join(chars))

# Create encode/decode mappings
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Encode the full dataset and split
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
print(f'Train: {len(train_data):,} tokens, Val: {len(val_data):,} tokens')

Vocab size: 65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


Train: 1,003,854 tokens, Val: 111,540 tokens


In [3]:
# Hyperparameters
batch_size = 32
block_size = 8
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
eval_iters = 200
n_embd = 32

torch.manual_seed(1337)

def get_batch(split):
    d = train_data if split == 'train' else val_data
    ix = torch.randint(len(d) - block_size, (batch_size,))
    x = torch.stack([d[i:i+block_size] for i in ix])
    y = torch.stack([d[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

## Part 1: Bigram Model (baseline)
The simplest model — each token predicts the next using only a lookup table. No context, no attention. We train this first to have a baseline loss to beat.

In [4]:
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        # idx and targets are (B,T) tensors of integers
        logits = self.token_embedding_table(idx) # (B,T,C) where C=vocab_size

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B,T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # (B,C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B,C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B,T+1)
        return idx

bigram_model = BigramLanguageModel(vocab_size).to(device)

# check untrained loss — should be ~4.17 = -ln(1/65)
xb, yb = get_batch('train')
logits, loss = bigram_model(xb, yb)
print(f'Untrained loss: {loss.item():.4f} (expected ~{-torch.log(torch.tensor(1/65.0)).item():.4f})')

# generate from untrained model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(bigram_model.generate(context, max_new_tokens=100)[0].tolist()))

Untrained loss: 4.6485 (expected ~4.1744)

p fvLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3!dcbf?pGXepydZJSrF$Jrqt!:wwWSzPN


In [5]:
# Train the bigram model
optimizer = torch.optim.AdamW(bigram_model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss(bigram_model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    xb, yb = get_batch('train')
    logits, loss = bigram_model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

losses = estimate_loss(bigram_model)
print(f"\nFinal: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

step 0: train loss 4.7312, val loss 4.7248


step 500: train loss 4.1794, val loss 4.1795


step 1000: train loss 3.7304, val loss 3.7389


step 1500: train loss 3.3840, val loss 3.3916


step 2000: train loss 3.1244, val loss 3.1282


step 2500: train loss 2.9440, val loss 2.9429


step 3000: train loss 2.8018, val loss 2.8063


step 3500: train loss 2.7125, val loss 2.7126


step 4000: train loss 2.6460, val loss 2.6372


step 4500: train loss 2.5988, val loss 2.6008



Final: train loss 2.5614, val loss 2.5765


In [6]:
# Generate from trained bigram model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print("Bigram model output:")
print(decode(bigram_model.generate(context, max_new_tokens=300)[0].tolist()))

Bigram model output:


MADOY'
'tr thSStlleel, noisuan os :
IN:

ToQ$VOFo?uejQGS:
Imy, thack.
pAl s VJusuer f t tor r athicke hivmispZ;
A
a!?jolo.

Swhy BYORSje ar

FoTowobrt
HENED:
Fas heandbrn mus:
Ty.
Vlly y y.
Twhainis mbCHishadjKIQYO al thangjENCINa!

DUCElerFak'lluerDYo-Spstheco, KNGorDO, te, t jusretand s d basorst


### Loss: 4.17 → 2.58 (bigram baseline)

The biggest single drop in the entire journey! Going from random guessing (4.17 = `-log(1/65)`) to 2.58 just by memorizing which characters follow which. This is pure statistics — "t" is often followed by "h", "q" is almost always followed by "u", spaces tend to come after certain letters.

The generated text is gibberish, but it's *patterned* gibberish. Notice train and val loss track closely — with a 65x65 lookup table, there's nothing to overfit. This is our baseline to beat.

## Part 2: Building toward attention
The bigram model is isolated — each token only knows itself. We want tokens to communicate with past tokens, but **not the future** (that would be cheating during generation). These cells build the intuition step by step: triangular mask → masked fill → softmax → learned Q·K scores.

In [7]:
# step 1: the lower triangular matrix — who can see whom
tril = torch.tril(torch.ones(5, 5))
tril

tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])

In [8]:
wei = torch.zeros((5,5)) # token weights — how much weight to give each past token
wei

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

In [9]:
wei = wei.masked_fill(tril==0, float('-inf')) # future tokens can't communicate to past
wei

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

In [10]:
wei = F.softmax(wei, dim=-1) # exponentiate and normalize — -inf becomes 0, equal scores become equal weights
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])

In [11]:
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time (tokens), channels (embedding dims)

x = torch.randn(B,T,C)

# self attention
# we don't want equal weights — we want to gather info from past tokens selectively.
# every single token emits a query (what am I looking for?) and a key (what do I contain?)
# wei = query dot product key — if they align, they interact with very high weight

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

k = key(x)
q = query(x)

wei = q @ k.transpose(-2,-1) #(B,T,16) @ (B,16,T) --> (B,T,T)

In [12]:
wei[0]

tensor([[-1.7629, -1.3011,  0.5652,  2.1616, -1.0674,  1.9632,  1.0765, -0.4530],
        [-3.3334, -1.6556,  0.1040,  3.3782, -2.1825,  1.0415, -0.0557,  0.2927],
        [-1.0226, -1.2606,  0.0762, -0.3813, -0.9843, -1.4303,  0.0749, -0.9547],
        [ 0.7836, -0.8014, -0.3368, -0.8496, -0.5602, -1.1701, -1.2927, -1.0260],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,  0.8638,  0.3719,  0.9258],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,  1.4187,  1.2196],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,  0.8048],
        [-1.8044, -0.4126, -0.8306,  0.5898, -0.7987, -0.5856,  0.6433,  0.6303]],
       grad_fn=<SelectBackward0>)

In [13]:
tril = torch.tril(torch.ones((T,T))) # lower triangular mask — who can see whom
wei = wei.masked_fill(tril==0, float('-inf')) # future tokens can't communicate to past
wei = F.softmax(wei, dim=-1) # exponentiate and normalize into proper weights
wei

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
         [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
         [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
         [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1687, 0.8313, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2477, 0.0514, 0.7008, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4410, 0.0957, 0.3747, 0.0887, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0069, 0.0456, 0.0300, 0.7748, 0.1427, 0.0000, 0.0000, 0.0000],
         [0.0660, 0.089

## Part 3: Single-Head Self-Attention Model
Now we take the Q, K, V mechanism from above and package it into a proper `Head` module, then build a model that uses it. This replaces the bigram's lookup table with actual attention — tokens can now communicate with past tokens using learned, data-dependent weights.

In [14]:
class Head(nn.Module):
    """One head of self-attention"""

    def __init__(self, head_size):
        super().__init__()
        self.key   = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        # tril is not a parameter — it's a buffer (constant, not trained)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)   # (B,T,head_size) - keys
        q = self.query(x) # (B,T,head_size) - queries

        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * C**-0.5  # (B,T,T), scaled by 1/sqrt(head_size) to keep softmax in useful range
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))  # mask future — decoder block, no peeking ahead
        wei = F.softmax(wei, dim=-1)  # (B,T,T)

        # weighted aggregation of values
        v = self.value(x)  # (B,T,head_size)
        out = wei @ v      # (B,T,head_size)
        return out

In [15]:
class SingleHeadAttentionModel(nn.Module):
    """Model with single-head self-attention — tokens can now look at past context"""

    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)

        # plugin the self attention head
        self.sa_head = Head(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.sa_head(x) # apply one head of self-attention. (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop context to block_size — position embedding only goes up to block_size
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :] # (B,C)
            probs = F.softmax(logits, dim=-1) # (B,C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            idx = torch.cat((idx, idx_next), dim=1) # (B,T+1)
        return idx

model = SingleHeadAttentionModel().to(device)
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')

# check untrained loss — should be ~4.17 = -ln(1/65)
xb, yb = get_batch('train')
logits, loss = model(xb, yb)
print(f'Untrained loss: {loss.item():.4f} (expected ~{-torch.log(torch.tensor(1/65.0)).item():.4f})')

# generate from untrained model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=100)[0].tolist()))

Parameters: 7,553
Untrained loss: 4.2239 (expected ~4.1744)



-Kt,IAguhOyhYSw-lWBP&o:'EE,mqVK:VEvSq!fQIylPMfuw$k'wuyqlc --keXuNMgg?gmznrfnvSDPPO$kJYpOahNuNOdbVgu3


In [16]:
# Train the model with self-attention
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

losses = estimate_loss(model)
print(f"\nFinal: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

step 0: train loss 4.2309, val loss 4.2329


step 500: train loss 2.7667, val loss 2.7862


step 1000: train loss 2.5416, val loss 2.5442


step 1500: train loss 2.4865, val loss 2.4844


step 2000: train loss 2.4390, val loss 2.4558


step 2500: train loss 2.4272, val loss 2.4394


step 3000: train loss 2.4199, val loss 2.4222


step 3500: train loss 2.3946, val loss 2.4186


step 4000: train loss 2.3952, val loss 2.4079


step 4500: train loss 2.3866, val loss 2.4106



Final: train loss 2.3887, val loss 2.4057


In [17]:
# Generate — compare with bigram output above
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print("Self-attention model output:")
print(decode(model.generate(context, max_new_tokens=300)[0].tolist()))

Self-attention model output:

Thabee othy theeandurdeves nd thitle, swhiem dsy an'g ply minds countt--ret brecethe suth the ourerle's rth
eve Ra I:
Ses the ripare nomutre herdin fit thil jorunscot when dle ble wakitheen nd shery!
Pobure imyoth, Clande's anke ithe may thay'schen mn'.
D IUSSoupolilfive tietr brathad PEN:
Ans ar,
W


### Loss: 2.58 → 2.40 (single-head self-attention)

A significant drop! The model went from seeing **one character at a time** (bigram) to using **8 characters of context** with learned attention weights. It now knows that "t" after "th" is different from "t" after "zz."

Look at the generated text — you can see word-shaped blobs forming: "theeandurdeves", "thitle", "swhiem". These aren't real words, but they have the *rhythm* of English. The model is learning that certain sequences of characters tend to appear together. That's what context awareness buys you.

## Part 4: Multi-Head Attention
One attention head learns one notion of "relevant." But language has many simultaneous relationships — a word needs to track grammar, position, meaning, and more at the same time.

The fix: run **4 smaller heads in parallel**, each with 8 dimensions (instead of 1 head with 32 dims). Same total parameters, but each head can specialize in a different pattern. We concatenate their outputs back into 32 dims.

In [18]:
class MultiHeadAttention(nn.Module):
    """multiple heads of self-attention running in parallel"""

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)

In [19]:
class MultiHeadAttentionModel(nn.Module):
    """Model with 4-head attention — multiple perspectives on the same input"""

    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)

        # 4 heads x 8 dims each = 32 dims total (same as single head, but more perspectives)
        self.sa_heads = MultiHeadAttention(4, n_embd//4)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.sa_heads(x) # apply 4 heads of self-attention in parallel (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :] # (B,C)
            probs = F.softmax(logits, dim=-1) # (B,C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            idx = torch.cat((idx, idx_next), dim=1) # (B,T+1)
        return idx

model = MultiHeadAttentionModel().to(device)
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')

# check untrained loss
xb, yb = get_batch('train')
logits, loss = model(xb, yb)
print(f'Untrained loss: {loss.item():.4f} (expected ~{-torch.log(torch.tensor(1/65.0)).item():.4f})')

context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=300)[0].tolist()))

Parameters: 7,553
Untrained loss: 4.2068 (expected ~4.1744)



N!N;D'a'DgaNF
sSpaqGMMXM'T3Mz!?k
mmEHm$evoI&CosEKr'?JSrV!wRMUkOAOPbja$'nPQp-Sv?HB$tspAQ.mkqB:SMpeoX:awfArhDRSouv?JDGeiAyEjpc3rKJlI3rt
EkSpt!rst-CF&mP$Pehv;DsKoE
zm-hMBFXJgsuSlhdgJTepFN!ayfU-UvRv?3Dt?qrlI-VueMVhRgRVK&RGFsAoyYD.XVjf?mv3YFMlB:HACUhuZ
BCmtjS$?.W XwARecSARC;zNM'ws$K$
EgYH$$sXWOD 'BK$e,Pt


In [20]:
# Train the model with self-attention
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

losses = estimate_loss(model)
print(f"\nFinal: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

step 0: train loss 4.2115, val loss 4.2183


step 500: train loss 2.6273, val loss 2.6266


step 1000: train loss 2.4787, val loss 2.4832


step 1500: train loss 2.4181, val loss 2.4350


step 2000: train loss 2.3791, val loss 2.3827


step 2500: train loss 2.3553, val loss 2.3466


step 3000: train loss 2.3168, val loss 2.3331


step 3500: train loss 2.2941, val loss 2.3109


step 4000: train loss 2.2730, val loss 2.2958


step 4500: train loss 2.2695, val loss 2.2838



Final: train loss 2.2540, val loss 2.2887


In [21]:
# Generate — compare with single-head output above
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print("Multi-head attention model output:")
print(decode(model.generate(context, max_new_tokens=300)[0].tolist()))

Multi-head attention model output:



CHet Il I
Hmey doy toud, crecer had kve maveser walteallls, Gics:
Ef ith
iblest entush ot dous an gepook, properardy cea thes I lille, mat mooll ca not to ma cople, lles a eat theomy, pray lut mow's domd,
SLODDRD TEN to goow so ast rroimnot so.

Asto bull my grot tharing, to it looms
Thane nefnin pa


### Loss: 2.40 → 2.28 (multi-head attention)

Splitting one 32-dim head into 4 independent 8-dim heads drops val loss from 2.40 to 2.28 — a solid improvement with **zero additional parameters** (same total dimensions, just redistributed).

Why does this help? Each head can specialize: one might track character pairs, another positional relationships, another capitalization patterns. With a single head, the model has to cram all these relationships into one set of Q/K/V weights. Multiple heads let it learn several relationships in parallel.

Notice the generated text: "Martin" is a real name, speaker labels like `QARENSLOCOFOR:` are appearing (the model learned the colon pattern), and "and" is used correctly. Structure is emerging.

## Part 5: Adding a FeedForward Layer
Attention lets tokens **communicate** — gather information from other tokens. But after gathering, each token needs to **process** what it collected individually. That's the feedforward layer: a simple `Linear → ReLU` that operates on each token independently.

Think of it as: attention = the meeting, feedforward = taking notes afterward.

In [22]:
class FeedForward(nn.Module):
    """a simple linear layer followed by non linearity"""

    def __init__(self, head_size):
        super().__init__()
        self.net   = nn.Sequential(nn.Linear(n_embd, n_embd), nn.ReLU())

    def forward(self, x):
        return self.net(x)

In [23]:
class AttentionWithFeedForwardModel(nn.Module):
    """Multi-head attention + feedforward — tokens can communicate AND think"""

    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)

        self.sa_heads = MultiHeadAttention(4, n_embd//4)  # 4 heads x 8 dims = 32
        self.ffwd = FeedForward(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.sa_heads(x) # communicate: multi-head self-attention (B,T,C)
        x = self.ffwd(x) # think: per-token feedforward (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :] # (B,C)
            probs = F.softmax(logits, dim=-1) # (B,C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            idx = torch.cat((idx, idx_next), dim=1) # (B,T+1)
        return idx

model = AttentionWithFeedForwardModel().to(device)
print(f'Parameters: {sum(p.numel() for p in model.parameters()):,}')

# check untrained loss
xb, yb = get_batch('train')
logits, loss = model(xb, yb)
print(f'Untrained loss: {loss.item():.4f} (expected ~{-torch.log(torch.tensor(1/65.0)).item():.4f})')

context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=300)[0].tolist()))

Parameters: 8,609
Untrained loss: 4.2110 (expected ~4.1744)



wmT,eVb;JdWhE,3gJmCYISN
HtMqWeokLfm?FIc$$kdJ-;Yz$cUf$bcjUDgGVBlPQf&bSCmv-ibDLAGMpUyKiY.l-k$SaL.AaNBwR;sv;rAYRiiVKbtHqRhl'bcxF&cfC?WZUrG;:M e;KlwxaJkEXjgyxYdOPFMiN.AMQX.KwP kYezhPs;oQPcH'pcQt cn3I
I3!O,B
gF&Li$,mGpfpL w;:cBUVIjDYr:syQSl3 zhf;ASNoTysvVgC'yVVXEfcCH ej,pjE huqS:uWEhIrXykSDccyUxOKxrcw 3b


In [24]:
# Train the model with self-attention + feed forward
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

losses = estimate_loss(model)
print(f"\nFinal: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

step 0: train loss 4.2021, val loss 4.2039


step 500: train loss 2.6014, val loss 2.6017


step 1000: train loss 2.4651, val loss 2.4825


step 1500: train loss 2.4011, val loss 2.4011


step 2000: train loss 2.3534, val loss 2.3675


step 2500: train loss 2.3204, val loss 2.3392


step 3000: train loss 2.2892, val loss 2.3171


step 3500: train loss 2.2778, val loss 2.2982


step 4000: train loss 2.2583, val loss 2.2916


step 4500: train loss 2.2313, val loss 2.2854



Final: train loss 2.2344, val loss 2.2614


In [25]:
# Generate — compare with multi-head output above
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print("Attention + feedforward model output:")
print(decode(model.generate(context, max_new_tokens=300)[0].tolist()))

Attention + feedforward model output:



II whart mich meithe ace, I fdeft thour ais antey hit now
Loly,
Loome tist bre,
Loreced thid younre sacte.

Whilesh ih at
LIKINON:
Whe
Hen, aknding:
Head
And of srod,-Rorf the theala coutele: nis aeikemery, thul Tard.

EMVINI HIANIONORD:
I exentest me thugy wit hithge pakfat?
and apill of it gon, is


### Loss: 2.58 → 2.24 (feedforward added)

A modest drop from 2.28 to 2.24 — the feedforward layer adds 1,056 parameters (the extra `Linear` + bias) and lets each token *process* the information it gathered from attention. Without it, attention can blend tokens together but can't do any nonlinear computation on the result.

At this small scale the improvement is subtle, but feedforward is where the majority of parameters live in real transformers (GPT-3's feedforward inner dim is 4 x 12,288 = 49,152). This is the component that unlocks depth — in notebook 03, we'll stack multiple blocks of attention + feedforward and see the loss drop significantly.

**Generated text** is starting to look more structured — real-ish words, punctuation in plausible places, and the colon-after-name pattern (Shakespeare's speaker labels) is becoming more consistent.

## Summary: Loss Progression in This Notebook

| Stage | Model | Val Loss | What Changed |
|-------|-------|----------|-------------|
| Baseline | `BigramLanguageModel` | 2.58 | Lookup table only — no context |
| + Self-attention | `SingleHeadAttentionModel` | 2.40 | Tokens can now see 8 chars of context |
| + Multi-head | `MultiHeadAttentionModel` | 2.28 | 4 parallel perspectives instead of 1 |
| + FeedForward | `AttentionWithFeedForwardModel` | 2.24 | Tokens can process what they gathered |

**Next up (notebook 03):** We add residual connections, LayerNorm, stack 4 transformer blocks, and scale up the model. Loss drops to **1.98**.