# Day 03: Advanced CNNs and Data Augmentation

**Goal:** Learn advanced CNN techniques including data augmentation, batch normalization, and learning rate scheduling.

**Mathematical Focus:** Batch normalization as affine transformation, dropout as stochastic regularization, LR schedules from optimization theory.

**Time estimate:** 3-4 hours

## Theory: Why These Techniques Matter

### Data Augmentation
- **Mathematical view:** Expanding training distribution by applying label-preserving transformations
- **Effect:** Reduces overfitting by increasing effective dataset size
- **Transformations:** Rotations, flips, crops preserve class labels

### Batch Normalization
- **Formula:** $\hat{x} = \frac{x - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}$, then $y = \gamma\hat{x} + \beta$
- **Effect:** Normalizes layer inputs, stabilizes gradient flow
- **Parameters:** $\gamma$ (scale) and $\beta$ (shift) are learned
- **Benefit:** Smoother loss landscape, faster training

### Dropout
- **Mathematical view:** Stochastic regularization by randomly dropping neurons
- **Effect:** Prevents co-adaptation of features, acts as ensemble
- **Training:** Drop with probability $p$, scale by $1/(1-p)$
- **Testing:** Use all neurons (scaling already handled)

### Learning Rate Scheduling
- **Motivation:** Large LR early (fast convergence), small LR late (fine-tuning)
- **StepLR:** Multiply LR by $\gamma$ every $n$ steps
- **ReduceLROnPlateau:** Adaptive reduction when validation plateaus

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Data Augmentation

We'll create two datasets:
1. **Standard:** Basic normalization only
2. **Augmented:** Rotations, translations, affine transforms

In [None]:
# Standard transform (baseline from Day 2)
transform_standard = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
])

# Augmented transform
transform_augmented = transforms.Compose([
    transforms.RandomRotation(15),              # Rotate ¬±15 degrees
    transforms.RandomAffine(                    # Affine transformations
        degrees=0,
        translate=(0.1, 0.1),                   # Shift ¬±10%
        scale=(0.9, 1.1)                        # Scale 90-110%
    ),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load datasets
train_dataset_standard = datasets.MNIST(root='./data', train=True, 
                                       download=True, transform=transform_standard)
train_dataset_augmented = datasets.MNIST(root='./data', train=True,
                                         download=True, transform=transform_augmented)
test_dataset = datasets.MNIST(root='./data', train=False,
                             download=True, transform=transform_standard)

# DataLoaders
train_loader_standard = DataLoader(train_dataset_standard, batch_size=64, shuffle=True)
train_loader_augmented = DataLoader(train_dataset_augmented, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

print(f"Training samples: {len(train_dataset_standard)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
# Visualize augmentation effects
fig, axes = plt.subplots(2, 5, figsize=(12, 5))

# Get one image
original_img, label = train_dataset_standard[0]

# Show original
axes[0, 0].imshow(original_img.squeeze(), cmap='gray')
axes[0, 0].set_title('Original')
axes[0, 0].axis('off')

# Show 4 augmented versions
for i in range(1, 5):
    aug_img, _ = train_dataset_augmented[0]  # Same index, different transform
    axes[0, i].imshow(aug_img.squeeze(), cmap='gray')
    axes[0, i].set_title(f'Augmented {i}')
    axes[0, i].axis('off')

# Show another digit
original_img2, label2 = train_dataset_standard[1]
axes[1, 0].imshow(original_img2.squeeze(), cmap='gray')
axes[1, 0].set_title('Original')
axes[1, 0].axis('off')

for i in range(1, 5):
    aug_img2, _ = train_dataset_augmented[1]
    axes[1, i].imshow(aug_img2.squeeze(), cmap='gray')
    axes[1, i].set_title(f'Augmented {i}')
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

print("Note: Each time the augmented dataset is accessed, different random transforms are applied.")

## 2. CNN with Batch Normalization and Dropout

We'll build an improved CNN with:
- Batch normalization after conv layers
- Dropout for regularization
- More capacity than Day 2

In [None]:
class ImprovedCNN(nn.Module):
    """
    CNN with Batch Normalization and Dropout
    
    Architecture:
    - Conv1: 1 -> 32 channels, 3x3 kernel
    - BatchNorm2d
    - ReLU + MaxPool
    - Conv2: 32 -> 64 channels, 3x3 kernel
    - BatchNorm2d
    - ReLU + MaxPool
    - Flatten
    - FC1: 1600 -> 128
    - Dropout(0.5)
    - FC2: 128 -> 10
    """
    def __init__(self, dropout_rate=0.5):
        super(ImprovedCNN, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)  # Normalize 32 channels
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)  # Normalize 64 channels
        
        self.pool = nn.MaxPool2d(2, 2)
        
        # After 2 pooling: 28 -> 14 -> 7
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        # Conv block 1
        x = self.conv1(x)           # [N, 1, 28, 28] -> [N, 32, 28, 28]
        x = self.bn1(x)             # Normalize
        x = F.relu(x)               # Activation
        x = self.pool(x)            # [N, 32, 28, 28] -> [N, 32, 14, 14]
        
        # Conv block 2
        x = self.conv2(x)           # [N, 32, 14, 14] -> [N, 64, 14, 14]
        x = self.bn2(x)             # Normalize
        x = F.relu(x)               # Activation
        x = self.pool(x)            # [N, 64, 14, 14] -> [N, 64, 7, 7]
        
        # Flatten and FC layers
        x = x.view(-1, 64 * 7 * 7)  # [N, 64, 7, 7] -> [N, 3136]
        x = F.relu(self.fc1(x))     # [N, 3136] -> [N, 128]
        x = self.dropout(x)         # Dropout during training
        x = self.fc2(x)             # [N, 128] -> [N, 10]
        
        return x

# Create model and count parameters
model = ImprovedCNN(dropout_rate=0.5).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print("\nModel architecture:")
print(model)

## 3. Training with Learning Rate Scheduling

We'll implement:
1. Model checkpointing (save best model)
2. Learning rate scheduling (StepLR)
3. Training with metrics tracking

In [None]:
def train_model(model, train_loader, test_loader, num_epochs=10, lr=0.001, use_scheduler=True):
    """
    Train model with learning rate scheduling and checkpointing
    
    Returns:
        history: dict with training metrics
    """
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Learning rate scheduler: reduce LR by 0.1 every 3 epochs
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.5) if use_scheduler else None
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'test_acc': [],
        'learning_rates': []
    }
    
    best_acc = 0.0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct / total
        
        # Validation phase
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        test_acc = 100 * correct / total
        
        # Save best model
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'best_model.pth')
        
        # Update learning rate
        current_lr = optimizer.param_groups[0]['lr']
        if scheduler:
            scheduler.step()
        
        # Record history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_acc'].append(test_acc)
        history['learning_rates'].append(current_lr)
        
        print(f'Epoch [{epoch+1}/{num_epochs}] '
              f'Train Loss: {train_loss:.4f}, '
              f'Train Acc: {train_acc:.2f}%, '
              f'Test Acc: {test_acc:.2f}%, '
              f'LR: {current_lr:.6f}')
    
    print(f'\nBest Test Accuracy: {best_acc:.2f}%')
    return history

