# 5. Transformer Block

**Combining all components into one repeatable unit**

A transformer block combines all our components into one repeatable unit. The full transformer model is just many of these blocks stacked on top of each other (GPT-3 has 96 blocks!).

## What's in a Block?

Each block contains four key components:

- **Multi-head attention:** Communication layer—tokens gather information from other tokens
- **Feed-forward network:** Computation layer—each token processes its gathered information
- **Layer normalization:** Stabilizes training by normalizing activations (prevents them from growing too large or small)
- **Residual connections:** "Skip connections" that create gradient highways for training deep networks

## Pre-LN Architecture

We use the Pre-LN (Pre-Layer Normalization) approach used in modern models like GPT-2 and GPT-3. This means we apply layer normalization *before* each sub-layer (attention or FFN) rather than after. This makes training more stable, especially for very deep networks.

```
Input x
   │
   ├──────────────────────┐
   │                      │ (residual)
   ↓                      │
[LayerNorm]               │
   ↓                      │
[Multi-Head Attention]    │
   ↓                      │
[Dropout]                 │
   ↓                      │
   + ←────────────────────┘
   │
   ├──────────────────────┐
   │                      │ (residual)
   ↓                      │
[LayerNorm]               │
   ↓                      │
[Feed-Forward]            │
   ↓                      │
[Dropout]                 │
   ↓                      │
   + ←────────────────────┘
   │
   ↓
Output
```

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttention(nn.Module):
    """Multi-head self-attention (from previous notebook)."""
    
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 1, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        return self.W_o(context)


class FeedForward(nn.Module):
    """Position-wise feed-forward network (from previous notebook)."""
    
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.activation = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.dropout1(x)
        x = self.linear2(x)
        x = self.dropout2(x)
        return x

In [None]:
class TransformerBlock(nn.Module):
    """A single transformer block with Pre-LN architecture."""
    
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        
        # Layer norms (applied BEFORE attention and FFN)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Attention and FFN
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        
        # Dropout for residual connections
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # First sub-layer: Multi-head attention with residual
        residual = x
        x = self.norm1(x)                    # Pre-LN
        x = self.attention(x, mask=mask)
        x = self.dropout1(x)
        x = x + residual                     # Residual connection
        
        # Second sub-layer: Feed-forward with residual
        residual = x
        x = self.norm2(x)                    # Pre-LN
        x = self.ffn(x)
        x = self.dropout2(x)
        x = x + residual                     # Residual connection
        
        return x

In [None]:
# Example
d_model = 64
num_heads = 4
d_ff = 256
batch_size = 2
seq_len = 8

block = TransformerBlock(d_model, num_heads, d_ff)

# Input (imagine this came from embeddings)
x = torch.randn(batch_size, seq_len, d_model)

# Create causal mask
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)

output = block(x, mask=mask)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nShape is preserved: {x.shape == output.shape}")

## Why Residual Connections?

Residual connections create gradient "highways" that allow gradients to flow directly from the output back to early layers. Without them, deep networks struggle to learn.

The math is beautiful:

$$\frac{\partial(x + f(x))}{\partial x} = 1 + \frac{\partial f(x)}{\partial x}$$

The "1" ensures gradients *always* flow, even if $\frac{\partial f(x)}{\partial x}$ is tiny!

In [None]:
# Demonstrate residual gradient flow
x = torch.randn(1, 4, 64, requires_grad=True)

# Without residual: gradients can vanish
linear = nn.Linear(64, 64)
y_no_res = linear(x)
y_no_res.sum().backward()
grad_no_res = x.grad.abs().mean().item()

# Reset
x.grad = None

# With residual: gradient is guaranteed to flow
y_with_res = x + linear(x)
y_with_res.sum().backward()
grad_with_res = x.grad.abs().mean().item()

print(f"Average gradient magnitude without residual: {grad_no_res:.4f}")
print(f"Average gradient magnitude with residual: {grad_with_res:.4f}")
print(f"\nResiduals ensure gradients always flow!")

In [None]:
# Count parameters in a transformer block
total_params = sum(p.numel() for p in block.parameters())

print(f"Transformer Block Parameters:")
print(f"  Attention (Q, K, V, O): {4 * d_model * d_model + 4 * d_model:,}")
print(f"  FFN (up, down): {2 * d_model * d_ff + d_ff + d_model:,}")
print(f"  LayerNorm (2×): {4 * d_model:,}")
print(f"  Total: {total_params:,} parameters")

## Stacking Blocks

The transformer is just N of these blocks stacked together:

| Model | Blocks | d_model | Heads |
|-------|--------|---------|-------|
| GPT-2 Small | 12 | 768 | 12 |
| GPT-2 Large | 36 | 1280 | 20 |
| GPT-3 | 96 | 12288 | 96 |

Each block refines the representations, building increasingly abstract understanding of the input.

In [None]:
# Stack multiple blocks
num_layers = 4
blocks = nn.ModuleList([
    TransformerBlock(d_model, num_heads, d_ff)
    for _ in range(num_layers)
])

# Forward through all blocks
x = torch.randn(batch_size, seq_len, d_model)
for i, block in enumerate(blocks):
    x = block(x, mask=mask)
    print(f"After block {i}: shape = {x.shape}, mean = {x.mean():.4f}")

print(f"\nTotal parameters: {sum(p.numel() for p in blocks.parameters()):,}")

## Next: Complete Model

Now we'll assemble everything—embeddings, transformer blocks, and output projection—into a complete language model.