# PyTorch Optimization & Training Loops

This notebook covers optimization, loss functions, and training patterns in PyTorch.

## Table of Contents
1. [Canonical Training Loop](#canonical-training-loop)
2. [Optimizers: SGD vs AdamW](#optimizers-sgd-vs-adamw)
3. [Common Loss Functions](#common-loss-functions)
4. [Train vs Eval Modes](#train-vs-eval-modes)
5. [Learning Rate Scheduling](#learning-rate-scheduling)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Canonical Training Loop

In [None]:
# Universal training loop skeleton
def train_model(model, train_loader, val_loader=None, num_epochs=5, lr=0.001, device='cpu'):
    """Universal training loop that works for MLPs, RNNs, and Transformers"""
    
    # Setup
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    
    # Move model to device
    model = model.to(device)
    
    # Training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    for epoch in range(num_epochs):
        # ===== TRAINING PHASE =====
        model.train()  # Set to training mode
        
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            # Move data to device
            data, target = data.to(device), target.to(device)
            
            # Zero gradients (important!)
            optimizer.zero_grad()
            
            # Forward pass
            output = model(data)
            loss = criterion(output, target)
            
            # Backward pass
            loss.backward()
            
            # Optional: gradient clipping (especially important for RNNs)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # Update parameters
            optimizer.step()
            
            # Track statistics
            train_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            train_total += target.size(0)
            train_correct += (predicted == target).sum().item()
            
            # Print progress occasionally
            if batch_idx % 20 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx:3d}: Loss = {loss.item():.4f}')
        
        # Calculate training metrics
        avg_train_loss = train_loss / len(train_loader)
        train_accuracy = 100.0 * train_correct / train_total
        
        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(train_accuracy)
        
        # ===== VALIDATION PHASE =====
        if val_loader is not None:
            model.eval()  # Set to evaluation mode
            
            val_loss = 0.0
            val_correct = 0
            val_total = 0
            
            with torch.no_grad():  # Disable gradient computation
                for data, target in val_loader:
                    data, target = data.to(device), target.to(device)
                    
                    output = model(data)
                    loss = criterion(output, target)
                    
                    val_loss += loss.item()
                    _, predicted = torch.max(output.data, 1)
                    val_total += target.size(0)
                    val_correct += (predicted == target).sum().item()
            
            avg_val_loss = val_loss / len(val_loader)
            val_accuracy = 100.0 * val_correct / val_total
            
            history['val_loss'].append(avg_val_loss)
            history['val_acc'].append(val_accuracy)
            
            print(f'Epoch {epoch+1:2d}: Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, '
                  f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.2f}%')
        else:
            print(f'Epoch {epoch+1:2d}: Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}%')
    
    return history

# Create a simple model for demonstration
class SimpleClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size//2)
        self.fc3 = nn.Linear(hidden_size//2, num_classes)
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Generate synthetic dataset
def create_classification_dataset(n_samples=1000, n_features=20, n_classes=5):
    """Create synthetic classification dataset"""
    torch.manual_seed(42)
    
    # Generate features
    X = torch.randn(n_samples, n_features)
    
    # Create separable classes
    # Use linear combination with some non-linearity
    weights = torch.randn(n_features, n_classes)
    logits = X @ weights + 0.1 * torch.randn(n_samples, n_classes)
    y = torch.argmax(logits, dim=1)
    
    return X, y

# Create dataset and dataloaders
X, y = create_classification_dataset(n_samples=1000, n_features=20, n_classes=5)
print(f"Dataset shape: X={X.shape}, y={y.shape}")
print(f"Class distribution: {torch.bincount(y)}")

# Split into train/validation
train_size = 800
X_train, X_val = X[:train_size], X[train_size:]
y_train, y_val = y[:train_size], y[train_size:]

# Create DataLoaders
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f"Train loader: {len(train_loader)} batches")
print(f"Val loader: {len(val_loader)} batches")

In [None]:
# Train the model using our universal training loop
model = SimpleClassifier(input_size=20, hidden_size=128, num_classes=5)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train model
print("Starting training...")
history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=10,
    lr=0.001,
    device=device
)

print("\nTraining completed!")

In [None]:
# Plot training curves
def plot_training_curves(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss curves
    ax1.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    if history['val_loss']:
        ax1.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Accuracy curves
    ax2.plot(epochs, history['train_acc'], 'b-', label='Training Accuracy')
    if history['val_acc']:
        ax2.plot(epochs, history['val_acc'], 'r-', label='Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print final metrics
    print(f"Final Training Loss: {history['train_loss'][-1]:.4f}")
    print(f"Final Training Accuracy: {history['train_acc'][-1]:.2f}%")
    if history['val_loss']:
        print(f"Final Validation Loss: {history['val_loss'][-1]:.4f}")
        print(f"Final Validation Accuracy: {history['val_acc'][-1]:.2f}%")

plot_training_curves(history)

## Optimizers: SGD vs AdamW

In [None]:
# Compare different optimizers
def compare_optimizers():
    """Compare SGD, Adam, and AdamW optimizers"""
    
    # Create identical models
    models = {
        'SGD': SimpleClassifier(20, 128, 5),
        'Adam': SimpleClassifier(20, 128, 5),
        'AdamW': SimpleClassifier(20, 128, 5)
    }
    
    # Make sure they start with identical weights
    state_dict = models['SGD'].state_dict()
    for model in models.values():
        model.load_state_dict(state_dict)
    
    # Create optimizers
    optimizers = {
        'SGD': optim.SGD(models['SGD'].parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4),
        'Adam': optim.Adam(models['Adam'].parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=1e-4),
        'AdamW': optim.AdamW(models['AdamW'].parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=1e-4)
    }
    
    criterion = nn.CrossEntropyLoss()
    
    # Track progress for each optimizer
    histories = {name: {'loss': [], 'acc': []} for name in models.keys()}
    
    num_epochs = 15
    
    for epoch in range(num_epochs):
        for name in models.keys():
            model = models[name]
            optimizer = optimizers[name]
            
            model.train()
            total_loss = 0
            correct = 0
            total = 0
            
            for data, target in train_loader:
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                _, predicted = torch.max(output.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
            
            avg_loss = total_loss / len(train_loader)
            accuracy = 100.0 * correct / total
            
            histories[name]['loss'].append(avg_loss)
            histories[name]['acc'].append(accuracy)
        
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1:2d}:")
            for name in models.keys():
                loss = histories[name]['loss'][-1]
                acc = histories[name]['acc'][-1]
                print(f"  {name:5s}: Loss = {loss:.4f}, Acc = {acc:.2f}%")
    
    return histories

# Compare optimizers
print("Comparing optimizers...")
optimizer_histories = compare_optimizers()

# Plot comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

epochs = range(1, len(optimizer_histories['SGD']['loss']) + 1)
colors = {'SGD': 'blue', 'Adam': 'green', 'AdamW': 'red'}

# Loss comparison
for name, history in optimizer_histories.items():
    ax1.plot(epochs, history['loss'], color=colors[name], label=f'{name}', marker='o', markersize=3)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Optimizer Comparison: Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy comparison
for name, history in optimizer_histories.items():
    ax2.plot(epochs, history['acc'], color=colors[name], label=f'{name}', marker='o', markersize=3)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Optimizer Comparison: Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print final comparison
print("\nFinal Results:")
for name, history in optimizer_histories.items():
    final_loss = history['loss'][-1]
    final_acc = history['acc'][-1]
    print(f"{name:5s}: Loss = {final_loss:.4f}, Accuracy = {final_acc:.2f}%")

print("\nKey differences:")
print("• SGD: Simple, requires learning rate tuning, benefits from momentum")
print("• Adam: Adaptive learning rates, fast convergence, can overfit")
print("• AdamW: Adam with decoupled weight decay, better generalization")

## Common Loss Functions

In [None]:
# Demonstrate different loss functions
print("=== Common Loss Functions Demo ===")

# 1. Classification: CrossEntropyLoss
print("\n1. CrossEntropyLoss (Multi-class Classification)")
ce_loss = nn.CrossEntropyLoss()

# Create sample data
batch_size, num_classes = 4, 5
logits = torch.randn(batch_size, num_classes)  # Raw scores from model
targets = torch.randint(0, num_classes, (batch_size,))  # Class indices

loss_ce = ce_loss(logits, targets)
print(f"Logits shape: {logits.shape}")
print(f"Targets shape: {targets.shape}")
print(f"Targets: {targets}")
print(f"CrossEntropyLoss: {loss_ce.item():.4f}")

# Show what happens with perfect predictions
perfect_logits = torch.zeros_like(logits)
for i, target in enumerate(targets):
    perfect_logits[i, target] = 10.0  # High score for correct class
perfect_loss = ce_loss(perfect_logits, targets)
print(f"Perfect prediction loss: {perfect_loss.item():.4f}")

# 2. Regression: MSELoss
print("\n2. MSELoss (Regression)")
mse_loss = nn.MSELoss()

predictions = torch.randn(batch_size, 1)
true_values = torch.randn(batch_size, 1)

loss_mse = mse_loss(predictions, true_values)
print(f"Predictions: {predictions.squeeze()}")
print(f"True values: {true_values.squeeze()}")
print(f"MSELoss: {loss_mse.item():.4f}")

# Show difference between MSE and MAE
mae_loss = nn.L1Loss()
loss_mae = mae_loss(predictions, true_values)
print(f"MAE (L1) Loss: {loss_mae.item():.4f}")

# 3. Binary Classification: BCEWithLogitsLoss
print("\n3. BCEWithLogitsLoss (Binary Classification)")
bce_loss = nn.BCEWithLogitsLoss()

binary_logits = torch.randn(batch_size, 1)  # Raw logits
binary_targets = torch.randint(0, 2, (batch_size, 1)).float()  # 0 or 1

loss_bce = bce_loss(binary_logits, binary_targets)
print(f"Binary logits: {binary_logits.squeeze()}")
print(f"Binary targets: {binary_targets.squeeze()}")
print(f"BCEWithLogitsLoss: {loss_bce.item():.4f}")

# Compare with manual BCE calculation
probs = torch.sigmoid(binary_logits)
print(f"Converted to probabilities: {probs.squeeze()}")

# 4. Negative Log Likelihood: NLLLoss
print("\n4. NLLLoss (when you already have log probabilities)")
nll_loss = nn.NLLLoss()

# NLLLoss expects log probabilities
log_probs = F.log_softmax(logits, dim=1)
loss_nll = nll_loss(log_probs, targets)

print(f"Log probabilities shape: {log_probs.shape}")
print(f"NLLLoss: {loss_nll.item():.4f}")
print(f"Note: CrossEntropyLoss = LogSoftmax + NLLLoss")
print(f"Verification - CE loss: {loss_ce.item():.4f}, NLL loss: {loss_nll.item():.4f}")
print(f"Match: {abs(loss_ce.item() - loss_nll.item()) < 1e-6}")

# 5. Multi-label classification: BCEWithLogitsLoss with multiple outputs
print("\n5. Multi-label Classification")
num_labels = 3
multi_logits = torch.randn(batch_size, num_labels)
multi_targets = torch.randint(0, 2, (batch_size, num_labels)).float()  # Multiple binary labels

multi_bce_loss = nn.BCEWithLogitsLoss()
loss_multi = multi_bce_loss(multi_logits, multi_targets)

print(f"Multi-label logits shape: {multi_logits.shape}")
print(f"Multi-label targets shape: {multi_targets.shape}")
print(f"Multi-label BCE loss: {loss_multi.item():.4f}")
print(f"Sample targets: {multi_targets[0]}  (can have multiple 1s)")

## Train vs Eval Modes

In [None]:
# Demonstrate train vs eval mode differences
print("=== Train vs Eval Mode Demo ===")

# Create model with dropout and batch norm
class ModelWithDropoutAndBN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.bn1 = nn.BatchNorm1d(hidden_size)
        self.dropout1 = nn.Dropout(0.5)  # 50% dropout
        
        self.fc2 = nn.Linear(hidden_size, hidden_size//2)
        self.bn2 = nn.BatchNorm1d(hidden_size//2)
        self.dropout2 = nn.Dropout(0.3)  # 30% dropout
        
        self.fc3 = nn.Linear(hidden_size//2, output_size)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        
        x = self.fc3(x)
        return x

model = ModelWithDropoutAndBN(20, 128, 5)
x = torch.randn(10, 20)  # Batch of 10 samples

print("1. Dropout behavior:")
print("=" * 30)

# Training mode: dropout is active
model.train()
print("Training mode (dropout active):")
out1 = model(x)
out2 = model(x)  # Same input, different output due to dropout

print(f"Output 1 mean: {out1.mean().item():.4f}")
print(f"Output 2 mean: {out2.mean().item():.4f}")
print(f"Outputs are different: {not torch.allclose(out1, out2)}")

# Evaluation mode: dropout is disabled
model.eval()
print("\nEvaluation mode (dropout disabled):")
out3 = model(x)
out4 = model(x)  # Same input, same output

print(f"Output 3 mean: {out3.mean().item():.4f}")
print(f"Output 4 mean: {out4.mean().item():.4f}")
print(f"Outputs are identical: {torch.allclose(out3, out4)}")

print("\n2. BatchNorm behavior:")
print("=" * 30)

# BatchNorm tracks running statistics differently in train vs eval
def check_bn_stats(model, mode_name):
    bn_layer = model.bn1
    print(f"{mode_name} mode:")
    print(f"  Running mean: {bn_layer.running_mean[:5]}")
    print(f"  Running var:  {bn_layer.running_var[:5]}")
    print(f"  Training: {bn_layer.training}")

# Reset batch norm statistics
def reset_bn_stats(model):
    for module in model.modules():
        if isinstance(module, nn.BatchNorm1d):
            module.reset_running_stats()

reset_bn_stats(model)

# Training mode: updates running statistics
model.train()
_ = model(x)
check_bn_stats(model, "Training")

# Evaluation mode: uses fixed running statistics
model.eval()
old_running_mean = model.bn1.running_mean.clone()
_ = model(x)
check_bn_stats(model, "Evaluation")

print(f"Running mean changed in eval: {not torch.allclose(old_running_mean, model.bn1.running_mean)}")

print("\n3. Practical implications:")
print("=" * 30)
print("✓ Always call model.train() before training")
print("✓ Always call model.eval() before inference")
print("✓ Use torch.no_grad() during inference to save memory")
print("✓ Dropout provides regularization during training")
print("✓ BatchNorm uses batch statistics in training, running stats in eval")

In [None]:
# Demonstrate correct inference pattern
def inference_example():
    """Show proper inference setup"""
    
    # Assume we have a trained model
    model = SimpleClassifier(20, 64, 5)
    model.eval()  # Set to evaluation mode
    
    # Sample test data
    test_data = torch.randn(100, 20)
    
    print("Inference setup:")
    
    # Method 1: Basic inference
    with torch.no_grad():  # Disable gradient computation
        predictions = model(test_data)
        probabilities = F.softmax(predictions, dim=1)
        predicted_classes = torch.argmax(predictions, dim=1)
    
    print(f"Predictions shape: {predictions.shape}")
    print(f"Sample probabilities: {probabilities[0]}")
    print(f"Predicted classes: {predicted_classes[:10]}")
    
    # Method 2: Batch processing for large datasets
    def batch_inference(model, data, batch_size=32):
        model.eval()
        all_predictions = []
        
        with torch.no_grad():
            for i in range(0, len(data), batch_size):
                batch = data[i:i+batch_size]
                batch_pred = model(batch)
                all_predictions.append(batch_pred)
        
        return torch.cat(all_predictions, dim=0)
    
    # Process large dataset in batches
    large_test_data = torch.randn(1000, 20)
    batch_predictions = batch_inference(model, large_test_data, batch_size=64)
    
    print(f"\nBatch inference on {len(large_test_data)} samples:")
    print(f"Output shape: {batch_predictions.shape}")
    
    # Memory usage comparison
    import torch.cuda
    if torch.cuda.is_available():
        print("\nMemory usage patterns:")
        print("• torch.no_grad() reduces memory usage by ~2x")
        print("• Batch processing prevents OOM on large datasets")
        print("• model.eval() ensures consistent results")

inference_example()

## Learning Rate Scheduling

In [None]:
# Demonstrate learning rate scheduling
print("=== Learning Rate Scheduling Demo ===")

def demonstrate_lr_scheduling():
    """Show different learning rate scheduling strategies"""
    
    # Create a simple model and optimizer
    model = SimpleClassifier(20, 64, 5)
    base_lr = 0.1
    optimizer = optim.SGD(model.parameters(), lr=base_lr)
    
    num_epochs = 50
    
    # Different schedulers
    schedulers = {
        'StepLR': optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5),
        'ExponentialLR': optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95),
        'CosineAnnealing': optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001),
        'ReduceLROnPlateau': optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=False)
    }
    
    # Track learning rates for each scheduler
    lr_histories = {name: [] for name in schedulers.keys()}
    lr_histories['No Scheduling'] = []
    
    # Simulate training with different schedulers
    for name, scheduler in schedulers.items():
        # Reset optimizer
        optimizer = optim.SGD(model.parameters(), lr=base_lr)
        if name != 'ReduceLROnPlateau':
            scheduler = type(scheduler)(optimizer, **scheduler.state_dict())
        else:
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=False)
        
        for epoch in range(num_epochs):
            current_lr = optimizer.param_groups[0]['lr']
            lr_histories[name].append(current_lr)
            
            # For ReduceLROnPlateau, we need to provide a metric
            if name == 'ReduceLROnPlateau':
                # Simulate a loss that decreases then plateaus
                fake_loss = 1.0 * np.exp(-epoch/10) + 0.1 + 0.05 * np.random.random()
                if epoch > 20:  # Start plateauing
                    fake_loss = 0.15 + 0.02 * np.random.random()
                scheduler.step(fake_loss)
            else:
                scheduler.step()
    
    # Add no scheduling baseline
    lr_histories['No Scheduling'] = [base_lr] * num_epochs
    
    return lr_histories

# Generate learning rate histories
lr_histories = demonstrate_lr_scheduling()

# Plot learning rate schedules
plt.figure(figsize=(12, 6))

epochs = range(len(lr_histories['No Scheduling']))
colors = ['blue', 'green', 'red', 'orange', 'purple']

for i, (name, lr_history) in enumerate(lr_histories.items()):
    plt.plot(epochs, lr_history, label=name, color=colors[i], linewidth=2)

plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Scheduling Comparison')
plt.legend()
plt.grid(True, alpha=0.3)
plt.yscale('log')  # Log scale to see all schedules clearly
plt.show()

# Explain each scheduler
print("\nScheduler explanations:")
print("=" * 40)
print("StepLR: Reduces LR by factor γ every step_size epochs")
print("ExponentialLR: Multiplies LR by γ each epoch")
print("CosineAnnealing: Follows cosine curve from max to min LR")
print("ReduceLROnPlateau: Reduces LR when metric stops improving")
print("No Scheduling: Constant learning rate")

print("\nWhen to use each:")
print("• StepLR: Simple, works well with SGD")
print("• ExponentialLR: Smooth decay, good for long training")
print("• CosineAnnealing: Popular for transformers, smooth restart")
print("• ReduceLROnPlateau: Adaptive, responds to actual performance")

In [None]:
# Advanced: Learning rate warmup
class WarmupScheduler:
    """Learning rate warmup followed by decay"""
    
    def __init__(self, optimizer, warmup_epochs, max_lr, total_epochs):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.max_lr = max_lr
        self.total_epochs = total_epochs
        self.current_epoch = 0
    
    def step(self):
        if self.current_epoch < self.warmup_epochs:
            # Linear warmup
            lr = self.max_lr * (self.current_epoch + 1) / self.warmup_epochs
        else:
            # Cosine decay after warmup
            progress = (self.current_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
            lr = 0.5 * self.max_lr * (1 + np.cos(np.pi * progress))
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        self.current_epoch += 1
        return lr

# Demonstrate warmup scheduling
model = SimpleClassifier(20, 64, 5)
optimizer = optim.AdamW(model.parameters(), lr=0.001)  # Will be overridden

warmup_scheduler = WarmupScheduler(
    optimizer=optimizer,
    warmup_epochs=10,
    max_lr=0.01,
    total_epochs=100
)

# Track warmup schedule
warmup_lrs = []
for epoch in range(100):
    lr = warmup_scheduler.step()
    warmup_lrs.append(lr)

# Plot warmup schedule
plt.figure(figsize=(10, 4))
plt.plot(range(100), warmup_lrs, 'b-', linewidth=2, label='Warmup + Cosine Decay')
plt.axvline(x=10, color='r', linestyle='--', alpha=0.7, label='End of Warmup')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Warmup + Cosine Decay')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("Warmup benefits:")
print("• Prevents early training instability")
print("• Especially important for large batch sizes")
print("• Common in transformer training")
print("• Allows higher maximum learning rates")

print("\n🎉 Optimization & Training exploration completed!")
print("\nKey takeaways:")
print("• Use a systematic training loop structure")
print("• AdamW is generally a good default optimizer")
print("• Match loss function to your task type")
print("• Always set model.train()/model.eval() appropriately")
print("• Learning rate scheduling can significantly improve results")
print("• Monitor both training and validation metrics")