# Day 9: Advanced CNN Architectures

**Time:** 3-4 hours

**Mathematical Prerequisites:**
- Linear algebra (matrix operations, dimensions)
- Calculus (gradient flow through networks)
- Understanding of convolutions from signal processing
- Optimization theory (gradient vanishing/explosion)

---

## Objectives

Today we implement famous CNN architectures from their original papers:
1. **VGG Network** (Simonyan & Zisserman, 2014)
2. **ResNet** (He et al., 2015)
3. **Inception Module** (Szegedy et al., 2014)

We'll understand:
- Why deeper networks were difficult to train
- How skip connections solve gradient vanishing
- Mathematical analysis of gradient flow
- Trade-offs between depth, width, and parameters

**Goal:** Build ResNet from scratch and achieve competitive performance on CIFAR-10

---

## Part 1: Theory - The Depth Problem

### 1.1 The Vanishing/Exploding Gradient Problem

For a deep network with $L$ layers:
$$
\frac{\partial L}{\partial \theta_1} = \frac{\partial L}{\partial h_L} \cdot \frac{\partial h_L}{\partial h_{L-1}} \cdot \ldots \cdot \frac{\partial h_2}{\partial h_1} \cdot \frac{\partial h_1}{\partial \theta_1}
$$

**Problem:** This is a product of $L$ terms. If each term has magnitude:
- $< 1$: Gradients vanish exponentially $\rightarrow$ Early layers don't learn
- $> 1$: Gradients explode exponentially $\rightarrow$ Training diverges

**Empirical observation:** Very deep plain networks (no skip connections) have **higher training error** than shallower networks. This is NOT overfitting—it's optimization failure!

### 1.2 The Degradation Problem

**Hypothesis (He et al., 2015):** If deeper networks can represent identity mappings, they should at worst perform as well as shallower networks.

**Reality:** Optimization algorithms struggle to learn identity mappings in deep plain networks.

**Solution:** Make identity mapping **explicit** via skip connections.

### 1.3 Residual Learning

Instead of learning $H(x)$, learn $F(x) = H(x) - x$ (the residual).

**Output:** $H(x) = F(x) + x$

**Key insight:** If identity is optimal, $F(x) = 0$ is easier to learn than $H(x) = x$.

**Gradient flow:**
$$
\frac{\partial L}{\partial x} = \frac{\partial L}{\partial H} \cdot \frac{\partial H}{\partial x} = \frac{\partial L}{\partial H} \cdot \left(1 + \frac{\partial F}{\partial x}\right)
$$

The "1" term ensures gradient always flows directly through skip connection, even if $\frac{\partial F}{\partial x}$ is small!

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from tqdm import tqdm
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

np.random.seed(42)
torch.manual_seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Part 2: Load CIFAR-10 Dataset

In [None]:
# CIFAR-10 specific normalization
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Load datasets
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test
)

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
           'dog', 'frog', 'horse', 'ship', 'truck']

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Batch size: {batch_size}")
print(f"Training batches: {len(train_loader)}")

## Part 3: VGG Network

### 3.1 VGG Architecture Philosophy

**Key ideas from Simonyan & Zisserman (2014):**
1. Use very small (3×3) convolutional filters
2. Stack multiple conv layers before pooling
3. Increase depth systematically

**Why 3×3?**
- Two 3×3 layers = one 5×5 receptive field
- Three 3×3 layers = one 7×7 receptive field
- **Fewer parameters:** $3 \times (3^2 \cdot C^2) = 27C^2$ vs $7^2 \cdot C^2 = 49C^2$
- **More non-linearities:** Each layer adds ReLU

### 3.2 VGG16 Architecture

```
Input: 224×224×3
|
Conv3-64 × 2  →  MaxPool  →  112×112×64
|
Conv3-128 × 2  →  MaxPool  →  56×56×128
|
Conv3-256 × 3  →  MaxPool  →  28×28×256
|
Conv3-512 × 3  →  MaxPool  →  14×14×512
|
Conv3-512 × 3  →  MaxPool  →  7×7×512
|
FC-4096  →  FC-4096  →  FC-1000  →  Softmax
```

