# Transformers — Try it in PyTorch

This is an **optional** hands-on companion to [Chapter 8](https://learnai.robennals.org/transformers). You'll build a complete transformer from scratch, train it on simple sequences, and then train one on real text to generate children's stories.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

## One Transformer Block

Each block does two things:
1. **Attention**: every word looks at every other word and gathers relevant context
2. **Feed-forward network**: a small neural network that processes what attention gathered

Plus a **shortcut connection** (residual) that adds the input directly to the output, so the block only needs to learn *refinements*.

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, n_heads, ff_dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, n_heads, batch_first=True)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
        )
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x, mask=None):
        # Step 1: Attention + residual
        attn_out, attn_weights = self.attn(x, x, x, attn_mask=mask)
        x = self.ln1(x + attn_out)    # shortcut connection!
        # Step 2: Feed-forward + residual
        ff_out = self.ff(x)
        x = self.ln2(x + ff_out)      # another shortcut connection
        return x, attn_weights

# Test with random input
torch.manual_seed(42)
block = TransformerBlock(embed_dim=16, n_heads=2, ff_dim=32)
x = torch.randn(1, 5, 16)  # batch=1, seq_len=5, embed_dim=16

with torch.no_grad():
    out, weights = block(x)

print(f"Input shape:  {x.shape}")
print(f"Output shape: {out.shape}  (same! the block refines, it doesn't change the shape)")
print(f"Attention weights: {weights.shape}  ({weights.shape[1]}×{weights.shape[2]} attention matrix)")
print(f"\nThe shortcut connection means the output = input + refinements.")
print(f"This is why you can stack many blocks without losing information.")

## A Complete Transformer

Stack the pieces: embedding → position encoding → N transformer blocks → linear output. This is the full architecture behind GPT, Claude, and every modern language model.

In [None]:
class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, n_heads, ff_dim, n_layers, max_len=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Embedding(max_len, embed_dim)
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, n_heads, ff_dim)
            for _ in range(n_layers)
        ])
        self.ln_final = nn.LayerNorm(embed_dim)
        self.output = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        seq_len = x.shape[1]
        # Embed tokens and add position
        tok_emb = self.embedding(x)
        pos_emb = self.pos_embedding(torch.arange(seq_len, device=x.device))
        h = tok_emb + pos_emb

        # Causal mask: each position can only attend to itself and earlier positions
        mask = torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1)

        # Pass through transformer blocks
        all_weights = []
        for block in self.blocks:
            h, w = block(h, mask=mask)
            all_weights.append(w)

        h = self.ln_final(h)
        logits = self.output(h)  # (batch, seq_len, vocab_size)
        return logits, all_weights

# Build a tiny transformer
model = TinyTransformer(
    vocab_size=16,
    embed_dim=32,
    n_heads=2,
    ff_dim=64,
    n_layers=2,
)

n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {n_params:,}")
print(f"Architecture: embed(32) → 2 blocks × (2-head attention + FFN(64)) → output(16)")
print(f"\nThis is the same architecture as GPT — just much, much smaller!")

## Training on a Simple Pattern

Let's train our tiny transformer to learn a simple pattern: sequences where each token copies the one 2 positions back. This is easy enough to verify but requires the model to actually attend to the right position.

In [None]:
# Generate training data: each token copies the one 2 positions back
# e.g., [3, 7, 3, 7, 3, 7, ...]
def make_copy_data(n_samples=500, seq_len=8, vocab_size=16):
    data = []
    for _ in range(n_samples):
        seq = torch.randint(0, vocab_size, (2,)).tolist()
        for _ in range(seq_len - 2):
            seq.append(seq[-2])  # copy from 2 positions back
        data.append(seq)
    return torch.tensor(data)

torch.manual_seed(42)
data = make_copy_data()

print("Example sequences (each token copies 2 positions back):")
for i in range(5):
    print(f"  {data[i].tolist()}")

# Train
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
losses = []

