# Module 2.1: The Training Loop Deconstructed

The training loop is where deep learning happens. Understanding each component deeply enables you to:
- Debug training issues effectively
- Implement custom training procedures
- Optimize for speed and memory
- Adapt to different training paradigms

## Learning Objectives
- Understand each step of the training loop and why it matters
- Master loss functions and implement custom ones
- Understand optimizers: SGD, Adam, AdamW and their differences
- Use learning rate schedulers effectively
- Apply gradient clipping for stability
- Handle numerical issues in training

---

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

---
## 1. Anatomy of the Training Loop

Every training loop follows this pattern:

```python
for epoch in range(num_epochs):
    for batch in dataloader:
        # 1. Zero gradients
        optimizer.zero_grad()
        
        # 2. Forward pass
        outputs = model(inputs)
        
        # 3. Compute loss
        loss = criterion(outputs, targets)
        
        # 4. Backward pass
        loss.backward()
        
        # 5. Update weights
        optimizer.step()
```

Let's understand each step deeply.

In [None]:
# Create a simple dataset
def create_synthetic_data(n_samples=1000, n_features=10, n_classes=3):
    """Create a synthetic classification dataset."""
    X = torch.randn(n_samples, n_features)
    # Create linearly separable classes
    true_weights = torch.randn(n_features, n_classes)
    logits = X @ true_weights
    y = logits.argmax(dim=1)
    return X, y

X, y = create_synthetic_data()
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

print(f"Dataset size: {len(dataset)}")
print(f"Batch count: {len(dataloader)}")

### Step 1: Zero Gradients

In [None]:
model = nn.Linear(10, 3)
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Without zeroing, gradients accumulate
for i in range(3):
    x = torch.randn(5, 10)
    y = model(x).sum()
    y.backward()
    print(f"Iteration {i+1}, weight grad norm: {model.weight.grad.norm():.4f}")

print("\nGradients accumulated!")

In [None]:
# Proper pattern: zero gradients before backward
model = nn.Linear(10, 3)
optimizer = optim.SGD(model.parameters(), lr=0.01)

for i in range(3):
    optimizer.zero_grad()  # or model.zero_grad()
    x = torch.randn(5, 10)
    y = model(x).sum()
    y.backward()
    print(f"Iteration {i+1}, weight grad norm: {model.weight.grad.norm():.4f}")

print("\nGradients are fresh each iteration")

In [None]:
# set_to_none=True can be slightly faster (avoids zero fill)
optimizer.zero_grad(set_to_none=True)

# Now gradients are None instead of zeros
print(f"After set_to_none=True: grad = {model.weight.grad}")

### Step 2: Forward Pass

In [None]:
# Forward pass builds the computation graph
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 3)
)

x = torch.randn(5, 10)
output = model(x)  # Graph is built during this call

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Output has grad_fn: {output.grad_fn is not None}")

### Step 3: Compute Loss

In [None]:
# Loss function reduces outputs to a scalar
criterion = nn.CrossEntropyLoss()

targets = torch.randint(0, 3, (5,))  # Class labels
loss = criterion(output, targets)

print(f"Loss value: {loss.item():.4f}")
print(f"Loss is scalar: {loss.shape == torch.Size([])}")
print(f"Loss requires grad: {loss.requires_grad}")

### Step 4: Backward Pass

In [None]:
# Backward computes gradients for all parameters
print("Before backward:")
print(f"  model[0].weight.grad: {model[0].weight.grad}")

loss.backward()

print("\nAfter backward:")
print(f"  model[0].weight.grad shape: {model[0].weight.grad.shape}")
print(f"  model[0].weight.grad norm: {model[0].weight.grad.norm():.4f}")

### Step 5: Update Weights

In [None]:
# optimizer.step() updates parameters using computed gradients
print("Before step:")
print(f"  model[0].weight[0,:3]: {model[0].weight.data[0,:3]}")

optimizer = optim.SGD(model.parameters(), lr=0.1)
optimizer.step()

