# 10: Fine-Tuning Basics

## Learning Objectives

1. Understand how fine-tuning adapts pretrained models to specific tasks
2. Load pretrained models and add task-specific heads
3. Observe catastrophic forgetting during aggressive fine-tuning
4. Compare frozen vs unfrozen backbone approaches
5. Implement learning rate schedules for stable fine-tuning

**Prerequisites:** [pretraining](../modern-llms/pretraining.md), [fine-tuning](../modern-llms/fine-tuning.md)

**Framework:** PyTorch + HuggingFace Transformers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from transformers import get_linear_schedule_with_warmup
import copy
from tqdm.auto import tqdm

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

## Part 1: Loading a Pretrained Model

We'll use GPT-2 small (124M parameters) as our base model. This model was pretrained on WebText to predict the next token.

In [None]:
# Load pretrained GPT-2
model_name = "gpt2"  # 124M parameters
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # GPT-2 doesn't have a pad token

# Load the model
pretrained_model = GPT2LMHeadModel.from_pretrained(model_name)
print(f"Model parameters: {sum(p.numel() for p in pretrained_model.parameters()):,}")

# Show architecture
print(f"\nModel config:")
print(f"  Hidden size: {pretrained_model.config.n_embd}")
print(f"  Layers: {pretrained_model.config.n_layer}")
print(f"  Heads: {pretrained_model.config.n_head}")
print(f"  Vocabulary: {pretrained_model.config.vocab_size}")

In [None]:
# Demonstrate that it already knows language
def generate_text(model, prompt, max_tokens=50, temperature=0.8):
    """Generate text from a prompt."""
    model.train(False)  # Set to inference mode
    
    inputs = tokenizer(prompt, return_tensors='pt').to(device)
    model = model.to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_new_tokens=max_tokens,
            temperature=temperature,
            do_sample=True,
            top_k=50,
            pad_token_id=tokenizer.eos_token_id
        )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test generation
print("Pretrained GPT-2 generation:")
print("-" * 50)
print(generate_text(pretrained_model, "The meaning of life is"))

## Part 2: Creating a Fine-Tuning Task

We'll create a sentiment classification task. Fine-tuning means adapting the pretrained model to classify text as positive or negative.

In [None]:
# Simple sentiment dataset
SENTIMENT_DATA = [
    # Positive
    ("This movie was absolutely amazing and wonderful!", 1),
    ("I loved every moment of this incredible experience.", 1),
    ("What a fantastic product, exceeded all expectations!", 1),
    ("The service was excellent and the staff were friendly.", 1),
    ("Best purchase I've ever made, highly recommend!", 1),
    ("This made my day so much better, thank you!", 1),
    ("Absolutely brilliant work, I'm impressed!", 1),
    ("Outstanding quality and great value for money.", 1),
    ("The food was delicious and beautifully presented.", 1),
    ("I had an amazing time, would definitely return!", 1),
    ("Such a wonderful surprise, I'm so happy!", 1),
    ("This exceeded my expectations in every way.", 1),
    ("Perfect in every way, couldn't ask for more.", 1),
    ("What an incredible achievement, well done!", 1),
    ("I'm thrilled with the results, fantastic!", 1),
    # Negative
    ("This was a complete waste of time and money.", 0),
    ("Terrible experience, I would never recommend this.", 0),
    ("The quality was awful and disappointing.", 0),
    ("Worst service I have ever encountered.", 0),
    ("I regret buying this, total disappointment.", 0),
    ("What a disaster, nothing worked as expected.", 0),
    ("Horrible product, broke after one day.", 0),
    ("The food was cold and tasted terrible.", 0),
    ("I'm very unhappy with this purchase.", 0),
    ("Waste of money, don't bother buying this.", 0),
    ("Extremely frustrating experience, avoid at all costs.", 0),
    ("This ruined my entire day, so disappointed.", 0),
    ("Poor quality and even worse customer service.", 0),
    ("I've never been more let down by a product.", 0),
    ("Absolute garbage, want my money back.", 0),
]

# Split into train/test
np.random.shuffle(SENTIMENT_DATA)
train_data = SENTIMENT_DATA[:24]
test_data = SENTIMENT_DATA[24:]