for epoch in range(100):
    # Input is all tokens except the last; target is all tokens except the first
    x = data[:, :-1]     # input
    y = data[:, 1:]      # target (shifted by 1)

    logits, _ = model(x)
    loss = F.cross_entropy(logits.reshape(-1, 16), y.reshape(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

    if (epoch + 1) % 25 == 0:
        print(f"Epoch {epoch+1}: loss = {loss.item():.4f}")

plt.figure(figsize=(8, 3))
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training loss (copy-2-back task)")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Test: does the model learn the pattern?
model.eval()
test = make_copy_data(n_samples=10, seq_len=10)

with torch.no_grad():
    logits, attn_weights = model(test[:, :-1])
    preds = logits.argmax(dim=-1)

# Check accuracy on positions 2+ (where the copy pattern applies)
targets = test[:, 1:]
correct = (preds[:, 2:] == targets[:, 2:]).float().mean()
print(f"Accuracy on copy-2-back positions: {correct.item():.1%}")

print("\nExample predictions:")
for i in range(3):
    print(f"  Input:  {test[i, :-1].tolist()}")
    print(f"  Target: {targets[i].tolist()}")
    print(f"  Pred:   {preds[i].tolist()}")
    print()

In [None]:
# Visualize attention: the model should attend to 2 positions back
with torch.no_grad():
    _, attn = model(test[:1, :-1])

fig, axes = plt.subplots(1, len(attn), figsize=(5 * len(attn), 4))
if len(attn) == 1:
    axes = [axes]

for layer_idx, layer_attn in enumerate(attn):
    # Average across heads
    avg_attn = layer_attn[0].mean(dim=0).numpy()
    ax = axes[layer_idx]
    ax.imshow(avg_attn, cmap="Blues", vmin=0)
    ax.set_xlabel("Attending to position")
    ax.set_ylabel("From position")
    ax.set_title(f"Layer {layer_idx + 1} attention (avg over heads)")

plt.suptitle("The model learns to attend to 2 positions back", y=1.02)
plt.tight_layout()
plt.show()

print("Look for a diagonal pattern shifted by 2 — that's the model looking")
print("back 2 positions to find the token it should copy.")

## From Patterns to Stories

The same architecture, trained on real text instead of symbol patterns, learns language. Let's train on a small corpus of children's story beginnings and watch it generate text.

This is a *tiny* model on a *tiny* dataset — the output won't be Shakespeare. But you'll see the same predict-the-next-token loop that powers ChatGPT and Claude.

In [None]:
# A small corpus of story beginnings (character-level for simplicity)
stories = [
    "once upon a time there was a little cat who loved to play",
    "once upon a time there was a little dog who loved to run",
    "once upon a time there was a little girl named lily",
    "once upon a time there was a little boy named tom",
    "the cat sat on the mat and looked at the bird",
    "the dog ran in the park and played with the ball",
    "she was very happy because she found a new friend",
    "he was very sad because he lost his toy",
    "the little girl smiled and said hello to everyone",
    "the little boy laughed and ran around the garden",
    "one day the cat found a box and climbed inside",
    "one day the dog found a bone and was very happy",
    "they played together all day and had so much fun",
    "the sun was shining and the birds were singing",
    "she loved her cat and her cat loved her",
    "he loved his dog and his dog loved him",
]

# Build a word-level vocabulary
all_words = sorted(set(w for s in stories for w in s.split()))
word2id = {w: i + 1 for i, w in enumerate(all_words)}  # 0 = padding
word2id["<pad>"] = 0
id2word = {i: w for w, i in word2id.items()}
vocab_size = len(word2id)
print(f"Vocabulary: {vocab_size} words")

# Tokenize stories
max_len = max(len(s.split()) for s in stories)
tokenized = torch.zeros(len(stories), max_len, dtype=torch.long)
for i, story in enumerate(stories):
    ids = [word2id[w] for w in story.split()]
    tokenized[i, :len(ids)] = torch.tensor(ids)

print(f"Max sequence length: {max_len}")
print(f"Example: {stories[0][:40]}...")
print(f"Tokens:  {tokenized[0, :8].tolist()}")

In [None]:
# Build and train a small transformer on stories
torch.manual_seed(42)
story_model = TinyTransformer(
    vocab_size=vocab_size,
    embed_dim=32,
    n_heads=2,
    ff_dim=64,
    n_layers=2,
    max_len=max_len,
)

optimizer = torch.optim.Adam(story_model.parameters(), lr=0.003)
losses = []

for epoch in range(500):
    x = tokenized[:, :-1]
    y = tokenized[:, 1:]

    logits, _ = story_model(x)
    # Ignore padding tokens in the loss
    loss = F.cross_entropy(logits.reshape(-1, vocab_size), y.reshape(-1), ignore_index=0)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

    if (epoch + 1) % 100 == 0:
        print(f"Epoch {epoch+1}: loss = {loss.item():.4f}")

plt.figure(figsize=(8, 3))
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training loss (story generation)")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Auto-regressive Generation

Now the fun part: predict → append → repeat. This is exactly how every large language model generates text.

In [None]:
def generate(model, prompt, max_new_tokens=15, temperature=0.8):
    """Generate text auto-regressively: predict, append, repeat."""
    model.eval()
    tokens = [word2id[w] for w in prompt.split() if w in word2id]
    ids = torch.tensor([tokens])

    with torch.no_grad():
        for _ in range(max_new_tokens):
            logits, _ = model(ids)
            # Take prediction for the last position
            next_logits = logits[0, -1] / max(temperature, 0.01)
            probs = F.softmax(next_logits, dim=-1)
            next_id = torch.multinomial(probs, 1)
            if next_id.item() == 0:  # stop on padding
                break
            ids = torch.cat([ids, next_id.unsqueeze(0)], dim=1)

    return " ".join(id2word[i.item()] for i in ids[0] if i.item() != 0)

# Generate from different prompts
prompts = ["once upon a", "the little", "she was", "one day the"]

for prompt in prompts:
    print(f"Prompt: '{prompt}'")
    for i in range(3):
        text = generate(story_model, prompt, temperature=0.7)
        print(f"  → {text}")
    print()

In [None]:
# Temperature comparison
prompt = "once upon a time"
print(f"Prompt: '{prompt}'\n")

for temp in [0.1, 0.5, 1.0, 2.0]:
    print(f"Temperature {temp}:")
    for _ in range(3):
        text = generate(story_model, prompt, temperature=temp)
        print(f"  {text}")
    print()

print("Low temperature → repetitive and safe.")
print("High temperature → creative but chaotic.")
print("This is the same trade-off in ChatGPT and Claude!")

## The Power of Scale

Our tiny model has a few thousand parameters and was trained on 16 sentences. GPT-4 has over a **trillion** parameters trained on trillions of words. The architecture is the same — embed, attend, transform, predict — the difference is just scale.

Let's see how even a small increase in model size helps:

In [None]:
configs = [
    ("Tiny",   16, 1, 32,  1),
    ("Small",  32, 2, 64,  2),
    ("Medium", 64, 4, 128, 3),
]

print(f"{'Name':<8} {'Embed':>6} {'Heads':>6} {'FF':>6} {'Layers':>7} {'Params':>10}")
print("-" * 50)
for name, emb, heads, ff, layers in configs:
    m = TinyTransformer(vocab_size, emb, heads, ff, layers)
    n = sum(p.numel() for p in m.parameters())
    print(f"{name:<8} {emb:>6} {heads:>6} {ff:>6} {layers:>7} {n:>10,}")

print(f"\nGPT-3: 175,000,000,000 parameters")
print(f"GPT-4: ~1,800,000,000,000 parameters")
print(f"\nSame architecture. Same training task (predict the next token).")
print(f"Scale is what turns a toy into an AI that can write, reason, and code.")

---

*This notebook accompanies [Chapter 8: Transformers](https://learnai.robennals.org/transformers). The interactive widgets in the web version let you step through a toy transformer, explore attention patterns on real stories, and generate text with temperature control.*

*New to PyTorch? See the [PyTorch from Scratch](https://learnai.robennals.org/appendix-pytorch) appendix for a beginner-friendly introduction.*