# Chapter 3: Transformer Blocks

Welcome to the third notebook in our LLM from Scratch series! In this chapter, we'll explore the **Transformer Block** - the fundamental building block that gets stacked to create powerful models like GPT.

## What You'll Learn

1. **Transformer block architecture**: Combining attention and feedforward networks
2. **Residual connections**: Why they're essential for deep networks
3. **Layer normalization**: Stabilizing training
4. **Pre-norm vs post-norm**: Modern architecture choices
5. **Feed forward networks**: Position-wise processing
6. **Hands-on experimentation** with our implementation

Let's build the complete transformer block!

## 1. The Transformer Block Architecture

A transformer block combines **two main components**:

### Component 1: Multi-Head Self-Attention
- Allows positions to communicate with each other
- Captures relationships and dependencies
- We learned about this in Notebook 02!

### Component 2: Feedforward Network (FFN)
- Processes each position independently
- Adds non-linear transformations
- Increases model capacity

### The Complete Block:

```
Input
  ↓
LayerNorm → Multi-Head Attention → + (residual)
  ↓                                  ↑
  └──────────────────────────────────┘
  ↓
LayerNorm → FeedForward Network → + (residual)
  ↓                                 ↑
  └─────────────────────────────────┘
  ↓
Output
```

This pattern is called **Pre-LayerNorm** architecture, which is what modern models like GPT-2/GPT-3 use!

## 2. Residual Connections: The Key to Deep Networks

**Residual connections** (also called skip connections) allow gradients to flow directly through the network.

### Without Residual Connections:
```python
x = layer1(x)
x = layer2(x)
x = layer3(x)
# Gradient must flow through all layers → vanishing gradients!
```

### With Residual Connections:
```python
x = x + layer1(x)  # Residual connection
x = x + layer2(x)  # Residual connection
x = x + layer3(x)  # Residual connection
# Gradient can flow directly → stable training!
```

### Why This Matters:

- ✅ **Enables training deep networks** (GPT-3 has 96 layers!)
- ✅ **Prevents vanishing gradients** (direct path for gradients)
- ✅ **Allows learning identity function** (if needed, layer can output 0)
- ✅ **Better optimization landscape** (easier to find good solutions)

This innovation from ResNet (He et al., 2015) revolutionized deep learning!

## 3. Layer Normalization: Stabilizing Training

**Layer normalization** normalizes activations to have mean 0 and variance 1.

### Formula:

$$\text{LayerNorm}(x) = \gamma \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$

Where:
- $\mu$ = mean across features (for each example)
- $\sigma^2$ = variance across features (for each example)
- $\gamma$ = learnable scale parameter
- $\beta$ = learnable shift parameter
- $\epsilon$ = small constant for numerical stability

### Why Layer Norm (not Batch Norm)?

**Batch Normalization**: Normalizes across batch dimension
- ❌ Requires large batches
- ❌ Doesn't work well with variable-length sequences
- ❌ Different behavior during train/test

**Layer Normalization**: Normalizes across feature dimension
- ✅ Works with any batch size
- ✅ Perfect for sequences
- ✅ Same behavior during train/test

### Benefits:

1. **Stable gradients**: Prevents exploding/vanishing gradients
2. **Faster training**: Can use higher learning rates
3. **Better generalization**: Acts as regularization

## 4. Pre-Norm vs Post-Norm

Two ways to arrange LayerNorm and residual connections:

### Post-Norm (Original Transformer):
```
x = LayerNorm(x + Attention(x))
x = LayerNorm(x + FFN(x))
```

### Pre-Norm (Modern, GPT-2/3):
```
x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))
```

### Why Pre-Norm is Better:

- ✅ **Better gradient flow**: Residual path doesn't go through LayerNorm
- ✅ **More stable training**: Especially for very deep models
- ✅ **Easier to train**: Less sensitive to hyperparameters
- ✅ **Used in modern LLMs**: GPT-2, GPT-3, LLaMA, etc.

Our implementation uses **Pre-Norm**!

## 5. Feedforward Network (FFN)

The FFN is a simple 2-layer MLP applied **independently** to each position:

### Architecture:
```
Input: (batch, seq_len, d_model)
  ↓
Linear: d_model → d_ff (expand, typically 4x)
  ↓
GELU activation
  ↓
Linear: d_ff → d_model (compress back)
  ↓
Dropout
  ↓
Output: (batch, seq_len, d_model)
```

