# Minimal Transformer from Scratch

**Learning Objectives:**
1. Build a complete transformer encoder block in PyTorch
2. Implement multi-head attention, feedforward network, and layer normalization
3. Add sinusoidal positional encoding
4. Train on simple sequence tasks (copy, reverse)
5. Visualize attention patterns during training

**Prerequisites:** [transformer-architecture](../transformers/transformer-architecture.md), [multi-head-attention](../transformers/multi-head-attention.md)

**Key Insight:** A transformer block is just: `x + Attention(LayerNorm(x))` followed by `x + FFN(LayerNorm(x))`

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Positional Encoding

Transformers have no inherent notion of position. We add sinusoidal encodings:

$$PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}})$$
$$PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}})$$

In [None]:
class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding."""
    
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Register as buffer (not a parameter, but saved with model)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """
        Args:
            x: (batch_size, seq_len, d_model)
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# Visualize positional encoding
pe = PositionalEncoding(d_model=64, max_len=100, dropout=0.0)
encoding = pe.pe[0, :50, :].numpy()

plt.figure(figsize=(12, 4))
plt.imshow(encoding, cmap='RdBu', aspect='auto')
plt.xlabel('Dimension')
plt.ylabel('Position')
plt.title('Sinusoidal Positional Encoding')
plt.colorbar()
plt.show()

## 2. Multi-Head Attention

Multiple attention heads learn different patterns:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

where each head is: $\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-head self-attention."""
    
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        self.scale = np.sqrt(self.d_k)
        
        # Store attention weights for visualization
        self.attn_weights = None
    
    def forward(self, x, mask=None):
        """
        Args:
            x: (batch_size, seq_len, d_model)
            mask: (seq_len, seq_len) or None
        Returns:
            output: (batch_size, seq_len, d_model)
        """
        batch_size, seq_len, _ = x.shape
        
        # Project to Q, K, V
        Q = self.W_q(x)  # (batch, seq_len, d_model)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Reshape for multi-head: (batch, n_heads, seq_len, d_k)
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Compute attention scores: (batch, n_heads, seq_len, seq_len)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax to get attention weights
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Store for visualization
        self.attn_weights = attn_weights.detach()
        
        # Apply attention to values: (batch, n_heads, seq_len, d_k)
        context = torch.matmul(attn_weights, V)
        
        # Concatenate heads: (batch, seq_len, d_model)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # Final projection
        output = self.W_o(context)
        
        return output

# Test multi-head attention
mha = MultiHeadAttention(d_model=64, n_heads=8)
x = torch.randn(2, 10, 64)  # batch=2, seq_len=10, d_model=64
out = mha(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Attention weights shape: {mha.attn_weights.shape}")

## 3. Position-wise Feed-Forward Network

Two linear layers with activation in between:

$$\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2$$

The inner dimension is typically 4x the model dimension.

In [None]:
class FeedForward(nn.Module):
    """Position-wise feed-forward network."""
    
    def __init__(self, d_model, d_ff=None, dropout=0.1):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        """
        Args:
            x: (batch_size, seq_len, d_model)
        """
        x = self.linear1(x)
        x = F.gelu(x)  # GELU is common in modern transformers
        x = self.dropout(x)
        x = self.linear2(x)
        return x

# Test feed-forward
ff = FeedForward(d_model=64)
x = torch.randn(2, 10, 64)
out = ff(x)
print(f"FFN input shape: {x.shape}")
print(f"FFN output shape: {out.shape}")
print(f"FFN parameters: {sum(p.numel() for p in ff.parameters()):,}")

## 4. Transformer Encoder Block

A single transformer block combines attention and FFN with residual connections and layer norm:

```
x -> LayerNorm -> MultiHeadAttention -> + -> LayerNorm -> FFN -> + -> output
     |____________________________|        |_________________|
              residual                          residual
```

In [None]:
class TransformerBlock(nn.Module):
    """Single transformer encoder block."""
    
    def __init__(self, d_model, n_heads, d_ff=None, dropout=0.1):
        super().__init__()
        
        self.attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        Pre-norm architecture (more stable training):
        x = x + Attention(LayerNorm(x))
        x = x + FFN(LayerNorm(x))
        """
        # Self-attention with residual
        attn_out = self.attention(self.norm1(x), mask)
        x = x + self.dropout1(attn_out)
        
        # Feed-forward with residual
        ff_out = self.feed_forward(self.norm2(x))
        x = x + self.dropout2(ff_out)
        
        return x