print("\nAfter step:")
print(f"  model[0].weight[0,:3]: {model[0].weight.data[0,:3]}")

### Complete Training Loop

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # 1. Zero gradients
        optimizer.zero_grad()
        
        # 2. Forward pass
        outputs = model(inputs)
        
        # 3. Compute loss
        loss = criterion(outputs, targets)
        
        # 4. Backward pass
        loss.backward()
        
        # 5. Update weights
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    return total_loss / total, 100. * correct / total

# Initialize
model = nn.Sequential(
    nn.Linear(10, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, 3)
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Train
print("Training:")
for epoch in range(10):
    loss, acc = train_epoch(model, dataloader, criterion, optimizer, device)
    if (epoch + 1) % 2 == 0:
        print(f"Epoch {epoch+1:2d}: Loss = {loss:.4f}, Accuracy = {acc:.1f}%")

---
## 2. Loss Functions

Loss functions measure how far predictions are from targets. Understanding them helps you choose the right one and create custom losses.

### 2.1 Common Loss Functions

In [None]:
# Regression losses
predictions = torch.tensor([1.0, 2.0, 3.0, 4.0])
targets = torch.tensor([1.1, 2.2, 2.8, 4.5])

# Mean Squared Error (L2)
mse = nn.MSELoss()
print(f"MSE: {mse(predictions, targets):.4f}")
print(f"Manual: {((predictions - targets) ** 2).mean():.4f}")

# Mean Absolute Error (L1)
mae = nn.L1Loss()
print(f"\nMAE: {mae(predictions, targets):.4f}")
print(f"Manual: {(predictions - targets).abs().mean():.4f}")

# Smooth L1 (Huber) - combines L1 and L2
smooth_l1 = nn.SmoothL1Loss()
print(f"\nSmooth L1: {smooth_l1(predictions, targets):.4f}")

In [None]:
# Classification losses

# Binary Cross Entropy (for binary classification)
# Input should be probabilities (after sigmoid)
probs = torch.tensor([0.9, 0.1, 0.8, 0.3])
binary_targets = torch.tensor([1.0, 0.0, 1.0, 0.0])

bce = nn.BCELoss()
print(f"BCE: {bce(probs, binary_targets):.4f}")

# BCE with logits (more numerically stable)
logits = torch.tensor([2.0, -2.0, 1.5, -1.0])
bce_logits = nn.BCEWithLogitsLoss()
print(f"BCE with logits: {bce_logits(logits, binary_targets):.4f}")

In [None]:
# Cross Entropy for multi-class
# Input: (N, C) logits, Targets: (N,) class indices
logits = torch.tensor([
    [2.0, 1.0, 0.1],   # Predicts class 0
    [0.5, 2.5, 0.3],   # Predicts class 1
    [0.1, 0.2, 3.0],   # Predicts class 2
])
targets = torch.tensor([0, 1, 2])  # All correct

ce = nn.CrossEntropyLoss()
print(f"Cross Entropy (all correct): {ce(logits, targets):.4f}")

# Wrong predictions
wrong_targets = torch.tensor([2, 0, 1])
print(f"Cross Entropy (all wrong): {ce(logits, wrong_targets):.4f}")

In [None]:
# Cross Entropy = LogSoftmax + NLLLoss
log_softmax = nn.LogSoftmax(dim=1)
nll_loss = nn.NLLLoss()

log_probs = log_softmax(logits)
manual_ce = nll_loss(log_probs, targets)

print(f"CrossEntropyLoss: {ce(logits, targets):.4f}")
print(f"LogSoftmax + NLLLoss: {manual_ce:.4f}")

### 2.2 Loss Function Options

In [None]:
# Reduction modes
predictions = torch.tensor([1.0, 2.0, 3.0])
targets = torch.tensor([1.5, 2.5, 3.5])

# Mean (default)
mse_mean = nn.MSELoss(reduction='mean')
print(f"Mean: {mse_mean(predictions, targets):.4f}")

# Sum
mse_sum = nn.MSELoss(reduction='sum')
print(f"Sum: {mse_sum(predictions, targets):.4f}")

# None (per-element)
mse_none = nn.MSELoss(reduction='none')
print(f"None: {mse_none(predictions, targets)}")

In [None]:
# Class weights for imbalanced data
# If class 0 is rare, give it higher weight
weights = torch.tensor([2.0, 1.0, 1.0])  # Class 0 counts double
ce_weighted = nn.CrossEntropyLoss(weight=weights)

logits = torch.randn(10, 3)
targets = torch.randint(0, 3, (10,))

print(f"Unweighted loss: {ce(logits, targets):.4f}")
print(f"Weighted loss: {ce_weighted(logits, targets):.4f}")

In [None]:
# Label smoothing (regularization technique)
# Instead of hard targets [0, 1, 0], use soft [0.1, 0.8, 0.1]
ce_smooth = nn.CrossEntropyLoss(label_smoothing=0.1)

logits = torch.tensor([[3.0, 0.5, 0.1]])  # Very confident
targets = torch.tensor([0])

print(f"Without smoothing: {ce(logits, targets):.4f}")
print(f"With smoothing: {ce_smooth(logits, targets):.4f}")
print("\nSmoothing penalizes overconfidence")

### 2.3 Custom Loss Functions

In [None]:
# Method 1: Simple function
def focal_loss(inputs, targets, alpha=0.25, gamma=2.0):
    """
    Focal Loss: down-weights easy examples, focuses on hard ones.
    FL = -alpha * (1 - p)^gamma * log(p)
    """
    ce_loss = F.cross_entropy(inputs, targets, reduction='none')
    pt = torch.exp(-ce_loss)  # probability of correct class
    focal_weight = alpha * (1 - pt) ** gamma
    return (focal_weight * ce_loss).mean()

logits = torch.randn(10, 3)
targets = torch.randint(0, 3, (10,))

print(f"Cross Entropy: {F.cross_entropy(logits, targets):.4f}")
print(f"Focal Loss: {focal_loss(logits, targets):.4f}")

In [None]:
# Method 2: nn.Module class (for learnable parameters)
class ContrastiveLoss(nn.Module):
    """
    Contrastive loss for similarity learning.
    For similar pairs: loss = distance^2
    For dissimilar pairs: loss = max(0, margin - distance)^2
    """
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin
    
    def forward(self, embeddings1, embeddings2, labels):
        """
        Args:
            embeddings1, embeddings2: (N, D) embeddings
            labels: (N,) 1 if similar, 0 if dissimilar
        """
        distances = F.pairwise_distance(embeddings1, embeddings2)
        
        similar_loss = labels * distances.pow(2)
        dissimilar_loss = (1 - labels) * F.relu(self.margin - distances).pow(2)
        
        return (similar_loss + dissimilar_loss).mean()

# Test
criterion = ContrastiveLoss(margin=2.0)
emb1 = torch.randn(5, 64)
emb2 = torch.randn(5, 64)
labels = torch.tensor([1, 1, 0, 0, 1], dtype=torch.float32)

loss = criterion(emb1, emb2, labels)
print(f"Contrastive loss: {loss:.4f}")

In [None]:
# Combining multiple losses
class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()
        self.mse = nn.MSELoss()
        self.alpha = alpha
    
    def forward(self, class_logits, regression_pred, class_targets, regression_targets):
        ce_loss = self.ce(class_logits, class_targets)
        mse_loss = self.mse(regression_pred, regression_targets)
        return self.alpha * ce_loss + (1 - self.alpha) * mse_loss

criterion = CombinedLoss(alpha=0.7)
print("Multi-task loss defined with weighted combination")

---
## 3. Optimizers

Optimizers update model parameters based on computed gradients. Different optimizers have different update rules and properties.

### 3.1 SGD (Stochastic Gradient Descent)

In [None]:
# Basic SGD: w = w - lr * grad
model = nn.Linear(10, 5)
sgd = optim.SGD(model.parameters(), lr=0.01)

x = torch.randn(3, 10)
y = model(x).sum()
y.backward()

# Manual update matches SGD
w_before = model.weight.data.clone()
sgd.step()
w_after = model.weight.data

expected = w_before - 0.01 * model.weight.grad
print(f"SGD update correct: {torch.allclose(w_after, expected)}")

In [None]:
# SGD with momentum
# v = momentum * v + grad
# w = w - lr * v
sgd_momentum = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Momentum accumulates gradient direction over time,
# helping escape local minima and smooth noisy gradients
print("Momentum accelerates consistent gradient directions")

In [None]:
# SGD with weight decay (L2 regularization)
# Penalizes large weights: loss = original_loss + (weight_decay/2) * ||w||^2
# Gradient becomes: grad + weight_decay * w
sgd_wd = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0001)