### Key Points:

1. **Position-wise**: Same transformation applied to each position independently
2. **Expansion**: Typically expand to 4 × d_model (e.g., 512 → 2048)
3. **GELU activation**: Smoother than ReLU, works better for language
4. **Adds capacity**: Provides non-linear transformations

### Why GELU (not ReLU)?

**ReLU**: $\text{ReLU}(x) = \max(0, x)$
- Hard cutoff at 0
- Non-smooth

**GELU**: $\text{GELU}(x) = x \cdot \Phi(x)$ (Φ = standard normal CDF)
- Smooth function
- Probabilistic interpretation
- Better for language models empirically

## 6. Hands-On: Building Transformer Blocks

Let's experiment with our implementation!

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from src.llm.transformer import TransformerBlock, FeedForward
from src.llm.attention import create_causal_mask

# Set random seed
torch.manual_seed(42)

print("PyTorch version:", torch.__version__)

### 6.1 Feedforward Network

In [None]:
# Create a feedforward network
d_model = 128
d_ff = 512  # 4 * d_model

ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=0.0)

print(f"Feedforward Network:")
print(f"  Input dimension: {d_model}")
print(f"  Hidden dimension: {d_ff}")
print(f"  Expansion factor: {d_ff / d_model}x")
print(f"\nParameters:")
total_params = sum(p.numel() for p in ffn.parameters())
print(f"  Total: {total_params:,}")

# Test forward pass
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, d_model)
output = ffn(x)

print(f"\nInput shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Shape preserved: {x.shape == output.shape}")

### 6.2 Visualizing GELU Activation

In [None]:
import torch.nn.functional as F

# Compare ReLU and GELU
x = torch.linspace(-3, 3, 200)
relu = F.relu(x)
gelu = F.gelu(x)

plt.figure(figsize=(10, 6))
plt.plot(x.numpy(), relu.numpy(), label='ReLU', linewidth=2)
plt.plot(x.numpy(), gelu.numpy(), label='GELU', linewidth=2)
plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
plt.axvline(x=0, color='k', linestyle='--', alpha=0.3)
plt.grid(True, alpha=0.3)
plt.xlabel('Input', fontsize=12)
plt.ylabel('Output', fontsize=12)
plt.title('ReLU vs GELU Activation Functions', fontsize=14)
plt.legend(fontsize=12)
plt.tight_layout()
plt.show()

print("Key differences:")
print("- ReLU: Hard cutoff at 0 (sharp corner)")
print("- GELU: Smooth transition (differentiable everywhere)")
print("- GELU: Allows small negative values (probabilistic)")

### 6.3 Complete Transformer Block

In [None]:
# Create a transformer block
d_model = 128
n_heads = 8
d_ff = 512

block = TransformerBlock(
    d_model=d_model,
    n_heads=n_heads,
    d_ff=d_ff,
    dropout=0.0  # Disable for deterministic behavior
)

print(f"Transformer Block Configuration:")
print(f"  Model dimension: {d_model}")
print(f"  Attention heads: {n_heads}")
print(f"  FFN dimension: {d_ff}")

# Count parameters
total_params = sum(p.numel() for p in block.parameters())
print(f"\nTotal parameters: {total_params:,}")

# Breakdown
attn_params = sum(p.numel() for p in block.attn.parameters())
ffn_params = sum(p.numel() for p in block.ffn.parameters())
ln_params = sum(p.numel() for p in block.ln1.parameters()) + sum(p.numel() for p in block.ln2.parameters())

print(f"  Attention: {attn_params:,} ({attn_params/total_params*100:.1f}%)")
print(f"  Feedforward: {ffn_params:,} ({ffn_params/total_params*100:.1f}%)")
print(f"  LayerNorm: {ln_params:,} ({ln_params/total_params*100:.1f}%)")

### 6.4 Forward Pass Through Transformer Block

In [None]:
# Create input
batch_size = 2
seq_len = 8
x = torch.randn(batch_size, seq_len, d_model)

# Create causal mask
mask = create_causal_mask(seq_len)

print(f"Input:")
print(f"  Shape: {x.shape}")
print(f"  Mean: {x.mean():.4f}")
print(f"  Std: {x.std():.4f}")

# Forward pass
output, attn_weights = block(x, mask=mask, return_attention=True)

