# Common Pitfalls

**Learn from common mistakes and how to avoid them**

## Introduction

**Even experienced practitioners make these mistakes.** This guide helps you avoid common pitfalls and debug issues quickly.

Each pitfall includes:
- **Symptom:** How to recognize the problem
- **Cause:** Why it happens
- **Solution:** How to fix it
- **Prevention:** How to avoid it

## Pitfall 1: Loss Becomes NaN

**Symptom:**
```
Epoch 1, Step 10:  Loss = 2.34
Epoch 1, Step 20:  Loss = 1.98
Epoch 1, Step 30:  Loss = 5.67
Epoch 1, Step 40:  Loss = inf
Epoch 1, Step 50:  Loss = nan  <-- Training dead!
```

**Causes:**
1. Learning rate too high (most common)
2. Gradient explosion
3. Numerical instability (FP16 underflow/overflow)
4. Bad batch (malformed data)

In [1]:
import torch

def check_for_nan_gradients(model):
    """Check if any gradients are NaN."""
    has_nan = False
    for name, param in model.named_parameters():
        if param.grad is not None:
            if torch.isnan(param.grad).any():
                print(f"NaN gradient in {name}")
                has_nan = True
    return has_nan

# Solutions
print("Solutions for NaN Loss:")
print("  1. Reduce learning rate by 10x")
print("  2. Enable/strengthen gradient clipping (max_grad_norm=0.5)")
print("  3. Use BF16 instead of FP16 (more stable)")
print("  4. Add warmup steps")
print()
print("If loss becomes NaN, you must restart from last good checkpoint!")

Solutions for NaN Loss:
  1. Reduce learning rate by 10x
  2. Enable/strengthen gradient clipping (max_grad_norm=0.5)
  3. Use BF16 instead of FP16 (more stable)
  4. Add warmup steps

If loss becomes NaN, you must restart from last good checkpoint!


## Pitfall 2: Loss Not Decreasing

**Symptom:**
```
Epoch 1, Step 100:  Loss = 2.45
Epoch 1, Step 200:  Loss = 2.44
Epoch 1, Step 300:  Loss = 2.43
Epoch 1, Step 400:  Loss = 2.42  <-- Barely moving!
```

**Causes:**
1. Learning rate too low
2. Model frozen (forgot to set trainable parameters)
3. Wrong optimizer state
4. Insufficient model capacity (LoRA rank too small)

In [2]:
def verify_training_setup(model, optimizer):
    """Verify model and optimizer are configured correctly."""
    
    # Check trainable params
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    
    if trainable == 0:
        print("ERROR: No trainable parameters!")
        return False
    
    print(f"Trainable: {trainable:,} ({100*trainable/total:.2f}%)")
    
    # Check optimizer
    if len(optimizer.param_groups) == 0:
        print("ERROR: Optimizer has no parameter groups!")
        return False
    
    lr = optimizer.param_groups[0]['lr']
    print(f"Learning rate: {lr}")
    
    if lr < 1e-6:
        print("WARNING: Learning rate very low!")
    
    return True

print("Solutions for Loss Not Decreasing:")
print("  1. Increase learning rate (10x for LoRA)")
print("  2. Verify trainable parameters exist")
print("  3. Increase LoRA rank if using LoRA")
print("  4. Check optimizer is correctly configured")

Solutions for Loss Not Decreasing:
  1. Increase learning rate (10x for LoRA)
  2. Verify trainable parameters exist
  3. Increase LoRA rank if using LoRA
  4. Check optimizer is correctly configured


## Pitfall 3: Overfitting (Train Loss << Val Loss)

**Symptom:**
```
Epoch 1: Train loss = 1.8, Val loss = 2.0  (gap = 0.2)
Epoch 2: Train loss = 1.2, Val loss = 2.1  (gap = 0.9)
Epoch 3: Train loss = 0.8, Val loss = 2.4  (gap = 1.6) <-- Overfitting!
```

In [3]:
def check_overfitting(train_loss, val_loss, threshold=0.5):
    """Check if model is overfitting."""
    gap = val_loss - train_loss
    
    if gap > threshold:
        print(f"Warning: Train/val gap = {gap:.2f} (overfitting!)")
        return True
    return False

print("Solutions for Overfitting:")
print("  1. Add/increase regularization (weight_decay=0.1)")
print("  2. Add dropout (lora_dropout=0.1)")
print("  3. Use early stopping")
print("  4. Reduce model capacity (lower LoRA rank)")
print("  5. Reduce epochs")
print("  6. Add data augmentation")

Solutions for Overfitting:
  1. Add/increase regularization (weight_decay=0.1)
  2. Add dropout (lora_dropout=0.1)
  3. Use early stopping
  4. Reduce model capacity (lower LoRA rank)
  5. Reduce epochs
  6. Add data augmentation


## Pitfall 4: Catastrophic Forgetting

**Symptom:**

After fine-tuning, model **loses general capabilities**:

```python
# Before fine-tuning (base model)
prompt = "What is the capital of France?"
response = "The capital of France is Paris."

# After fine-tuning on medical data
prompt = "What is the capital of France?"
response = "The capital of France is diabetes mellitus."  # What?!
```