print("Weight decay prevents overfitting by keeping weights small")

### 3.2 Adam (Adaptive Moment Estimation)

In [None]:
# Adam maintains per-parameter adaptive learning rates
# using first moment (mean) and second moment (variance) of gradients

# m = beta1 * m + (1 - beta1) * grad          # First moment
# v = beta2 * v + (1 - beta2) * grad^2        # Second moment  
# m_hat = m / (1 - beta1^t)                   # Bias correction
# v_hat = v / (1 - beta2^t)                   # Bias correction
# w = w - lr * m_hat / (sqrt(v_hat) + eps)    # Update

model = nn.Linear(10, 5)
adam = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-8)

print("Adam adapts learning rate for each parameter based on gradient history")

In [None]:
# Adam state inspection
# After some training steps, Adam stores momentum for each parameter
for _ in range(10):
    adam.zero_grad()
    x = torch.randn(3, 10)
    model(x).sum().backward()
    adam.step()

# Inspect state
for name, param in model.named_parameters():
    if param in adam.state:
        state = adam.state[param]
        print(f"{name}:")
        print(f"  step: {state['step']}")
        print(f"  exp_avg (m) shape: {state['exp_avg'].shape}")
        print(f"  exp_avg_sq (v) shape: {state['exp_avg_sq'].shape}")