## 4. Experiment: Standard vs Augmented Data

In [None]:
print("=" * 60)
print("Training with STANDARD data (no augmentation)")
print("=" * 60)

model_standard = ImprovedCNN(dropout_rate=0.5).to(device)
history_standard = train_model(model_standard, train_loader_standard, 
                               test_loader, num_epochs=10, lr=0.001)

In [None]:
print("\n" + "=" * 60)
print("Training with AUGMENTED data")
print("=" * 60)

model_augmented = ImprovedCNN(dropout_rate=0.5).to(device)
history_augmented = train_model(model_augmented, train_loader_augmented,
                               test_loader, num_epochs=10, lr=0.001)

## 5. Analysis and Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

epochs = range(1, len(history_standard['train_loss']) + 1)

# Training loss
axes[0, 0].plot(epochs, history_standard['train_loss'], 'b-', label='Standard', marker='o')
axes[0, 0].plot(epochs, history_augmented['train_loss'], 'r-', label='Augmented', marker='s')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Training Loss')
axes[0, 0].set_title('Training Loss Comparison')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Training accuracy
axes[0, 1].plot(epochs, history_standard['train_acc'], 'b-', label='Standard', marker='o')
axes[0, 1].plot(epochs, history_augmented['train_acc'], 'r-', label='Augmented', marker='s')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Training Accuracy (%)')
axes[0, 1].set_title('Training Accuracy Comparison')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Test accuracy
axes[1, 0].plot(epochs, history_standard['test_acc'], 'b-', label='Standard', marker='o')
axes[1, 0].plot(epochs, history_augmented['test_acc'], 'r-', label='Augmented', marker='s')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Test Accuracy (%)')
axes[1, 0].set_title('Test Accuracy Comparison')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Learning rate schedule
axes[1, 1].plot(epochs, history_standard['learning_rates'], 'g-', marker='o')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Learning Rate')
axes[1, 1].set_title('Learning Rate Schedule (StepLR)')
axes[1, 1].set_yscale('log')
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