# Test transformer block
block = TransformerBlock(d_model=64, n_heads=8)
x = torch.randn(2, 10, 64)
out = block(x)
print(f"Block input shape: {x.shape}")
print(f"Block output shape: {out.shape}")
print(f"Block parameters: {sum(p.numel() for p in block.parameters()):,}")

## 5. Complete Transformer Encoder

Stack multiple blocks with embedding and output layers:

In [None]:
class TransformerEncoder(nn.Module):
    """Complete transformer encoder."""
    
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff=None, 
                 max_len=512, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        
        # Embedding layers
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        
        # Output layer
        self.norm = nn.LayerNorm(d_model)
        self.output = nn.Linear(d_model, vocab_size)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: (batch_size, seq_len) token indices
            mask: optional attention mask
        Returns:
            logits: (batch_size, seq_len, vocab_size)
        """
        # Embed tokens and add positional encoding
        x = self.embedding(x) * np.sqrt(self.d_model)
        x = self.pos_encoding(x)
        
        # Pass through transformer blocks
        for block in self.blocks:
            x = block(x, mask)
        
        # Final norm and output projection
        x = self.norm(x)
        logits = self.output(x)
        
        return logits

# Test complete model
model = TransformerEncoder(
    vocab_size=100,
    d_model=64,
    n_heads=4,
    n_layers=2,
    dropout=0.1
)

x = torch.randint(0, 100, (2, 10))  # batch=2, seq_len=10
logits = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {logits.shape}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

## 6. Task: Sequence Copy

Train the transformer to copy input sequences. A simple task that tests whether the model can learn to pass information through.

In [None]:
class CopyDataset(Dataset):
    """Dataset for sequence copying task."""
    
    def __init__(self, vocab_size, seq_len, num_samples):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.num_samples = num_samples
        
        # Reserve 0 for padding, 1 for separator
        self.data_vocab_size = vocab_size - 2
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # Generate random sequence (values 2 to vocab_size-1)
        seq = torch.randint(2, self.vocab_size, (self.seq_len,))
        
        # Input: [seq, SEP, zeros]
        # Target: [zeros, zeros, seq]
        sep = torch.tensor([1])  # Separator token
        padding = torch.zeros(self.seq_len, dtype=torch.long)
        
        input_seq = torch.cat([seq, sep, padding])
        target_seq = torch.cat([padding, torch.tensor([0]), seq])
        
        return input_seq, target_seq

# Create dataset
vocab_size = 20
seq_len = 8
dataset = CopyDataset(vocab_size, seq_len, num_samples=10000)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Show example
x, y = dataset[0]
print(f"Input:  {x.tolist()}")
print(f"Target: {y.tolist()}")
print(f"\nThe model should learn to copy the first {seq_len} tokens after the separator.")

## 7. Training Loop

In [None]:
def train_epoch(model, dataloader, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch_x, batch_y in dataloader:
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        logits = model(batch_x)
        
        # Compute loss (only on target positions, not padding)
        # Reshape for cross entropy
        logits_flat = logits.view(-1, logits.size(-1))
        targets_flat = batch_y.view(-1)
        
        # Mask out padding positions (target == 0)
        mask = targets_flat != 0
        loss = F.cross_entropy(logits_flat[mask], targets_flat[mask])
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        
        # Compute accuracy
        preds = logits_flat.argmax(dim=-1)
        correct += (preds[mask] == targets_flat[mask]).sum().item()
        total += mask.sum().item()
    
    return total_loss / len(dataloader), correct / total

def test_model(model, dataset, device, n_samples=5):
    """Test the model on a few examples."""
    model.train(False)  # Set to inference mode
    
    print("\nTest examples:")
    print("-" * 50)
    
    for i in range(n_samples):
        x, y = dataset[i]
        x = x.unsqueeze(0).to(device)
        
        with torch.no_grad():
            logits = model(x)
            preds = logits.argmax(dim=-1)
        
        x_list = x[0].cpu().tolist()
        y_list = y.tolist()
        p_list = preds[0].cpu().tolist()
        
        # Extract the relevant part (after separator)
        sep_idx = seq_len
        input_seq = x_list[:sep_idx]
        target_seq = y_list[sep_idx+1:]
        pred_seq = p_list[sep_idx+1:]
        
        match = "Correct!" if target_seq == pred_seq else "Wrong"
        print(f"Input:  {input_seq}")
        print(f"Target: {target_seq}")
        print(f"Output: {pred_seq} ({match})")
        print()

In [None]:
# Create model
model = TransformerEncoder(
    vocab_size=vocab_size,
    d_model=64,
    n_heads=4,
    n_layers=2,
    dropout=0.1
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Training
n_epochs = 20
losses = []
accuracies = []

for epoch in range(n_epochs):
    loss, acc = train_epoch(model, dataloader, optimizer, device)
    losses.append(loss)
    accuracies.append(acc)
    
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1:3d} | Loss: {loss:.4f} | Accuracy: {acc:.4f}")

In [None]:
# Plot training progress
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(losses)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True, alpha=0.3)

axes[1].plot(accuracies)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Copy Accuracy')
axes[1].set_ylim(0, 1)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Test on examples
test_model(model, dataset, device)

## 8. Visualize Attention Patterns

Let's see what the model learned to attend to:

In [None]:
def visualize_attention(model, x, layer_idx=0):
    """Visualize attention weights for a given input."""
    model.train(False)  # Set to inference mode
    
    # Forward pass
    with torch.no_grad():
        _ = model(x.unsqueeze(0).to(device))
    
    # Get attention weights from specified layer
    attn = model.blocks[layer_idx].attention.attn_weights[0]  # (n_heads, seq_len, seq_len)
    attn = attn.cpu().numpy()
    
    n_heads = attn.shape[0]
    seq_len_full = attn.shape[1]
    
    # Plot attention for each head
    fig, axes = plt.subplots(1, n_heads, figsize=(3 * n_heads, 3))
    
    for head in range(n_heads):
        ax = axes[head] if n_heads > 1 else axes
        im = ax.imshow(attn[head], cmap='Blues', aspect='auto')
        ax.set_title(f'Head {head}')
        ax.set_xlabel('Key')
        if head == 0:
            ax.set_ylabel('Query')
    
    plt.suptitle(f'Attention Weights - Layer {layer_idx}')
    plt.tight_layout()
    return fig

# Visualize on a test example
x, y = dataset[0]
print(f"Input sequence: {x.tolist()}")
print(f"Format: [data tokens] [SEP=1] [padding=0]")
print()

# Layer 0 attention
visualize_attention(model, x, layer_idx=0)
plt.show()

# Layer 1 attention
visualize_attention(model, x, layer_idx=1)
plt.show()

## 9. Task: Sequence Reversal

A harder task - reverse the input sequence:

In [None]:
class ReverseDataset(Dataset):
    """Dataset for sequence reversal task."""
    
    def __init__(self, vocab_size, seq_len, num_samples):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.num_samples = num_samples
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # Generate random sequence
        seq = torch.randint(2, self.vocab_size, (self.seq_len,))
        
        # Reversed sequence
        rev_seq = seq.flip(0)
        
        # Input: [seq, SEP, zeros]
        # Target: [zeros, zeros, reversed_seq]
        sep = torch.tensor([1])
        padding = torch.zeros(self.seq_len, dtype=torch.long)
        
        input_seq = torch.cat([seq, sep, padding])
        target_seq = torch.cat([padding, torch.tensor([0]), rev_seq])
        
        return input_seq, target_seq

# Create reverse dataset
reverse_dataset = ReverseDataset(vocab_size, seq_len, num_samples=10000)
reverse_dataloader = DataLoader(reverse_dataset, batch_size=64, shuffle=True)

# Show example
x, y = reverse_dataset[0]
print(f"Input:  {x.tolist()}")
print(f"Target: {y.tolist()}")
print(f"\nThe model should output the first {seq_len} tokens in reverse order.")

In [None]:
# Train on reverse task
reverse_model = TransformerEncoder(
    vocab_size=vocab_size,
    d_model=64,
    n_heads=4,
    n_layers=3,  # Slightly deeper for harder task
    dropout=0.1
).to(device)

optimizer = torch.optim.Adam(reverse_model.parameters(), lr=1e-3)

n_epochs = 30
rev_losses = []
rev_accuracies = []

for epoch in range(n_epochs):
    loss, acc = train_epoch(reverse_model, reverse_dataloader, optimizer, device)
    rev_losses.append(loss)
    rev_accuracies.append(acc)
    
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1:3d} | Loss: {loss:.4f} | Accuracy: {acc:.4f}")

In [None]:
# Plot results
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(rev_losses)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Reverse Task - Training Loss')
axes[0].grid(True, alpha=0.3)

axes[1].plot(rev_accuracies)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Reverse Task - Accuracy')
axes[1].set_ylim(0, 1)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Test reverse task
print("\nReverse task test examples:")
print("-" * 50)
reverse_model.train(False)  # Set to inference mode

for i in range(5):
    x, y = reverse_dataset[i]
    x_in = x.unsqueeze(0).to(device)
    
    with torch.no_grad():
        logits = reverse_model(x_in)
        preds = logits.argmax(dim=-1)
    
    input_seq = x[:seq_len].tolist()
    target_seq = y[seq_len+1:].tolist()
    pred_seq = preds[0, seq_len+1:].cpu().tolist()
    
    match = "Correct!" if target_seq == pred_seq else "Wrong"
    print(f"Input:    {input_seq}")
    print(f"Expected: {target_seq}")
    print(f"Got:      {pred_seq} ({match})")
    print()

In [None]:
# Visualize attention for reverse task
x, y = reverse_dataset[0]
print(f"Input: {x[:seq_len].tolist()} (reversed: {x[:seq_len].flip(0).tolist()})")
print()

visualize_attention(reverse_model, x, layer_idx=0)
plt.show()

print("\nNotice: The model learns to attend to positions in reverse order!")
print("Output position i attends strongly to input position (seq_len - 1 - i)")

## 10. Architecture Summary

```
TransformerEncoder
├── Embedding (vocab_size -> d_model)
├── PositionalEncoding (sinusoidal)
├── TransformerBlock x N
│   ├── LayerNorm
│   ├── MultiHeadAttention
│   │   ├── Linear (Q projection)
│   │   ├── Linear (K projection)
│   │   ├── Linear (V projection)
│   │   └── Linear (output projection)
│   ├── Residual Connection
│   ├── LayerNorm
│   ├── FeedForward
│   │   ├── Linear (d_model -> 4*d_model)
│   │   ├── GELU
│   │   └── Linear (4*d_model -> d_model)
│   └── Residual Connection
├── LayerNorm (final)
└── Linear (d_model -> vocab_size)
```

In [None]:
# Parameter breakdown
def count_parameters(model):
    """Count parameters in each component."""
    counts = {}
    
    counts['embedding'] = model.embedding.weight.numel()
    counts['output'] = sum(p.numel() for p in model.output.parameters())
    counts['layer_norm'] = sum(p.numel() for p in model.norm.parameters())
    
    block_params = sum(p.numel() for p in model.blocks.parameters())
    counts['transformer_blocks'] = block_params
    
    # Breakdown per block
    block = model.blocks[0]
    counts['  - attention'] = sum(p.numel() for p in block.attention.parameters())
    counts['  - feed_forward'] = sum(p.numel() for p in block.feed_forward.parameters())
    counts['  - norms'] = sum(p.numel() for p in block.norm1.parameters()) + \
                          sum(p.numel() for p in block.norm2.parameters())
    
    return counts

counts = count_parameters(model)
total = sum(p.numel() for p in model.parameters())

print("Parameter Breakdown:")
print("-" * 40)
for name, count in counts.items():
    pct = count / total * 100
    print(f"{name:25s} {count:8,d} ({pct:5.1f}%)")
print("-" * 40)
print(f"{'Total':25s} {total:8,d}")

## Summary

| Component | Purpose |
|-----------|--------|
| Positional Encoding | Inject position information |
| Multi-Head Attention | Learn multiple attention patterns |
| Feed-Forward Network | Non-linear transformation per position |
| Layer Normalization | Stabilize training |
| Residual Connections | Enable gradient flow |

**Key Insights:**
1. The transformer is surprisingly simple: attention + FFN + residuals
2. Pre-norm (LayerNorm before attention/FFN) is more stable than post-norm
3. The model learns task-specific attention patterns (copy = identity, reverse = anti-diagonal)
4. Most parameters are in FFN layers, not attention

**Next:** [09-minimal-gpt.ipynb](09-minimal-gpt.ipynb) builds a decoder-only (autoregressive) transformer for text generation.

## Exercises

1. **Sorting Task:** Create a dataset where the model must output the input sequence in sorted order. How does this compare to copy/reverse?

2. **Causal Masking:** Modify MultiHeadAttention to support causal (autoregressive) masking. Verify that position i cannot attend to positions > i.

3. **Depth Ablation:** Train models with 1, 2, 4, and 6 layers on the reverse task. How does depth affect learning speed and final accuracy?

4. **Head Analysis:** After training, analyze what each attention head has learned. Do different heads specialize in different patterns?