# Week 3.2: Model Training Optimizations

In this notebook, we'll introduce optimization techniques used to make training large language models faster and more efficient. 

## Overview of Training Optimizations

1. **Starting point**: Unoptimized GPT-2 training (over 1000ms per iteration)
2. **GPUs, mixed precision**: First speedup (to ~1000ms)
3. **Tensor Cores, TF32 precision**: Second speedup (to ~333ms) 
4. **float16, gradient scalers, bfloat16**: Third speedup (to ~300ms)
5. **torch.compile, Python overhead, kernel fusion**: Fourth speedup (to ~130ms)
6. **Flash attention**: Fifth speedup (to ~96ms)
7. **Vocabulary size optimization** (from 50257 → 50304): Final speedup (to ~93ms)

## Optimization 1: GPUs and Mixed Precision

The first optimization involves moving computation to GPUs and leveraging mixed precision training.

### Key changes:
- Moving model and data to GPU with `model.to(device)` and `batch.to(device)`
- Setting up device detection:
```python
device = 'cuda' if torch.cuda.is_available() else 'cpu'
```

### Benefits:
- GPUs provide massive parallelism for matrix operations
- Initial speedup bringing iteration time to approximately 1000ms

## Optimization 2: Tensor Cores and TF32 Precision

NVIDIA GPUs starting from the Ampere architecture (e.g., A100, RTX 3090) have Tensor Cores that can use TensorFloat32 (TF32) for faster matrix multiplications.

### Relevant code ([commit 5265b20](https://github.com/karpathy/build-nanogpt/commit/5265b20)):
```python
# enable TensorFloat32 (TF32) for matrix multiplications
torch.set_float32_matmul_precision('high')
```

### Benefits:
- TF32 uses truncated 19-bit mantissa instead of 23-bit in standard float32
- Provides nearly float32 precision with performance close to float16
- Reduced iteration time to approximately 333ms (a 3x improvement)

## Optimization 3: float16, Gradient Scalers, and bfloat16

Further precision optimizations use half-precision floating point formats with gradient scaling to prevent underflow.

### Relevant code ([commit 177e4cd](https://github.com/karpathy/build-nanogpt/commit/177e4cd)):
```python
# use bfloat16 precision
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
context_manager = torch.amp.autocast(device_type=device_type, dtype=getattr(torch, dtype))

# gradient scaler for mixed precision training stability
scaler = torch.cuda.amp.GradScaler()

# in training loop:
with context_manager:
    logits, loss = model(X, Y)
    
# scale gradients to prevent underflow
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()
```

### Benefits:
- float16 uses 16 bits instead of 32, halving memory usage and speeding up computation
- bfloat16 (brain floating point) has better numerical range than float16
- Gradient scaling prevents numerical underflow in gradients
- Reduced iteration time to approximately 300ms

## Optimization 4: torch.compile and Kernel Fusion

PyTorch 2.0 introduced `torch.compile()` which dynamically optimizes code execution through kernel fusion and other techniques.

### Relevant code ([commit fb8bd6e](https://github.com/karpathy/build-nanogpt/commit/fb8bd6e)):
```python
# apply torch.compile for PyTorch 2.0+ optimization
model = torch.compile(model)
```

### Benefits:
- Fuses multiple operations into optimized CUDA kernels
- Reduces Python overhead by generating optimized code
- Dynamic graph capture and optimization
- Drastically reduced iteration time to approximately 130ms (a major improvement)

## Optimization 5: Flash Attention

Flash Attention is an optimized attention implementation that reduces memory usage and increases computation speed.

### Relevant code ([commit 7ee630c](https://github.com/karpathy/build-nanogpt/commit/7ee630c)):
```python
# In CausalSelfAttention forward method
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)  # flash attention
```

### Benefits:
- Optimized memory access patterns for attention computation
- Computes attention in small blocks that fit in fast SRAM
- Reduced memory bandwidth requirements
- Reduced iteration time to approximately 96ms

## Optimization 6: Vocabulary Size Optimization

The final optimization adjusts the vocabulary size to a nice round number that's a multiple of 64, making better use of hardware.

### Relevant code ([commit 7230096](https://github.com/karpathy/build-nanogpt/commit/7230096129a52ae763cd810f2bb9e61f67ec9ab7)):
```python
# Change from
model = GPT(GPTConfig())
# to
model = GPT(GPTConfig(vocab_size=50304))
```

### Benefits:
- Changed vocabulary size from 50257 to 50304 (a multiple of 64)
- Better alignment with hardware vector units
- More efficient memory access and computation
- Final reduction in iteration time to approximately 93ms

## Implementation Example

Below is a simplified implementation that incorporates these optimizations for a training loop:

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# Set up device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device_type = 'cuda' if 'cuda' in device else 'cpu'

# Enable TF32 for faster matrix multiplications on Ampere+ GPUs
if device_type == 'cuda':
    torch.set_float32_matmul_precision('high')

# Set up model (assuming you have a model class)
model = YourTransformerModel(vocab_size=50304)  # Use hardware-friendly size
model.to(device)

# Apply torch.compile (PyTorch 2.0+)
if hasattr(torch, 'compile'):
    model = torch.compile(model)

# Set up optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Set up mixed precision training
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
context_manager = torch.amp.autocast(device_type=device_type, dtype=getattr(torch, dtype))
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

# Training loop
def train_batch(x, y):
    # Move data to device
    x, y = x.to(device), y.to(device)
    
    # Forward pass with mixed precision
    with context_manager:
        logits, loss = model(x, y)
    
    # Backward pass with gradient scaling
    model.zero_grad(set_to_none=True)
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
    scaler.step(optimizer)
    scaler.update()
    
    return loss.item()

## Conclusion and Key Takeaways

By implementing these optimizations, Karpathy was able to reduce the per-iteration training time from over 1000ms to just 93ms - more than a 10x speedup. The progression of optimizations demonstrates different levels of complexity:

1. **Basic**: Moving to GPU
2. **Intermediate**: Using mixed precision (TF32, float16, bfloat16) with gradient scaling
3. **Advanced**: Leveraging torch.compile and Flash Attention
4. **Fine-tuning**: Optimizing vocabulary size for hardware alignment

These optimizations compound, with each building on the previous ones, and can be applied to many deep learning models, not just transformers.

When training your own models, always consider starting with these optimizations to dramatically reduce training time and resource usage. The order presented here is also a good guideline for implementation, as it progresses from simpler to more complex changes.