print(f"Training examples: {len(train_data)}")
print(f"Test examples: {len(test_data)}")

In [None]:
class SentimentDataset(Dataset):
    """Dataset for sentiment classification."""
    
    def __init__(self, data, tokenizer, max_length=64):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text, label = self.data[idx]
        encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(label, dtype=torch.long)
        }

train_dataset = SentimentDataset(train_data, tokenizer)
test_dataset = SentimentDataset(test_data, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

## Part 3: Adding a Classification Head

For classification, we need to:
1. Get the hidden states from GPT-2 (use last token's representation)
2. Add a linear layer to project to class logits

In [None]:
class GPT2ForSentiment(nn.Module):
    """GPT-2 with a classification head for sentiment analysis."""
    
    def __init__(self, pretrained_model, num_classes=2, freeze_backbone=False):
        super().__init__()
        
        # Copy the transformer backbone
        self.transformer = copy.deepcopy(pretrained_model.transformer)
        self.config = pretrained_model.config
        
        # Optionally freeze backbone weights
        if freeze_backbone:
            for param in self.transformer.parameters():
                param.requires_grad = False
        
        # Add classification head (randomly initialized)
        self.classifier = nn.Linear(self.config.n_embd, num_classes)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, input_ids, attention_mask=None):
        # Get transformer outputs
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        hidden_states = outputs.last_hidden_state  # [batch, seq, hidden]
        
        # Pool: use last non-padding token
        # For each sequence, find the position of the last real token
        if attention_mask is not None:
            # Get index of last 1 in attention mask
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = hidden_states.shape[0]
            pooled = hidden_states[torch.arange(batch_size, device=hidden_states.device), 
                                   sequence_lengths]
        else:
            pooled = hidden_states[:, -1]  # Just use last token
        
        # Classify
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)
        return logits
    
    def count_trainable_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

# Create model with unfrozen backbone
sentiment_model = GPT2ForSentiment(pretrained_model, freeze_backbone=False).to(device)
print(f"Trainable parameters (unfrozen): {sentiment_model.count_trainable_params():,}")

# Create model with frozen backbone
frozen_model = GPT2ForSentiment(pretrained_model, freeze_backbone=True).to(device)
print(f"Trainable parameters (frozen): {frozen_model.count_trainable_params():,}")

## Part 4: Training and Evaluating