### 3.3 AdamW (Adam with Decoupled Weight Decay)

In [None]:
# AdamW fixes a subtle bug in Adam's weight decay implementation
#
# Adam with weight_decay: adds weight_decay * w to gradient BEFORE adaptive scaling
# AdamW: applies weight decay AFTER the Adam update (decoupled)
#
# Adam:  w = w - lr * (m_hat / sqrt(v_hat) + weight_decay * w)
# AdamW: w = w - lr * m_hat / sqrt(v_hat) - lr * weight_decay * w

model = nn.Linear(10, 5)
adamw = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

print("AdamW is generally preferred over Adam with weight_decay")
print("It's the default in many modern architectures (transformers, etc.)")

### 3.4 Comparing Optimizers

In [None]:
def train_with_optimizer(optimizer_class, optimizer_kwargs, epochs=50):
    """Train a model with given optimizer and return loss history."""
    torch.manual_seed(42)
    model = nn.Sequential(
        nn.Linear(10, 32),
        nn.ReLU(),
        nn.Linear(32, 3)
    )
    optimizer = optimizer_class(model.parameters(), **optimizer_kwargs)
    criterion = nn.CrossEntropyLoss()
    
    losses = []
    for _ in range(epochs):
        epoch_loss = 0
        for inputs, targets in dataloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(dataloader))
    return losses

# Compare optimizers
optimizers = {
    'SGD': (optim.SGD, {'lr': 0.1}),
    'SGD+Momentum': (optim.SGD, {'lr': 0.1, 'momentum': 0.9}),
    'Adam': (optim.Adam, {'lr': 0.01}),
    'AdamW': (optim.AdamW, {'lr': 0.01, 'weight_decay': 0.01}),
}