**Total parameters:** ~138 million (mostly in FC layers!)

In [None]:
class VGGBlock(nn.Module):
    """VGG building block: stack of conv layers followed by maxpool."""
    
    def __init__(self, in_channels, out_channels, num_convs):
        super(VGGBlock, self).__init__()
        
        layers = []
        for i in range(num_convs):
            if i == 0:
                layers.append(nn.Conv2d(in_channels, out_channels, 3, padding=1))
            else:
                layers.append(nn.Conv2d(out_channels, out_channels, 3, padding=1))
            layers.append(nn.BatchNorm2d(out_channels))
            layers.append(nn.ReLU(inplace=True))
        
        layers.append(nn.MaxPool2d(2, 2))
        
        self.block = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.block(x)


class VGGNet(nn.Module):
    """VGG-style network adapted for CIFAR-10 (32×32 images).
    
    Architecture:
    - 2 conv blocks with 64 and 128 channels
    - 3 conv blocks with 256, 512, 512 channels
    - FC layers
    """
    
    def __init__(self, num_classes=10):
        super(VGGNet, self).__init__()
        
        # Feature extractor (conv blocks)
        self.features = nn.Sequential(
            VGGBlock(3, 64, 2),      # 32 → 16
            VGGBlock(64, 128, 2),    # 16 → 8
            VGGBlock(128, 256, 3),   # 8 → 4
            VGGBlock(256, 512, 3),   # 4 → 2
            VGGBlock(512, 512, 3),   # 2 → 1
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

# Create VGG model
vgg_model = VGGNet(num_classes=10).to(device)

# Count parameters
total_params = sum(p.numel() for p in vgg_model.parameters())
print(f"VGGNet Parameters: {total_params:,}")

# Test forward pass
dummy_input = torch.randn(1, 3, 32, 32).to(device)
output = vgg_model(dummy_input)
print(f"Output shape: {output.shape}")

## Part 4: ResNet (Residual Networks)

### 4.1 Residual Block

**Basic Block (ResNet-18, ResNet-34):**
```
x ───────────────────┐
|                    |
Conv3x3 → BN → ReLU  |
|                    |
Conv3x3 → BN         |
|                    |
+  ←─────────────────┘
|
ReLU
```

**Bottleneck Block (ResNet-50, 101, 152):**
```
x ───────────────────────┐
|                        |
Conv1x1 → BN → ReLU      | (reduce channels)
|                        |
Conv3x3 → BN → ReLU      | (process)
|                        |
Conv1x1 → BN             | (expand channels)
|                        |
+  ←─────────────────────┘
|
ReLU
```

### 4.2 Handling Dimension Mismatch

When spatial dimensions or channels change, we need **projection shortcut:**
$$
y = F(x, W_i) + W_s x
$$
where $W_s$ is a 1×1 conv that matches dimensions.

In [None]:
class BasicBlock(nn.Module):
    """Basic residual block for ResNet-18/34.
    
    Two 3x3 conv layers with skip connection.
    """
    
    expansion = 1  # Output channels = planes * expansion
    
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        
        # First conv (may downsample spatial dimensions)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        
        # Second conv (maintains dimensions)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        
        # Skip connection (identity or projection)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            # Projection shortcut to match dimensions
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
                         stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
    
    def forward(self, x):
        # Main branch (residual)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        # Skip connection
        out += self.shortcut(x)
        
        # Final ReLU
        out = F.relu(out)
        
        return out


class BottleneckBlock(nn.Module):
    """Bottleneck residual block for ResNet-50/101/152.
    
    1x1 reduce → 3x3 process → 1x1 expand.
    More efficient for deeper networks.
    """
    
    expansion = 4  # Output channels = planes * 4
    
    def __init__(self, in_planes, planes, stride=1):
        super(BottleneckBlock, self).__init__()
        
        # 1x1 conv to reduce channels
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        
        # 3x3 conv (may downsample)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        
        # 1x1 conv to expand channels
        self.conv3 = nn.Conv2d(planes, self.expansion * planes,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)
        
        # Skip connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
                         stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        
        out += self.shortcut(x)
        out = F.relu(out)
        
        return out

### 4.3 Complete ResNet Architecture

In [None]:
class ResNet(nn.Module):
    """ResNet for CIFAR-10 (32×32 images).
    
    Modified from original (224×224) by:
    - Smaller initial conv (3×3 instead of 7×7)
    - No initial maxpool
    """
    
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        
        self.in_planes = 64
        
        # Initial conv layer (no maxpool for CIFAR)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        # Residual layers
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        # Global average pooling and classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        """Create a layer with multiple residual blocks."""
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        # Initial conv
        out = F.relu(self.bn1(self.conv1(x)))
        
        # Residual layers
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        
        # Global average pooling
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        
        # Classifier
        out = self.fc(out)
        
        return out


def ResNet18():
    """ResNet-18: 18 layers with basic blocks."""
    return ResNet(BasicBlock, [2, 2, 2, 2])

def ResNet34():
    """ResNet-34: 34 layers with basic blocks."""
    return ResNet(BasicBlock, [3, 4, 6, 3])

def ResNet50():
    """ResNet-50: 50 layers with bottleneck blocks."""
    return ResNet(BottleneckBlock, [3, 4, 6, 3])


# Create ResNet-18
resnet18 = ResNet18().to(device)
resnet18_params = sum(p.numel() for p in resnet18.parameters())
print(f"ResNet-18 Parameters: {resnet18_params:,}")

# Create ResNet-34
resnet34 = ResNet34().to(device)
resnet34_params = sum(p.numel() for p in resnet34.parameters())
print(f"ResNet-34 Parameters: {resnet34_params:,}")

# Create ResNet-50
resnet50 = ResNet50().to(device)
resnet50_params = sum(p.numel() for p in resnet50.parameters())
print(f"ResNet-50 Parameters: {resnet50_params:,}")

# Test forward pass
dummy_input = torch.randn(2, 3, 32, 32).to(device)
output = resnet18(dummy_input)
print(f"\nResNet-18 output shape: {output.shape}")

## Part 5: Plain Network (No Skip Connections)

For comparison, let's build a plain network with the same depth but NO skip connections.

In [None]:
class PlainBlock(nn.Module):
    """Plain block WITHOUT skip connection."""
    
    def __init__(self, in_planes, planes, stride=1):
        super(PlainBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out  # NO skip connection!


class PlainNet(nn.Module):
    """Plain network (no skip connections) for comparison."""
    
    def __init__(self, block, num_blocks, num_classes=10):
        super(PlainNet, self).__init__()
        
        self.in_planes = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
    
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def PlainNet18():
    return PlainNet(PlainBlock, [2, 2, 2, 2])

# Create plain network
plain_net = PlainNet18().to(device)
plain_params = sum(p.numel() for p in plain_net.parameters())
print(f"PlainNet-18 Parameters: {plain_params:,}")
print(f"\nNote: PlainNet has similar parameters to ResNet but NO skip connections.")

## Part 6: Inception Module (Optional)

### 6.1 Multi-Scale Feature Extraction

**Key insight:** Different objects/features may have different scales. Instead of choosing one filter size, use multiple in parallel!

**Original Inception Module:**
```
Input
  ├─ 1×1 conv ─────────────┐
  ├─ 1×1 conv → 3×3 conv ──┤
  ├─ 1×1 conv → 5×5 conv ──┤  Concatenate
  └─ 3×3 maxpool → 1×1 conv┘
```

**1×1 convolutions:**
- Reduce channels (dimensionality reduction)
- Cross-channel information
- Add non-linearity
- Reduce computational cost

In [None]:
class InceptionModule(nn.Module):
    """Simplified Inception module."""
    
    def __init__(self, in_channels, ch1x1, ch3x3_reduce, ch3x3, 
                 ch5x5_reduce, ch5x5, pool_proj):
        super(InceptionModule, self).__init__()
        
        # 1×1 branch
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, ch1x1, kernel_size=1),
            nn.BatchNorm2d(ch1x1),
            nn.ReLU(inplace=True)
        )
        
        # 3×3 branch (with 1×1 reduction)
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, ch3x3_reduce, kernel_size=1),
            nn.BatchNorm2d(ch3x3_reduce),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch3x3_reduce, ch3x3, kernel_size=3, padding=1),
            nn.BatchNorm2d(ch3x3),
            nn.ReLU(inplace=True)
        )
        
        # 5×5 branch (with 1×1 reduction)
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, ch5x5_reduce, kernel_size=1),
            nn.BatchNorm2d(ch5x5_reduce),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch5x5_reduce, ch5x5, kernel_size=5, padding=2),
            nn.BatchNorm2d(ch5x5),
            nn.ReLU(inplace=True)
        )
        
        # Pool branch
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, pool_proj, kernel_size=1),
            nn.BatchNorm2d(pool_proj),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        out1 = self.branch1(x)
        out2 = self.branch2(x)
        out3 = self.branch3(x)
        out4 = self.branch4(x)
        
        # Concatenate along channel dimension
        return torch.cat([out1, out2, out3, out4], dim=1)