In [None]:
def train_epoch(model, train_loader, optimizer, scheduler=None):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask)
        loss = F.cross_entropy(logits, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        if scheduler:
            scheduler.step()
        
        total_loss += loss.item()
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(train_loader), correct / total


def evaluate(model, test_loader):
    """Evaluate on test set."""
    model.train(False)  # Inference mode
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            logits = model(input_ids, attention_mask)
            preds = logits.argmax(dim=-1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    return correct / total

## Part 5: Demonstrating Catastrophic Forgetting

Let's see what happens when we fine-tune aggressively (high learning rate). The model will overfit to our task and "forget" its language modeling abilities.

In [None]:
def measure_language_modeling(model_transformer, tokenizer, test_prompts):
    """Measure language modeling quality by computing perplexity on test prompts."""
    # Rebuild an LM head model from the transformer
    config = GPT2Config.from_pretrained('gpt2')
    lm_model = GPT2LMHeadModel(config).to(device)
    lm_model.transformer = model_transformer
    lm_model.train(False)
    
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for prompt in test_prompts:
            inputs = tokenizer(prompt, return_tensors='pt').to(device)
            outputs = lm_model(**inputs, labels=inputs['input_ids'])
            total_loss += outputs.loss.item() * inputs['input_ids'].shape[1]
            total_tokens += inputs['input_ids'].shape[1]
    
    avg_loss = total_loss / total_tokens
    perplexity = np.exp(avg_loss)
    return perplexity

# Test prompts to measure language modeling quality
LM_TEST_PROMPTS = [
    "The capital of France is Paris, which is known for",
    "In the beginning, there was nothing but darkness and",
    "Machine learning is a field of artificial intelligence that",
    "The quick brown fox jumps over the lazy dog while",
]

In [None]:
# Train with HIGH learning rate (will cause forgetting)
print("Training with HIGH learning rate (1e-3) - demonstrates catastrophic forgetting")
print("=" * 70)

aggressive_model = GPT2ForSentiment(pretrained_model, freeze_backbone=False).to(device)
aggressive_optimizer = torch.optim.AdamW(aggressive_model.parameters(), lr=1e-3)  # Too high!

aggressive_metrics = {
    'train_loss': [],
    'train_acc': [],
    'test_acc': [],
    'lm_perplexity': []
}

# Measure initial perplexity
initial_ppl = measure_language_modeling(aggressive_model.transformer, tokenizer, LM_TEST_PROMPTS)
print(f"Initial LM perplexity: {initial_ppl:.2f}")

for epoch in range(10):
    train_loss, train_acc = train_epoch(aggressive_model, train_loader, aggressive_optimizer)
    test_acc = evaluate(aggressive_model, test_loader)
    lm_ppl = measure_language_modeling(aggressive_model.transformer, tokenizer, LM_TEST_PROMPTS)
    
    aggressive_metrics['train_loss'].append(train_loss)
    aggressive_metrics['train_acc'].append(train_acc)
    aggressive_metrics['test_acc'].append(test_acc)
    aggressive_metrics['lm_perplexity'].append(lm_ppl)
    
    print(f"Epoch {epoch+1}: Train Acc={train_acc:.3f}, Test Acc={test_acc:.3f}, LM PPL={lm_ppl:.1f}")

print(f"\nFinal LM perplexity: {aggressive_metrics['lm_perplexity'][-1]:.2f}")
print(f"Perplexity increased by {aggressive_metrics['lm_perplexity'][-1]/initial_ppl:.1f}x (forgetting!)")

In [None]:
# Train with PROPER low learning rate
print("\nTraining with LOW learning rate (2e-5) - proper fine-tuning")
print("=" * 70)

careful_model = GPT2ForSentiment(pretrained_model, freeze_backbone=False).to(device)
careful_optimizer = torch.optim.AdamW(careful_model.parameters(), lr=2e-5, weight_decay=0.01)

# Learning rate schedule with warmup
total_steps = len(train_loader) * 10
careful_scheduler = get_linear_schedule_with_warmup(
    careful_optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

careful_metrics = {
    'train_loss': [],
    'train_acc': [],
    'test_acc': [],
    'lm_perplexity': []
}

initial_ppl = measure_language_modeling(careful_model.transformer, tokenizer, LM_TEST_PROMPTS)
print(f"Initial LM perplexity: {initial_ppl:.2f}")

for epoch in range(10):
    train_loss, train_acc = train_epoch(careful_model, train_loader, careful_optimizer, careful_scheduler)
    test_acc = evaluate(careful_model, test_loader)
    lm_ppl = measure_language_modeling(careful_model.transformer, tokenizer, LM_TEST_PROMPTS)
    
    careful_metrics['train_loss'].append(train_loss)
    careful_metrics['train_acc'].append(train_acc)
    careful_metrics['test_acc'].append(test_acc)
    careful_metrics['lm_perplexity'].append(lm_ppl)
    
    print(f"Epoch {epoch+1}: Train Acc={train_acc:.3f}, Test Acc={test_acc:.3f}, LM PPL={lm_ppl:.1f}")

print(f"\nFinal LM perplexity: {careful_metrics['lm_perplexity'][-1]:.2f}")
print(f"Perplexity change: {careful_metrics['lm_perplexity'][-1]/initial_ppl:.2f}x (minimal forgetting!)")

In [None]:
# Visualize catastrophic forgetting
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

epochs = range(1, 11)

# Training loss
axes[0].plot(epochs, aggressive_metrics['train_loss'], 'r-', label='High LR (1e-3)', linewidth=2)
axes[0].plot(epochs, careful_metrics['train_loss'], 'b-', label='Low LR (2e-5)', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Training Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Test accuracy
axes[1].plot(epochs, aggressive_metrics['test_acc'], 'r-', label='High LR (1e-3)', linewidth=2)
axes[1].plot(epochs, careful_metrics['test_acc'], 'b-', label='Low LR (2e-5)', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Test Accuracy')
axes[1].set_title('Task Performance')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# LM perplexity (forgetting indicator)
axes[2].plot(epochs, aggressive_metrics['lm_perplexity'], 'r-', label='High LR (1e-3)', linewidth=2)
axes[2].plot(epochs, careful_metrics['lm_perplexity'], 'b-', label='Low LR (2e-5)', linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('LM Perplexity')
axes[2].set_title('Language Modeling Quality\n(Lower = Better)')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('catastrophic_forgetting.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nKey insight: High learning rate achieves similar task accuracy but destroys")
print("the model's language understanding (high perplexity = catastrophic forgetting)")

## Part 6: Frozen vs Unfrozen Backbone

Another strategy to prevent forgetting: freeze the pretrained weights entirely and only train the classification head.

In [None]:
# Train with frozen backbone (feature extraction)
print("Training with FROZEN backbone (only train classifier head)")
print("=" * 70)

frozen_model = GPT2ForSentiment(pretrained_model, freeze_backbone=True).to(device)
# Can use higher LR since we're not touching pretrained weights
frozen_optimizer = torch.optim.AdamW(frozen_model.parameters(), lr=1e-3)

frozen_metrics = {
    'train_loss': [],
    'train_acc': [],
    'test_acc': [],
    'lm_perplexity': []
}

initial_ppl = measure_language_modeling(frozen_model.transformer, tokenizer, LM_TEST_PROMPTS)
print(f"Initial LM perplexity: {initial_ppl:.2f}")

for epoch in range(10):
    train_loss, train_acc = train_epoch(frozen_model, train_loader, frozen_optimizer)
    test_acc = evaluate(frozen_model, test_loader)
    lm_ppl = measure_language_modeling(frozen_model.transformer, tokenizer, LM_TEST_PROMPTS)
    
    frozen_metrics['train_loss'].append(train_loss)
    frozen_metrics['train_acc'].append(train_acc)
    frozen_metrics['test_acc'].append(test_acc)
    frozen_metrics['lm_perplexity'].append(lm_ppl)
    
    print(f"Epoch {epoch+1}: Train Acc={train_acc:.3f}, Test Acc={test_acc:.3f}, LM PPL={lm_ppl:.1f}")

print(f"\nFinal LM perplexity: {frozen_metrics['lm_perplexity'][-1]:.2f}")
print("No forgetting possible - backbone weights unchanged!")

In [None]:
# Compare all three approaches
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

epochs = range(1, 11)

# Test accuracy comparison
axes[0].plot(epochs, aggressive_metrics['test_acc'], 'r-', label='Unfrozen + High LR', linewidth=2)
axes[0].plot(epochs, careful_metrics['test_acc'], 'b-', label='Unfrozen + Low LR', linewidth=2)
axes[0].plot(epochs, frozen_metrics['test_acc'], 'g-', label='Frozen backbone', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Test Accuracy')
axes[0].set_title('Task Performance Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# LM perplexity comparison
axes[1].plot(epochs, aggressive_metrics['lm_perplexity'], 'r-', label='Unfrozen + High LR', linewidth=2)
axes[1].plot(epochs, careful_metrics['lm_perplexity'], 'b-', label='Unfrozen + Low LR', linewidth=2)
axes[1].plot(epochs, frozen_metrics['lm_perplexity'], 'g-', label='Frozen backbone', linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('LM Perplexity')
axes[1].set_title('Language Modeling Quality')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('finetuning_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Summary statistics
print("Summary: Final Test Accuracy and LM Perplexity")
print("=" * 60)
print(f"{'Method':<30} {'Test Acc':>12} {'LM PPL':>12}")
print("-" * 60)
print(f"{'Unfrozen + High LR (1e-3)':<30} {aggressive_metrics['test_acc'][-1]:>12.3f} {aggressive_metrics['lm_perplexity'][-1]:>12.1f}")
print(f"{'Unfrozen + Low LR (2e-5)':<30} {careful_metrics['test_acc'][-1]:>12.3f} {careful_metrics['lm_perplexity'][-1]:>12.1f}")
print(f"{'Frozen backbone':<30} {frozen_metrics['test_acc'][-1]:>12.3f} {frozen_metrics['lm_perplexity'][-1]:>12.1f}")
print("-" * 60)
print("\nTrade-offs:")
print("- High LR: Fastest training but destroys pretrained knowledge")
print("- Low LR: Good task performance while preserving general abilities")
print("- Frozen: Zero forgetting but may underfit complex tasks")

## Part 7: Learning Rate Schedules

Proper fine-tuning uses warmup + decay schedules to stabilize training.

In [None]:
# Visualize different learning rate schedules
def visualize_lr_schedules():
    """Compare different learning rate schedules."""
    total_steps = 1000
    warmup_steps = 100
    base_lr = 2e-5
    
    # Create dummy models and optimizers for schedule visualization
    schedules = {}
    
    # Constant LR
    lrs = [base_lr] * total_steps
    schedules['Constant'] = lrs
    
    # Linear warmup + decay
    dummy_param = nn.Parameter(torch.zeros(1))
    opt = torch.optim.AdamW([dummy_param], lr=base_lr)
    scheduler = get_linear_schedule_with_warmup(
        opt, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )
    lrs = []
    for _ in range(total_steps):
        lrs.append(opt.param_groups[0]['lr'])
        scheduler.step()
    schedules['Linear warmup + decay'] = lrs
    
    # Cosine with warmup
    dummy_param = nn.Parameter(torch.zeros(1))
    opt = torch.optim.AdamW([dummy_param], lr=base_lr)
    from transformers import get_cosine_schedule_with_warmup
    scheduler = get_cosine_schedule_with_warmup(
        opt, num_warmup_steps=warmup_steps, num_training_steps=total_steps
    )
    lrs = []
    for _ in range(total_steps):
        lrs.append(opt.param_groups[0]['lr'])
        scheduler.step()
    schedules['Cosine warmup + decay'] = lrs
    
    # Plot
    plt.figure(figsize=(10, 5))
    steps = range(total_steps)
    colors = ['gray', 'blue', 'green']
    
    for (name, lrs), color in zip(schedules.items(), colors):
        plt.plot(steps, lrs, label=name, color=color, linewidth=2)
    
    plt.axvline(x=warmup_steps, color='red', linestyle='--', alpha=0.5, label='Warmup ends')
    plt.xlabel('Training Step')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedules for Fine-Tuning')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('lr_schedules.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\nWhy warmup?")
    print("- Early gradients are noisy (random classifier head)")
    print("- Small LR during warmup prevents large initial updates")
    print("- Gradual increase lets model stabilize")
    print("\nWhy decay?")
    print("- Reduces LR as we approach convergence")
    print("- Helps model settle into a good minimum")
    print("- Prevents overshooting late in training")

visualize_lr_schedules()

## Part 8: Layer-wise Learning Rates

Advanced technique: use different learning rates for different layers. Lower layers (closer to input) should change less.

In [None]:
def create_layerwise_optimizer(model, base_lr=2e-5, decay_factor=0.9):
    """
    Create optimizer with layer-wise learning rate decay.
    
    Later layers get higher LR, earlier layers get lower LR.
    This preserves lower-level features while adapting higher-level ones.
    """
    # Separate parameters into groups
    param_groups = []
    
    # Embeddings get lowest LR
    embed_params = []
    for name, param in model.transformer.named_parameters():
        if 'wte' in name or 'wpe' in name:  # Token and position embeddings
            embed_params.append(param)
    if embed_params:
        param_groups.append({
            'params': embed_params,
            'lr': base_lr * (decay_factor ** 12)  # Lowest LR
        })
    
    # Each transformer layer gets progressively higher LR
    n_layers = model.config.n_layer
    for layer_idx in range(n_layers):
        layer_params = []
        for name, param in model.transformer.named_parameters():
            if f'.h.{layer_idx}.' in name:
                layer_params.append(param)
        
        # Earlier layers get lower LR
        layer_lr = base_lr * (decay_factor ** (n_layers - 1 - layer_idx))
        if layer_params:
            param_groups.append({
                'params': layer_params,
                'lr': layer_lr
            })
    
    # Final layer norm
    ln_params = []
    for name, param in model.transformer.named_parameters():
        if 'ln_f' in name:
            ln_params.append(param)
    if ln_params:
        param_groups.append({
            'params': ln_params,
            'lr': base_lr  # Full LR for final norm
        })
    
    # Classification head gets highest LR
    param_groups.append({
        'params': model.classifier.parameters(),
        'lr': base_lr * 10  # Higher LR for new layers
    })
    
    return torch.optim.AdamW(param_groups, weight_decay=0.01)


# Visualize layer-wise learning rates
def visualize_layerwise_lr(base_lr=2e-5, decay_factor=0.9, n_layers=12):
    """Show how learning rate varies by layer."""
    layers = ['Embed'] + [f'Layer {i}' for i in range(n_layers)] + ['LN_f', 'Head']
    lrs = [
        base_lr * (decay_factor ** n_layers),  # Embeddings
        *[base_lr * (decay_factor ** (n_layers - 1 - i)) for i in range(n_layers)],  # Layers
        base_lr,  # Final LN
        base_lr * 10  # Head
    ]
    
    plt.figure(figsize=(12, 5))
    bars = plt.bar(range(len(layers)), lrs, color='steelblue')
    bars[-1].set_color('green')  # Highlight classifier head
    bars[0].set_color('orange')  # Highlight embeddings
    
    plt.xticks(range(len(layers)), layers, rotation=45, ha='right')
    plt.ylabel('Learning Rate')
    plt.title('Layer-wise Learning Rates\n(Earlier layers learn slower to preserve pretrained features)')
    plt.tight_layout()
    plt.savefig('layerwise_lr.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"Embeddings LR: {lrs[0]:.2e}")
    print(f"Layer 0 LR: {lrs[1]:.2e}")
    print(f"Layer 11 LR: {lrs[-3]:.2e}")
    print(f"Classifier LR: {lrs[-1]:.2e}")

visualize_layerwise_lr()

## Part 9: Practical Fine-Tuning Checklist

A summary of best practices for fine-tuning pretrained models.

In [None]:
def print_finetuning_checklist():
    """Print a checklist for fine-tuning."""
    checklist = """
    FINE-TUNING BEST PRACTICES
    ==========================
    
    1. LEARNING RATE
       [ ] Use small LR: 1e-5 to 5e-5 for full fine-tuning
       [ ] Use warmup: 6-10% of total steps
       [ ] Use decay: linear or cosine schedule
       [ ] Consider layer-wise LR decay
    
    2. REGULARIZATION
       [ ] Weight decay: 0.01 typical
       [ ] Dropout: keep or slightly increase
       [ ] Early stopping: monitor validation loss
    
    3. DATA
       [ ] Clean, high-quality task data
       [ ] Balanced classes if classification
       [ ] Reasonable train/val/test split (e.g., 80/10/10)
    
    4. TRAINING
       [ ] Small batch size: 8-32 typical
       [ ] Few epochs: 2-4 often enough
       [ ] Gradient clipping: max_norm=1.0
       [ ] Mixed precision if available
    
    5. EVALUATION
       [ ] Monitor task metrics (accuracy, F1, etc.)
       [ ] Check for overfitting (train vs val gap)
       [ ] Optionally track forgetting metrics
    
    6. ALTERNATIVES
       [ ] Frozen backbone for limited data
       [ ] LoRA/adapters for efficient training
       [ ] Few-shot prompting before fine-tuning
    """
    print(checklist)

print_finetuning_checklist()

## Exercises

1. **Data size experiment**: Try fine-tuning with 5, 10, 20, 50 training examples. How does performance scale?

2. **Epoch selection**: Train for 20 epochs and plot train/val loss. When does overfitting begin?

3. **Different pooling**: Instead of last-token pooling, try mean pooling. Does it help?

4. **Partial freezing**: Freeze only the first N layers. Find the best N.

5. **Multi-task**: Add a second task (e.g., topic classification). Train on both simultaneously.

## Summary

| Concept | Key Point |
|---------|----------|
| Transfer learning | Pretrained models have useful general representations |
| Classification head | Add task-specific layers on top of frozen/unfrozen backbone |
| Catastrophic forgetting | High LR destroys pretrained knowledge |
| Learning rate | Use 1e-5 to 5e-5 with warmup + decay |
| Frozen backbone | Prevents forgetting but may underfit |
| Layer-wise LR | Earlier layers learn slower to preserve features |

**Key insight:** Fine-tuning works because pretrained representations transfer. The challenge is adapting to new tasks without destroying what was learned. Use small learning rates, warmup schedules, and consider which layers to update.