# 13. Learning Rate Experiments

This notebook explores different learning rates and their effects on training.

## Experiment Overview
- **Goal**: Compare different learning rates and schedules
- **Model**: MLP with various learning rate strategies
- **Features**: LR scheduling, convergence analysis, optimization curves
- **Learning**: Understanding learning rate effects on training

## What You'll Learn
- Learning rate impact on convergence
- Learning rate scheduling
- Optimization landscape exploration
- Training stability analysis


In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import sys
import os

# Add scripts directory to path
sys.path.append('../scripts')
from utils import load_mnist_data, get_device, set_seed

# Set random seed for reproducibility
set_seed(42)

# Get device
device = get_device()
print(f"Using device: {device}")

# Load MNIST dataset
print("Loading MNIST dataset...")
train_loader, val_loader, test_loader = load_mnist_data(batch_size=64, test_split=0.2)

print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")


In [None]:
# Define MLP model
class LearningRateMLP(nn.Module):
    def __init__(self, input_size=784, hidden_size=128, num_classes=10):
        super(LearningRateMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Training function with learning rate tracking
def train_with_lr_tracking(model, train_loader, val_loader, epochs=30, lr=0.001, scheduler=None):
    """Train model with learning rate tracking."""
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    train_losses = []
    val_losses = []
    learning_rates = []
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(output, target)
                val_loss += loss.item()
        
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        learning_rates.append(optimizer.param_groups[0]['lr'])
        
        # Step scheduler
        if scheduler:
            scheduler.step()
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {optimizer.param_groups[0]["lr"]:.6f}')
    
    return train_losses, val_losses, learning_rates

# Test different learning rates
learning_rates = [0.001, 0.01, 0.1, 0.0001]
results = {}

for lr in learning_rates:
    print(f"\nTraining with learning rate: {lr}")
    model = LearningRateMLP().to(device)
    train_losses, val_losses, lr_history = train_with_lr_tracking(model, train_loader, val_loader, epochs=30, lr=lr)
    
    results[f'LR_{lr}'] = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'learning_rates': lr_history
    }

# Plot results
plt.figure(figsize=(15, 10))

plt.subplot(2, 2, 1)
for name, result in results.items():
    plt.plot(result['train_losses'], label=f'{name} (Train)')
plt.title('Training Losses by Learning Rate')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(2, 2, 2)
for name, result in results.items():
    plt.plot(result['val_losses'], label=f'{name} (Val)')
plt.title('Validation Losses by Learning Rate')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(2, 2, 3)
for name, result in results.items():
    plt.plot(result['learning_rates'], label=name)
plt.title('Learning Rate Schedules')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.legend()
plt.grid(True)
plt.yscale('log')

plt.subplot(2, 2, 4)
final_val_losses = [result['val_losses'][-1] for result in results.values()]
lr_names = list(results.keys())
plt.bar(lr_names, final_val_losses)
plt.title('Final Validation Loss by Learning Rate')
plt.ylabel('Final Val Loss')
plt.xticks(rotation=45)
plt.grid(True)

plt.tight_layout()
plt.savefig('../results/plots/learning_rate_experiments.png', dpi=300, bbox_inches='tight')
plt.show()

# Print summary
print("\nLearning Rate Experiment Summary:")
for name, result in results.items():
    final_train_loss = result['train_losses'][-1]
    final_val_loss = result['val_losses'][-1]
    print(f"{name}: Train Loss: {final_train_loss:.4f}, Val Loss: {final_val_loss:.4f}")