print(f"\nOutput:")
print(f"  Shape: {output.shape}")
print(f"  Mean: {output.mean():.4f}")
print(f"  Std: {output.std():.4f}")

print(f"\nAttention weights:")
print(f"  Shape: {attn_weights.shape}")
print(f"  (batch_size, n_heads, seq_len, seq_len)")

### 6.5 Visualizing Information Flow

In [None]:
# Track intermediate values through the block
def forward_with_intermediates(block, x, mask):
    """Forward pass that returns intermediate values."""
    intermediates = {}
    
    # Initial input
    intermediates['input'] = x.clone()
    
    # After first LayerNorm
    ln1_out = block.ln1(x)
    intermediates['after_ln1'] = ln1_out.clone()
    
    # After attention
    attn_out, _ = block.attn(ln1_out, mask=mask)
    intermediates['after_attn'] = attn_out.clone()
    
    # After first residual
    x = x + attn_out
    intermediates['after_residual1'] = x.clone()
    
    # After second LayerNorm
    ln2_out = block.ln2(x)
    intermediates['after_ln2'] = ln2_out.clone()
    
    # After FFN
    ffn_out = block.ffn(ln2_out)
    intermediates['after_ffn'] = ffn_out.clone()
    
    # After second residual
    x = x + ffn_out
    intermediates['output'] = x.clone()
    
    return intermediates

# Get intermediate values
intermediates = forward_with_intermediates(block, x, mask)

# Plot statistics
stages = list(intermediates.keys())
means = [intermediates[stage].mean().item() for stage in stages]
stds = [intermediates[stage].std().item() for stage in stages]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Mean values
ax1.plot(means, marker='o', linewidth=2, markersize=8)
ax1.set_xticks(range(len(stages)))
ax1.set_xticklabels(stages, rotation=45, ha='right')
ax1.set_ylabel('Mean', fontsize=12)
ax1.set_title('Mean Activation Through Transformer Block', fontsize=14)
ax1.grid(True, alpha=0.3)
ax1.axhline(y=0, color='r', linestyle='--', alpha=0.3)

# Standard deviation
ax2.plot(stds, marker='s', linewidth=2, markersize=8, color='orange')
ax2.set_xticks(range(len(stages)))
ax2.set_xticklabels(stages, rotation=45, ha='right')
ax2.set_ylabel('Standard Deviation', fontsize=12)
ax2.set_title('Std Activation Through Transformer Block', fontsize=14)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Observations:")
print("- LayerNorm keeps activations normalized (mean ≈ 0, std ≈ 1)")
print("- Residual connections add back un-normalized values")
print("- This pattern repeats for attention and FFN")

## 7. Stacking Transformer Blocks

Real models stack **many transformer blocks** to increase capacity:
- GPT-2 Small: 12 blocks
- GPT-2 Medium: 24 blocks
- GPT-2 Large: 36 blocks
- GPT-3: 96 blocks!

Let's build a stack of blocks:

In [None]:
# Create a stack of transformer blocks
n_layers = 6
d_model = 128
n_heads = 8
d_ff = 512

blocks = nn.ModuleList([
    TransformerBlock(
        d_model=d_model,
        n_heads=n_heads,
        d_ff=d_ff,
        dropout=0.1
    )
    for _ in range(n_layers)
])

print(f"Transformer Stack:")
print(f"  Number of layers: {n_layers}")
print(f"  Model dimension: {d_model}")
print(f"  Attention heads: {n_heads}")

total_params = sum(p.numel() for p in blocks.parameters())
print(f"\nTotal parameters: {total_params:,}")
print(f"Parameters per block: {total_params // n_layers:,}")

### 7.1 Forward Pass Through Stack

In [None]:
# Input
batch_size = 2
seq_len = 10
x = torch.randn(batch_size, seq_len, d_model)
mask = create_causal_mask(seq_len)

# Pass through all blocks
blocks.eval()  # Set to eval mode
layer_outputs = [x]

with torch.no_grad():
    for i, block in enumerate(blocks):
        x, _ = block(x, mask=mask)
        layer_outputs.append(x.clone())
        print(f"Layer {i+1}: mean={x.mean():.4f}, std={x.std():.4f}")

print(f"\nFinal output shape: {x.shape}")

### 7.2 Visualizing Representations Across Layers