results = {}
for name, (opt_class, kwargs) in optimizers.items():
    results[name] = train_with_optimizer(opt_class, kwargs)

# Plot
plt.figure(figsize=(10, 5))
for name, losses in results.items():
    plt.plot(losses, label=name)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Optimizer Comparison')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

### 3.5 Parameter Groups

In [None]:
# Different learning rates for different layers
model = nn.Sequential(
    nn.Linear(10, 20),  # Early layer - lower LR
    nn.ReLU(),
    nn.Linear(20, 5)    # Later layer - higher LR
)

# Define parameter groups
optimizer = optim.Adam([
    {'params': model[0].parameters(), 'lr': 1e-4},   # Early layer
    {'params': model[2].parameters(), 'lr': 1e-3},   # Later layer
], lr=1e-3)  # Default LR (not used here since all groups specify lr)

print("Parameter groups:")
for i, group in enumerate(optimizer.param_groups):
    print(f"  Group {i}: lr = {group['lr']}, params = {len(group['params'])}")

In [None]:
# Modifying learning rates during training
for group in optimizer.param_groups:
    group['lr'] *= 0.1  # Reduce all LRs by 10x

print("After LR reduction:")
for i, group in enumerate(optimizer.param_groups):
    print(f"  Group {i}: lr = {group['lr']}")

---
## 4. Learning Rate Schedulers

Schedulers adjust the learning rate during training, often leading to better convergence.

In [None]:
# StepLR: Decay LR by gamma every step_size epochs
model = nn.Linear(10, 5)
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

lrs = []
for epoch in range(50):
    lrs.append(optimizer.param_groups[0]['lr'])
    # train_epoch(...)
    scheduler.step()  # Call after each epoch

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(lrs)
plt.title('StepLR (step=10, gamma=0.5)')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')

In [None]:
# CosineAnnealingLR: Smooth cosine decay
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

lrs = []
for epoch in range(50):
    lrs.append(optimizer.param_groups[0]['lr'])
    scheduler.step()

plt.subplot(1, 2, 2)
plt.plot(lrs)
plt.title('CosineAnnealingLR (T_max=50)')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.tight_layout()
plt.show()

In [None]:
# OneCycleLR: Popular for fast training (1cycle policy)
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.1,
    epochs=50,
    steps_per_epoch=len(dataloader)  # Called every batch, not epoch!
)

lrs = []
for epoch in range(50):
    for batch in dataloader:
        lrs.append(optimizer.param_groups[0]['lr'])
        scheduler.step()  # Called every batch!

plt.figure(figsize=(10, 4))
plt.plot(lrs)
plt.title('OneCycleLR (per-batch updates)')
plt.xlabel('Iteration')
plt.ylabel('Learning Rate')
plt.show()

In [None]:
# ReduceLROnPlateau: Reduce when metric stops improving
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min',      # Reduce when metric stops decreasing
    factor=0.5,      # Multiply LR by this factor
    patience=5,      # Wait this many epochs before reducing
    verbose=True
)

# Usage:
# for epoch in range(epochs):
#     train_loss = train(...)
#     val_loss = validate(...)
#     scheduler.step(val_loss)  # Pass the metric to monitor!

print("ReduceLROnPlateau requires passing the metric to step()")

In [None]:
# Warmup + Decay combination
def get_warmup_scheduler(optimizer, warmup_epochs, total_epochs):
    """Linear warmup followed by cosine decay."""
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return epoch / warmup_epochs
        else:
            progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
            return 0.5 * (1 + np.cos(np.pi * progress))
    
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = get_warmup_scheduler(optimizer, warmup_epochs=5, total_epochs=50)

lrs = []
for epoch in range(50):
    lrs.append(optimizer.param_groups[0]['lr'])
    scheduler.step()

plt.figure(figsize=(10, 4))
plt.plot(lrs)
plt.axvline(x=5, color='r', linestyle='--', label='End of warmup')
plt.title('Warmup + Cosine Decay')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.legend()
plt.show()