# Test Inception module
inception = InceptionModule(64, ch1x1=16, ch3x3_reduce=24, ch3x3=32,
                           ch5x5_reduce=8, ch5x5=16, pool_proj=16).to(device)
dummy = torch.randn(2, 64, 32, 32).to(device)
out = inception(dummy)
print(f"Inception module output shape: {out.shape}")
print(f"Output channels: 16 + 32 + 16 + 16 = {16+32+16+16}")

## Part 7: Training and Comparison

### 7.1 Training Helper Functions

In [None]:
def train_model(model, train_loader, test_loader, num_epochs=20, lr=0.1):
    """Train model and track metrics."""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    train_losses = []
    train_accs = []
    test_accs = []
    
    print(f"Training for {num_epochs} epochs...")
    start_time = time.time()
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'loss': running_loss / (pbar.n + 1),
                'acc': 100. * correct / total
            })
        
        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # Testing
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        test_acc = 100. * correct / total
        test_accs.append(test_acc)
        
        print(f'Epoch {epoch+1}: Train Loss={train_loss:.4f}, '
              f'Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%')
        
        scheduler.step()
    
    total_time = time.time() - start_time
    print(f"\nTraining completed in {total_time:.2f} seconds")
    
    return {
        'train_losses': train_losses,
        'train_accs': train_accs,
        'test_accs': test_accs,
        'total_time': total_time
    }