In [4]:
def evaluate_general_knowledge(model, tokenizer, test_cases):
    """
    Evaluate on general knowledge to detect forgetting.
    """
    test_cases = [
        "What is 2 + 2?",
        "Who wrote Romeo and Juliet?",
        "What is the capital of France?",
        "What is water made of?",
    ]
    
    # Would generate responses and check accuracy
    print("General knowledge test cases:")
    for case in test_cases:
        print(f"  - {case}")

print("Solutions for Catastrophic Forgetting:")
print("  1. Lower learning rate (especially for full fine-tuning)")
print("  2. Use LoRA instead of full fine-tuning")
print("  3. Mix general data with specialized data")
print("  4. For DPO: Ensure KL penalty (beta=0.1)")
print("  5. Train for fewer epochs")
print()
print("Prevention: Test on general benchmarks before and after training")

Solutions for Catastrophic Forgetting:
  1. Lower learning rate (especially for full fine-tuning)
  2. Use LoRA instead of full fine-tuning
  3. Mix general data with specialized data
  4. For DPO: Ensure KL penalty (beta=0.1)
  5. Train for fewer epochs

Prevention: Test on general benchmarks before and after training


## Pitfall 5: Reference Model Divergence (DPO/RLHF)

**Symptom:**
```
Step 10:  KL = 0.05
Step 20:  KL = 0.08
Step 30:  KL = 0.15
Step 40:  KL = 0.35
Step 50:  KL = 1.20  <-- Too high!
```

In [5]:
import torch.nn.functional as F

def compute_kl_divergence(policy_model, ref_model, batch):
    """Compute KL(policy || reference)."""
    with torch.no_grad():
        policy_logits = policy_model(**batch).logits
        ref_logits = ref_model(**batch).logits
        
        policy_probs = F.softmax(policy_logits, dim=-1)
        ref_log_probs = F.log_softmax(ref_logits, dim=-1)
        
        kl = (policy_probs * (policy_probs.log() - ref_log_probs)).sum(-1).mean()
    
    return kl.item()

def verify_reference_frozen(ref_model):
    """Verify reference model is frozen."""
    for param in ref_model.parameters():
        if param.requires_grad:
            return False
    return True

print("Solutions for KL Divergence Explosion:")
print("  1. Verify reference model is frozen")
print("  2. Increase beta (DPO) or KL coefficient (RLHF)")
print("  3. Lower learning rate")
print("  4. Use adaptive KL coefficient")

Solutions for KL Divergence Explosion:
  1. Verify reference model is frozen
  2. Increase beta (DPO) or KL coefficient (RLHF)
  3. Lower learning rate
  4. Use adaptive KL coefficient


## Pitfall 6: Wrong Loss Masking

**Symptom:**

Model doesn't learn, or learns to generate prompts instead of responses:

```python
prompt = "Summarize this article: [long text]"
response = "Summarize this article: [repeats prompt]"  # Wrong!
```

In [6]:
import numpy as np

def test_loss_masking(dataset):
    """
    Verify loss masking is correct.
    
    Labels should be -100 for prompt tokens (masked)
    and actual token IDs for response tokens.
    """
    for i in range(min(5, len(dataset))):
        example = dataset[i]
        labels = example['labels']
        
        masked = sum(1 for l in labels if l == -100)
        unmasked = sum(1 for l in labels if l != -100)
        
        print(f"Example {i}: {masked} masked, {unmasked} unmasked")
        
        if unmasked == 0:
            print(f"  ERROR: All tokens masked! No training signal.")
        if masked == 0:
            print(f"  WARNING: No tokens masked (prompt included in loss?)")

print("Correct Loss Masking:")
print("  - Prompt tokens: labels = -100 (ignored)")
print("  - Response tokens: labels = actual token IDs")
print("  - If all masked: model has no training signal")
print("  - If none masked: model learns to repeat prompts")

Correct Loss Masking:
  - Prompt tokens: labels = -100 (ignored)
  - Response tokens: labels = actual token IDs
  - If all masked: model has no training signal
  - If none masked: model learns to repeat prompts


## Pitfall 7: Reward Hacking (RLHF)

**Symptom:**

Model achieves high reward but generates nonsensical responses:

```python
# Reward model trained on length preference
prompt = "Say hello"
response = "Hello hello hello hello hello..."  # Repeats to maximize length
reward = 10.0  # High reward, but terrible response!
```

In [7]:
def apply_reward_constraints(response, reward):
    """
    Apply rule-based penalties to reward.
    """
    words = response.split()
    
    # Penalize repetition
    unique_ratio = len(set(words)) / len(words) if words else 0
    if unique_ratio < 0.5:
        reward -= 5.0  # Heavy penalty for repetition
    
    # Penalize extreme length
    if len(words) > 300:
        reward -= 2.0
    if len(words) < 5:
        reward -= 3.0
    
    return reward