---
## 5. Gradient Clipping

Gradient clipping prevents exploding gradients by limiting gradient magnitudes.

In [None]:
# Gradient explosion example
model = nn.Sequential(*[nn.Linear(10, 10) for _ in range(50)])  # Deep network

# Initialize with values that cause explosion
for layer in model:
    if hasattr(layer, 'weight'):
        nn.init.uniform_(layer.weight, 1.0, 1.1)

x = torch.randn(1, 10)
y = model(x).sum()
y.backward()

# Check gradient norms
for i, layer in enumerate(model[:5]):
    if hasattr(layer, 'weight'):
        print(f"Layer {i} gradient norm: {layer.weight.grad.norm():.2e}")

print("\nGradients exploded in early layers!")

In [None]:
# clip_grad_norm_: Clip by global norm
# If ||grad|| > max_norm, scale all gradients so ||grad|| = max_norm

model = nn.Linear(10, 5)
x = torch.randn(3, 10) * 100  # Large input -> large gradients
model(x).sum().backward()

print(f"Before clipping: {model.weight.grad.norm():.4f}")

# Clip
max_norm = 1.0
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

print(f"Total norm was: {total_norm:.4f}")
print(f"After clipping: {model.weight.grad.norm():.4f}")

In [None]:
# clip_grad_value_: Clip by value
# Clamps each gradient element to [-clip_value, clip_value]

model = nn.Linear(10, 5)
x = torch.randn(3, 10) * 100
model(x).sum().backward()

print(f"Before clipping:")
print(f"  Max: {model.weight.grad.max():.4f}")
print(f"  Min: {model.weight.grad.min():.4f}")

torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)

print(f"\nAfter clipping:")
print(f"  Max: {model.weight.grad.max():.4f}")
print(f"  Min: {model.weight.grad.min():.4f}")

In [None]:
# Training loop with gradient clipping
def train_with_clipping(model, dataloader, criterion, optimizer, max_grad_norm=1.0):
    model.train()
    for inputs, targets in dataloader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        
        # Clip gradients BEFORE optimizer step
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        
        optimizer.step()

print("Gradient clipping is especially important for RNNs and transformers")

---
## 6. Numerical Stability

In [None]:
# Problem: Overflow in softmax
logits = torch.tensor([1000.0, 1001.0, 1002.0])
try:
    naive_softmax = torch.exp(logits) / torch.exp(logits).sum()
    print(f"Naive softmax: {naive_softmax}")
except:
    print("Overflow!")

# Solution: Subtract max before exp
stable_softmax = F.softmax(logits, dim=0)
print(f"Stable softmax: {stable_softmax}")

In [None]:
# Problem: Log of small probabilities
probs = torch.tensor([0.99, 0.009, 0.001, 1e-10])
log_probs = torch.log(probs)
print(f"Log probs: {log_probs}")
print("Note: log(1e-10) = -23, which can cause issues")

# Solution: Use log_softmax instead of softmax + log
logits = torch.randn(4)
log_probs_stable = F.log_softmax(logits, dim=0)
print(f"\nStable log_softmax: {log_probs_stable}")

In [None]:
# Detecting NaN and Inf
def check_for_nan(model, loss):
    """Check for NaN in loss or gradients."""
    if torch.isnan(loss) or torch.isinf(loss):
        print("Warning: Loss is NaN or Inf!")
        return True
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            if torch.isnan(param.grad).any():
                print(f"Warning: NaN gradient in {name}")
                return True
            if torch.isinf(param.grad).any():
                print(f"Warning: Inf gradient in {name}")
                return True
    return False

# Usage in training loop:
# if check_for_nan(model, loss):
#     print("Skipping batch due to numerical issues")
#     continue

In [None]:
# Use anomaly detection during debugging
# (SLOW - only use for debugging)

torch.autograd.set_detect_anomaly(True)