In [None]:
# Analyze how representations change across layers
# For first sequence in batch, first position
position_idx = 0
representations = torch.stack([out[0, position_idx] for out in layer_outputs])  # (n_layers+1, d_model)

plt.figure(figsize=(12, 6))
plt.imshow(representations.T.numpy(), aspect='auto', cmap='RdBu_r', vmin=-2, vmax=2)
plt.colorbar(label='Activation Value')
plt.xlabel('Layer', fontsize=12)
plt.ylabel('Feature Dimension', fontsize=12)
plt.title(f'Representation Evolution Across Layers (Position {position_idx})', fontsize=14)
plt.xticks(range(n_layers + 1), ['Input'] + [f'L{i+1}' for i in range(n_layers)])
plt.tight_layout()
plt.show()

print("Each column shows the feature vector after a transformer layer.")
print("Notice how the representation evolves as we go deeper!")

## 8. Understanding Capacity and Depth

Let's explore how depth affects model capacity:

In [None]:
def count_parameters(d_model: int, n_heads: int, d_ff: int, n_layers: int) -> int:
    """Count parameters in a stack of transformer blocks."""
    # Single block parameters
    # Attention: 4 projections (Q, K, V, O) each d_model × d_model
    attn_params = 4 * d_model * d_model + 4 * d_model  # weights + biases
    
    # FFN: two projections
    ffn_params = (d_model * d_ff + d_ff) + (d_ff * d_model + d_model)
    
    # LayerNorm: 2 sets of (scale + shift)
    ln_params = 2 * (d_model + d_model)
    
    block_params = attn_params + ffn_params + ln_params
    return block_params * n_layers

# Compare different configurations
configs = [
    {"name": "Tiny", "d_model": 64, "n_heads": 4, "d_ff": 256, "n_layers": 4},
    {"name": "Small", "d_model": 128, "n_heads": 8, "d_ff": 512, "n_layers": 6},
    {"name": "Medium", "d_model": 256, "n_heads": 8, "d_ff": 1024, "n_layers": 12},
    {"name": "Large", "d_model": 512, "n_heads": 16, "d_ff": 2048, "n_layers": 24},
]

print(f"{'Config':<10} {'d_model':<10} {'n_heads':<10} {'d_ff':<10} {'n_layers':<10} {'Parameters':<15}")
print("-" * 75)

for config in configs:
    params = count_parameters(
        config["d_model"],
        config["n_heads"],
        config["d_ff"],
        config["n_layers"]
    )
    print(f"{config['name']:<10} {config['d_model']:<10} {config['n_heads']:<10} "
          f"{config['d_ff']:<10} {config['n_layers']:<10} {params:>12,}")

## 9. Key Takeaways

Let's recap what we've learned about transformer blocks:

1. **Transformer block = Attention + FFN**:
   - Multi-head attention captures relationships
   - Feedforward network adds capacity
   - Both use residual connections and layer normalization

2. **Residual connections enable deep networks**:
   - Direct path for gradients
   - Prevents vanishing gradients
   - Essential for training 100+ layer models

3. **Layer normalization stabilizes training**:
   - Normalizes across features (not batch)
   - Works well with variable-length sequences
   - Allows higher learning rates

4. **Pre-norm architecture is modern best practice**:
   - Better gradient flow than post-norm
   - More stable training
   - Used in GPT-2, GPT-3, and beyond

5. **Feedforward network expands and contracts**:
   - Typically 4× expansion (d_model → 4 * d_model → d_model)
   - GELU activation for smooth gradients
   - Processes each position independently

6. **Stacking blocks increases capacity**:
   - More layers = more complex patterns
   - GPT-3 has 96 transformer blocks!
   - Parameters grow linearly with depth

## Next Steps

Now that we understand transformer blocks, we're ready to build the **complete GPT model** - adding embeddings, positional encoding, and the output layer!

Continue to **Notebook 04: Complete GPT Model** →

---

## Further Reading

- [Attention Is All You Need](https://arxiv.org/abs/1706.03762) (Vaswani et al., 2017)
- [Deep Residual Learning](https://arxiv.org/abs/1512.03385) (He et al., 2015)
- [Layer Normalization](https://arxiv.org/abs/1607.06450) (Ba et al., 2016)
- [On Layer Normalization in Transformers](https://arxiv.org/abs/2002.04745) (Pre-norm vs Post-norm)