# Print summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Standard Data - Final Test Acc: {history_standard['test_acc'][-1]:.2f}%")
print(f"Augmented Data - Final Test Acc: {history_augmented['test_acc'][-1]:.2f}%")
print(f"\nImprovement from augmentation: {history_augmented['test_acc'][-1] - history_standard['test_acc'][-1]:.2f}%")

## 6. Experiment: Effect of Dropout Rate

In [None]:
# Test different dropout rates
dropout_rates = [0.0, 0.3, 0.5, 0.7]
dropout_results = {}

for dropout_rate in dropout_rates:
    print(f"\n{'='*60}")
    print(f"Training with Dropout Rate: {dropout_rate}")
    print(f"{'='*60}")
    
    model = ImprovedCNN(dropout_rate=dropout_rate).to(device)
    history = train_model(model, train_loader_standard, test_loader, 
                         num_epochs=5, lr=0.001)
    dropout_results[dropout_rate] = history

# Plot dropout comparison
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
for dropout_rate in dropout_rates:
    plt.plot(dropout_results[dropout_rate]['train_acc'], 
             label=f'Dropout={dropout_rate}', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Training Accuracy (%)')
plt.title('Effect of Dropout on Training Accuracy')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
for dropout_rate in dropout_rates:
    plt.plot(dropout_results[dropout_rate]['test_acc'],
             label=f'Dropout={dropout_rate}', marker='s')
plt.xlabel('Epoch')
plt.ylabel('Test Accuracy (%)')
plt.title('Effect of Dropout on Test Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

print("\nFinal Test Accuracies:")
for dropout_rate in dropout_rates:
    final_acc = dropout_results[dropout_rate]['test_acc'][-1]
    print(f"Dropout {dropout_rate}: {final_acc:.2f}%")

## 7. Model Checkpointing and Loading

In [None]:
# Save complete model (architecture + weights)
torch.save({
    'model_state_dict': model_augmented.state_dict(),
    'architecture': 'ImprovedCNN',
    'dropout_rate': 0.5,
    'final_test_acc': history_augmented['test_acc'][-1]
}, 'checkpoint_day03.pth')

print("Model checkpoint saved to 'checkpoint_day03.pth'")

# Load model
checkpoint = torch.load('checkpoint_day03.pth')
loaded_model = ImprovedCNN(dropout_rate=checkpoint['dropout_rate']).to(device)
loaded_model.load_state_dict(checkpoint['model_state_dict'])

print(f"Model loaded successfully")
print(f"Saved test accuracy: {checkpoint['final_test_acc']:.2f}%")

# Verify loaded model
loaded_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = loaded_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Verified accuracy: {100 * correct / total:.2f}%")

## 8. Key Takeaways

### Data Augmentation
- ‚úÖ Increases effective dataset size
- ‚úÖ Reduces overfitting
- ‚úÖ Improves generalization
- ‚ö†Ô∏è May slightly slow training (more computation)
- ‚ö†Ô∏è Can hurt performance if transforms are too aggressive

### Batch Normalization
- ‚úÖ Stabilizes training (smoother loss landscape)
- ‚úÖ Allows higher learning rates
- ‚úÖ Acts as mild regularization
- ‚úÖ Reduces internal covariate shift
- ‚ö†Ô∏è Different behavior in train vs test mode

### Dropout
- ‚úÖ Strong regularization technique
- ‚úÖ Prevents overfitting
- ‚úÖ Acts as ensemble of models
- ‚ö†Ô∏è Too high dropout (>0.7) hurts training
- ‚ö†Ô∏è Must disable during inference

### Learning Rate Scheduling
- ‚úÖ Improves final convergence
- ‚úÖ Large LR early = fast training
- ‚úÖ Small LR late = fine-tuning
- üìä StepLR is simple and effective
- üìä ReduceLROnPlateau is adaptive to validation

### Best Practices Learned
1. Always use data augmentation for image tasks
2. Add batch norm after conv layers
3. Use dropout (0.5) before final FC layers
4. Start with higher LR, decay over time
5. Save best model based on validation accuracy
6. Monitor both training and validation metrics

## 9. Exercises

1. **Try ReduceLROnPlateau scheduler** instead of StepLR
2. **Experiment with batch norm placement**: Before or after ReLU?
3. **Add more augmentation**: ColorJitter, RandomErasing
4. **Compare BatchNorm vs LayerNorm vs GroupNorm**
5. **Implement early stopping**: Stop if validation doesn't improve for N epochs
6. **Try different dropout locations**: After each conv layer vs only FC layers