x = torch.tensor([0.0], requires_grad=True)
try:
    y = torch.log(x)  # log(0) = -inf
    z = y * 2
    z.backward()
except RuntimeError as e:
    print(f"Anomaly detected: {str(e)[:100]}...")

torch.autograd.set_detect_anomaly(False)

---
## Exercises

### Exercise 1: Implement SGD from Scratch

In [None]:
class MySGD:
    """
    Implement SGD with momentum from scratch.
    
    Update rule:
    if momentum > 0:
        v = momentum * v + grad
        param = param - lr * v
    else:
        param = param - lr * grad
    """
    def __init__(self, parameters, lr, momentum=0):
        self.parameters = list(parameters)
        self.lr = lr
        self.momentum = momentum
        # YOUR CODE: Initialize velocity buffers if momentum > 0
        self.velocities = None
    
    def zero_grad(self):
        # YOUR CODE: Set all gradients to zero
        pass
    
    def step(self):
        # YOUR CODE: Update parameters using SGD with momentum
        pass

# Test
# model = nn.Linear(10, 5)
# my_sgd = MySGD(model.parameters(), lr=0.01, momentum=0.9)
# x = torch.randn(3, 10)
# model(x).sum().backward()
# my_sgd.step()
# print("Custom SGD step completed")

### Exercise 2: Custom Learning Rate Scheduler

In [None]:
class CyclicLR:
    """
    Implement triangular cyclic learning rate.
    
    LR oscillates between base_lr and max_lr over step_size*2 iterations.
    
    Args:
        optimizer: The optimizer
        base_lr: Minimum learning rate
        max_lr: Maximum learning rate
        step_size: Number of iterations to go from base to max
    """
    def __init__(self, optimizer, base_lr, max_lr, step_size):
        self.optimizer = optimizer
        self.base_lr = base_lr
        self.max_lr = max_lr
        self.step_size = step_size
        self.iteration = 0
    
    def step(self):
        # YOUR CODE: Calculate and set new learning rate
        # Cycle position goes from 0 to 1 and back to 0
        pass

# Test
# model = nn.Linear(10, 5)
# optimizer = optim.SGD(model.parameters(), lr=0.1)
# scheduler = CyclicLR(optimizer, base_lr=0.001, max_lr=0.1, step_size=100)
# lrs = []
# for i in range(400):
#     lrs.append(optimizer.param_groups[0]['lr'])
#     scheduler.step()
# plt.plot(lrs)
# plt.title('Cyclic LR')
# plt.show()

### Exercise 3: Training with Early Stopping

In [None]:
class EarlyStopping:
    """
    Early stopping to prevent overfitting.
    
    Stop training if validation loss doesn't improve for `patience` epochs.
    Save the best model weights.
    
    Args:
        patience: Number of epochs to wait before stopping
        min_delta: Minimum change to consider as improvement
    """
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.best_weights = None
        self.should_stop = False
    
    def __call__(self, val_loss, model):
        """
        Call after each validation.
        Returns True if training should stop.
        """
        # YOUR CODE:
        # 1. Check if this is the first call or if val_loss improved
        # 2. If improved, reset counter and save best weights
        # 3. If not improved, increment counter
        # 4. Set should_stop = True if counter >= patience
        pass
    
    def load_best_weights(self, model):
        """Restore the best weights."""
        # YOUR CODE
        pass

# Usage:
# early_stopping = EarlyStopping(patience=5)
# for epoch in range(100):
#     train_loss = train_epoch(...)
#     val_loss = validate_epoch(...)
#     if early_stopping(val_loss, model):
#         print(f"Early stopping at epoch {epoch}")
#         break
# early_stopping.load_best_weights(model)

---
## Solutions