def check_reward_hacking(responses, rewards):
    """Detect if policy is exploiting reward model."""
    import numpy as np
    
    if np.std(rewards) < 0.1:
        print("Warning: All rewards similar (may be hacking)")
    
    # Check for repetition in high-reward responses
    high_reward_idx = np.argsort(rewards)[-5:]  # Top 5
    for idx in high_reward_idx:
        response = responses[idx]
        words = response.split()
        if len(set(words)) < len(words) * 0.5:
            print(f"Warning: High-reward response is repetitive")

print("Solutions for Reward Hacking:")
print("  1. Increase KL penalty")
print("  2. Add rule-based constraints")
print("  3. Use ensemble of reward models")
print("  4. Train reward model on diverse data")

Solutions for Reward Hacking:
  1. Increase KL penalty
  2. Add rule-based constraints
  3. Use ensemble of reward models
  4. Train reward model on diverse data


## Debugging Strategies

In [8]:
def bisect_debug(model, dataloader, optimizer):
    """Find which component is causing issues by binary search."""
    
    # Test 1: Model loads correctly?
    print("Test 1: Model loads...")
    try:
        _ = model.config
        print("  PASS")
    except Exception as e:
        print(f"  FAIL: {e}")
        return
    
    # Test 2: Forward pass works?
    print("Test 2: Forward pass...")
    try:
        batch = next(iter(dataloader))
        outputs = model(**batch)
        print("  PASS")
    except Exception as e:
        print(f"  FAIL: {e}")
        return
    
    # Test 3: Backward pass works?
    print("Test 3: Backward pass...")
    try:
        loss = outputs.loss
        loss.backward()
        print("  PASS")
    except Exception as e:
        print(f"  FAIL: {e}")
        return
    
    # Test 4: Optimizer step works?
    print("Test 4: Optimizer step...")
    try:
        optimizer.step()
        print("  PASS")
    except Exception as e:
        print(f"  FAIL: {e}")
        return
    
    print("\nAll components work individually!")

def check_gradients(model):
    """Check if gradients are computed correctly."""
    grad_norms = {}
    
    for name, param in model.named_parameters():
        if param.requires_grad and param.grad is not None:
            grad_norm = param.grad.norm().item()
            grad_norms[name] = grad_norm
            
            if grad_norm == 0:
                print(f"Warning: Zero gradient in {name}")
            elif grad_norm > 100:
                print(f"Warning: Large gradient in {name}: {grad_norm:.2f}")
    
    if grad_norms:
        import numpy as np
        avg_grad = np.mean(list(grad_norms.values()))
        print(f"Average gradient norm: {avg_grad:.4f}")

## Quick Reference: Debugging Checklist

**1. Environment Setup**
- [ ] PyTorch installed correctly
- [ ] GPU accessible (`torch.cuda.is_available()`)
- [ ] Correct CUDA version

**2. Data**
- [ ] Dataset loads without errors
- [ ] Loss masking correct (labels have -100 for prompt)
- [ ] No empty examples
- [ ] Tokenization works correctly

**3. Model**
- [ ] Model loads correctly
- [ ] Has trainable parameters
- [ ] LoRA applied if intended
- [ ] Model on correct device

**4. Training**
- [ ] Learning rate reasonable
- [ ] Gradient clipping enabled
- [ ] Warmup steps configured
- [ ] Loss decreases

**5. Method-Specific**
- [ ] DPO: Reference model frozen
- [ ] DPO: Beta in reasonable range
- [ ] RLHF: KL coefficient set
- [ ] RLHF: Value network separate from policy

## Summary: Most Common Mistakes

**Top 10 pitfalls by frequency:**

1. **Learning rate too high** -> Loss becomes NaN
2. **Wrong loss masking** -> Model doesn't learn properly
3. **No trainable parameters** -> Loss doesn't decrease
4. **Overfitting** -> Train loss << val loss
5. **Reference model not frozen** -> KL divergence explodes
6. **Gradient clipping disabled** -> Training unstable
7. **Poor data quality** -> Model learns bad patterns
8. **Catastrophic forgetting** -> Loses general knowledge
9. **Reward hacking** (RLHF) -> High reward, bad outputs
10. **OOM errors** -> Batch size too large

**Quick fixes:**

| Problem | Quick Fix |
|---------|----------|
| Loss = NaN | Reduce LR by 10x, add gradient clipping |
| Loss not decreasing | Check trainable params, increase LR |
| Overfitting | Add regularization, reduce epochs |
| Forgetting | Lower LR, use LoRA, mix general data |
| KL divergence high | Increase beta/KL coefficient |
| OOM | Reduce batch size, enable grad checkpointing |

## Congratulations!

You've completed the Fine-Tuning a Transformer section. You now understand:

- **SFT:** Supervised fine-tuning with instruction formatting and loss masking
- **Reward Models:** Training models to predict human preferences
- **RLHF:** Reinforcement learning from human feedback with PPO
- **DPO:** Direct preference optimization as a simpler alternative
- **Advanced Topics:** Memory optimization, hyperparameters, evaluation, and debugging

Ready to try it yourself? Check out the Try It notebook!