### 7.2 Train ResNet-18 (Main Focus)

In [None]:
# Train ResNet-18
print("="*60)
print("Training ResNet-18")
print("="*60)

resnet18 = ResNet18().to(device)
resnet18_results = train_model(resnet18, train_loader, test_loader, num_epochs=20)

print(f"\nFinal Test Accuracy: {resnet18_results['test_accs'][-1]:.2f}%")
print(f"Best Test Accuracy: {max(resnet18_results['test_accs']):.2f}%")

### 7.3 Train Plain Network (for Comparison)

In [None]:
# Train PlainNet-18 (no skip connections)
print("\n" + "="*60)
print("Training PlainNet-18 (No Skip Connections)")
print("="*60)

plain_net = PlainNet18().to(device)
plain_results = train_model(plain_net, train_loader, test_loader, num_epochs=20)

print(f"\nFinal Test Accuracy: {plain_results['test_accs'][-1]:.2f}%")
print(f"Best Test Accuracy: {max(plain_results['test_accs']):.2f}%")

## Part 8: Comparison and Analysis

### 8.1 ResNet vs PlainNet

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Training loss
axes[0].plot(resnet18_results['train_losses'], label='ResNet-18', linewidth=2)
axes[0].plot(plain_results['train_losses'], label='PlainNet-18', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Training Loss', fontsize=12)
axes[0].set_title('Training Loss Comparison', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Training accuracy
axes[1].plot(resnet18_results['train_accs'], label='ResNet-18', linewidth=2)
axes[1].plot(plain_results['train_accs'], label='PlainNet-18', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Training Accuracy (%)', fontsize=12)
axes[1].set_title('Training Accuracy Comparison', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

# Test accuracy
axes[2].plot(resnet18_results['test_accs'], label='ResNet-18', linewidth=2)
axes[2].plot(plain_results['test_accs'], label='PlainNet-18', linewidth=2)
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('Test Accuracy (%)', fontsize=12)
axes[2].set_title('Test Accuracy Comparison', fontsize=14, fontweight='bold')
axes[2].legend(fontsize=11)
axes[2].grid(True, alpha=0.3)

plt.suptitle('Skip Connections: ResNet vs Plain Network', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# Summary statistics
print("\n" + "="*60)
print("COMPARISON SUMMARY")
print("="*60)
print(f"{'Metric':<25} {'ResNet-18':>15} {'PlainNet-18':>15}")
print("-"*60)
print(f"{'Parameters':.<25} {resnet18_params:>15,} {plain_params:>15,}")
print(f"{'Best Test Acc (%)':.<25} {max(resnet18_results['test_accs']):>15.2f} {max(plain_results['test_accs']):>15.2f}")
print(f"{'Final Test Acc (%)':.<25} {resnet18_results['test_accs'][-1]:>15.2f} {plain_results['test_accs'][-1]:>15.2f}")
print(f"{'Training Time (s)':.<25} {resnet18_results['total_time']:>15.2f} {plain_results['total_time']:>15.2f}")
print("="*60)

improvement = max(resnet18_results['test_accs']) - max(plain_results['test_accs'])
print(f"\nSkip connections improvement: +{improvement:.2f}%")

### 8.2 Gradient Flow Analysis

Let's analyze how gradients flow through the networks.

In [None]:
def get_gradient_norms(model, train_loader):
    """Get gradient norms for each layer."""
    model.train()
    criterion = nn.CrossEntropyLoss()
    
    # Get one batch
    inputs, labels = next(iter(train_loader))
    inputs, labels = inputs.to(device), labels.to(device)
    
    # Forward and backward
    model.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    
    # Collect gradient norms
    grad_norms = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norms[name] = param.grad.norm().item()
    
    return grad_norms

# Get gradients for both models
resnet_grads = get_gradient_norms(resnet18, train_loader)
plain_grads = get_gradient_norms(plain_net, train_loader)

# Filter conv layer gradients
resnet_conv_grads = {k: v for k, v in resnet_grads.items() if 'conv' in k and 'weight' in k}
plain_conv_grads = {k: v for k, v in plain_grads.items() if 'conv' in k and 'weight' in k}

# Plot gradient norms by layer depth
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# ResNet gradients
resnet_values = list(resnet_conv_grads.values())
ax1.bar(range(len(resnet_values)), resnet_values, alpha=0.7)
ax1.set_xlabel('Layer (depth)', fontsize=12)
ax1.set_ylabel('Gradient Norm', fontsize=12)
ax1.set_title('ResNet-18: Gradient Norms by Layer', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3, axis='y')

# PlainNet gradients
plain_values = list(plain_conv_grads.values())
ax2.bar(range(len(plain_values)), plain_values, alpha=0.7, color='orange')
ax2.set_xlabel('Layer (depth)', fontsize=12)
ax2.set_ylabel('Gradient Norm', fontsize=12)
ax2.set_title('PlainNet-18: Gradient Norms by Layer', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print(f"ResNet gradient norm range: [{min(resnet_values):.6f}, {max(resnet_values):.6f}]")
print(f"PlainNet gradient norm range: [{min(plain_values):.6f}, {max(plain_values):.6f}]")
print(f"\nGradient range ratio: {max(resnet_values)/min(resnet_values):.2f} (ResNet) vs "
      f"{max(plain_values)/min(plain_values):.2f} (PlainNet)")

## Part 9: Architecture Comparison Summary

### 9.1 Parameters vs Performance

In [None]:
# Create comparison table
architectures = {
    'VGGNet': {
        'params': total_params,  # VGG params from earlier
        'depth': '16 conv + 3 FC',
        'key_feature': 'Stacked 3x3 convs',
        'year': 2014
    },
    'ResNet-18': {
        'params': resnet18_params,
        'depth': '18 layers',
        'key_feature': 'Skip connections',
        'year': 2015
    },
    'ResNet-34': {
        'params': resnet34_params,
        'depth': '34 layers',
        'key_feature': 'Skip connections',
        'year': 2015
    },
    'ResNet-50': {
        'params': resnet50_params,
        'depth': '50 layers',
        'key_feature': 'Bottleneck blocks',
        'year': 2015
    },
    'PlainNet-18': {
        'params': plain_params,
        'depth': '18 layers',
        'key_feature': 'No skip connections',
        'year': '-'
    }
}

print("\nArchitecture Comparison:")
print("="*90)
print(f"{'Architecture':<15} {'Parameters':>15} {'Depth':<20} {'Key Feature':<25} {'Year':>6}")
print("-"*90)
for name, info in architectures.items():
    print(f"{name:<15} {info['params']:>15,} {info['depth']:<20} {info['key_feature']:<25} {info['year']:>6}")
print("="*90)

# Visualize parameters
fig, ax = plt.subplots(figsize=(12, 6))

names = list(architectures.keys())
params = [architectures[n]['params'] for n in names]

bars = ax.bar(names, [p/1e6 for p in params], alpha=0.7, edgecolor='black')
ax.set_ylabel('Parameters (Millions)', fontsize=12)
ax.set_title('Model Complexity: Parameter Count', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Add value labels
for bar, p in zip(bars, params):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 0.1,
            f'{p/1e6:.2f}M', ha='center', va='bottom', fontsize=10, fontweight='bold')

plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

## Part 10: Key Takeaways and Mathematical Insights

### 10.1 Why Skip Connections Work

**Mathematical Perspective:**

For a residual block $H(x) = F(x) + x$:

$$
\frac{\partial L}{\partial x_l} = \frac{\partial L}{\partial x_L} \cdot \frac{\partial x_L}{\partial x_l}
$$

For ResNets:
$$
x_L = x_l + \sum_{i=l}^{L-1} F(x_i)
$$

Therefore:
$$
\frac{\partial x_L}{\partial x_l} = 1 + \frac{\partial}{\partial x_l}\sum_{i=l}^{L-1} F(x_i)
$$

**The "+1" term ensures gradient always flows**, regardless of how small the residual gradients are!

### 10.2 Architecture Design Principles

**VGG:**
- Simplicity: All 3×3 convolutions
- Pros: Easy to understand and implement
- Cons: Very large (138M parameters), slow to train

**ResNet:**
- Skip connections enable very deep networks (100+ layers)
- Pros: Efficient, scalable, state-of-the-art performance
- Cons: More complex design

**Inception:**
- Multi-scale feature extraction
- Pros: Captures features at different scales efficiently
- Cons: Complex architecture, harder to tune

### 10.3 Modern Best Practices

1. **Use skip connections** for networks deeper than ~20 layers
2. **Batch Normalization** after every conv layer
3. **Global Average Pooling** instead of FC layers (reduces parameters)
4. **Bottleneck blocks** for very deep networks (computational efficiency)
5. **Data augmentation** is crucial for generalization
6. **Learning rate scheduling** (cosine annealing, step decay)
7. **Regularization** (weight decay, dropout)

---

## Summary

Congratulations! You've implemented famous CNN architectures from their original papers. You now understand:

✅ **VGG Network:** Simplicity of stacked 3×3 convolutions  
✅ **ResNet:** Skip connections solve gradient vanishing  
✅ **Inception Module:** Multi-scale feature extraction  
✅ **The Degradation Problem:** Why plain deep networks fail  
✅ **Residual Learning:** Learn $F(x) = H(x) - x$ instead of $H(x)$  
✅ **Gradient Flow Analysis:** Mathematical proof of why skip connections help  
✅ **Architecture Trade-offs:** Parameters vs depth vs performance  

**Key Achievement:**
- Implemented ResNet-18, ResNet-34, ResNet-50 from scratch
- Demonstrated skip connection benefit (ResNet vs PlainNet)
- Achieved competitive accuracy on CIFAR-10

**Mathematical Insight:**
$$
\frac{\partial L}{\partial x} = \frac{\partial L}{\partial H} \cdot \left(1 + \frac{\partial F}{\partial x}\right)
$$
The "+1" ensures gradient flow even when residual gradients are small!

**Time spent:** ~3-4 hours

**Next:** Day 10 - Project 1: MNIST Classification (portfolio-quality project combining all learnings from Days 1-9)