In [None]:
# Exercise 1 Solution
class MySGDSolution:
    def __init__(self, parameters, lr, momentum=0):
        self.parameters = list(parameters)
        self.lr = lr
        self.momentum = momentum
        if momentum > 0:
            self.velocities = [torch.zeros_like(p) for p in self.parameters]
        else:
            self.velocities = None
    
    def zero_grad(self):
        for param in self.parameters:
            if param.grad is not None:
                param.grad.zero_()
    
    def step(self):
        with torch.no_grad():
            for i, param in enumerate(self.parameters):
                if param.grad is None:
                    continue
                
                if self.momentum > 0:
                    self.velocities[i] = self.momentum * self.velocities[i] + param.grad
                    param -= self.lr * self.velocities[i]
                else:
                    param -= self.lr * param.grad

print("Exercise 1 Solution:")
model = nn.Linear(10, 5)
my_sgd = MySGDSolution(model.parameters(), lr=0.01, momentum=0.9)
x = torch.randn(3, 10)
model(x).sum().backward()
my_sgd.step()
print("Custom SGD step completed")

In [None]:
# Exercise 2 Solution
class CyclicLRSolution:
    def __init__(self, optimizer, base_lr, max_lr, step_size):
        self.optimizer = optimizer
        self.base_lr = base_lr
        self.max_lr = max_lr
        self.step_size = step_size
        self.iteration = 0
    
    def step(self):
        cycle = self.iteration // (2 * self.step_size)
        x = abs(self.iteration / self.step_size - 2 * cycle - 1)
        lr = self.base_lr + (self.max_lr - self.base_lr) * max(0, 1 - x)
        
        for group in self.optimizer.param_groups:
            group['lr'] = lr
        
        self.iteration += 1

print("\nExercise 2 Solution:")
model = nn.Linear(10, 5)
optimizer = optim.SGD(model.parameters(), lr=0.1)
scheduler = CyclicLRSolution(optimizer, base_lr=0.001, max_lr=0.1, step_size=100)
lrs = []
for i in range(400):
    lrs.append(optimizer.param_groups[0]['lr'])
    scheduler.step()

plt.figure(figsize=(10, 4))
plt.plot(lrs)
plt.title('Cyclic LR (Triangular)')
plt.xlabel('Iteration')
plt.ylabel('Learning Rate')
plt.show()

In [None]:
# Exercise 3 Solution
import copy

class EarlyStoppingSolution:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.best_weights = None
        self.should_stop = False
    
    def __call__(self, val_loss, model):
        if self.best_loss is None or val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.best_weights = copy.deepcopy(model.state_dict())
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        
        return self.should_stop
    
    def load_best_weights(self, model):
        if self.best_weights is not None:
            model.load_state_dict(self.best_weights)

print("\nExercise 3 Solution:")
# Simulate training with early stopping
early_stopping = EarlyStoppingSolution(patience=3)
model = nn.Linear(10, 5)

# Simulate validation losses that plateau
fake_val_losses = [1.0, 0.8, 0.6, 0.5, 0.5, 0.51, 0.52, 0.53]

for epoch, val_loss in enumerate(fake_val_losses):
    stop = early_stopping(val_loss, model)
    print(f"Epoch {epoch}: val_loss={val_loss:.2f}, counter={early_stopping.counter}, stop={stop}")
    if stop:
        print(f"\nStopped at epoch {epoch}. Best loss: {early_stopping.best_loss:.2f}")
        break

early_stopping.load_best_weights(model)
print("Best weights restored")

---
## Summary

Key takeaways from this notebook:

1. **Training Loop Steps**: zero_grad → forward → loss → backward → step
2. **Loss Functions**: Choose based on task; use numerically stable versions
3. **Optimizers**: SGD with momentum for simplicity, Adam/AdamW for adaptive learning
4. **Learning Rate Schedulers**: Warmup + decay, cosine annealing, or reduce-on-plateau
5. **Gradient Clipping**: Essential for RNNs and transformers; use `clip_grad_norm_`
6. **Numerical Stability**: Use `log_softmax`, avoid log(0), detect anomalies

---
*Next: Module 2.2 - Data Pipeline Mastery*