---
title: Residual Networks (ResNets)
exports:
  - format: pdf
    template: plain_latex
    output: exports/ResNets.pdf
    logo: false
    link: true
downloads:
  - file: exports/ResNets.pdf
  - file: ResNets.ipynb
math:
    '\calA': '{\cal A}'
    '\calB': '{\cal B}'
    '\calC': '{\cal C}'
    '\calD': '{\cal D}'
    '\calE': '{\cal E}'
    '\calF': '{\cal F}'
    '\calG': '{\cal G}'
    '\calH': '{\cal H}'
    '\calI': '{\cal I}'
    '\calJ': '{\cal J}'
    '\calK': '{\cal K}'
    '\calL': '{\cal L}'
    '\calM': '{\cal M}'
    '\calN': '{\cal N}'
    '\calO': '{\cal O}'
    '\calP': '{\cal P}'
    '\calQ': '{\cal Q}'
    '\calR': '{\cal R}'
    '\calS': '{\cal S}'
    '\calT': '{\cal T}'
    '\calU': '{\cal U}'
    '\calV': '{\cal V}'
    '\calW': '{\cal W}'
    '\calX': '{\cal X}'
    '\calY': '{\cal Y}'
    '\calZ': '{\cal Z}'
    '\bfa': '\mathbf{a}'
    '\bfb': '\mathbf{b}'
    '\bfc': '\mathbf{c}'
    '\bfd': '\mathbf{d}'
    '\bfe': '\mathbf{e}'
    '\bff': '\mathbf{f}'
    '\bfg': '\mathbf{g}'
    '\bfh': '\mathbf{h}'
    '\bfi': '\mathbf{i}'
    '\bfj': '\mathbf{j}'
    '\bfk': '\mathbf{k}'
    '\bfl': '\mathbf{l}'
    '\bfm': '\mathbf{m}'
    '\bfn': '\mathbf{n}'
    '\bfo': '\mathbf{o}'
    '\bfp': '\mathbf{p}'
    '\bfq': '\mathbf{q}'
    '\bfr': '\mathbf{r}'
    '\bfs': '\mathbf{s}'
    '\bft': '\mathbf{t}'
    '\bfu': '\mathbf{u}'
    '\bfv': '\mathbf{v}'
    '\bfw': '\mathbf{w}'
    '\bfx': '\mathbf{x}'
    '\bfy': '\mathbf{y}'
    '\bfz': '\mathbf{z}'
    '\bfW': '\mathbf{W}'
    '\bfX': '\mathbf{X}'
    '\bfY': '\mathbf{Y}'
    '\bfZ': '\mathbf{Z}'
    '\bftheta': '\boldsymbol{\theta}'
    '\bbR': '\mathbb{R}'
    '\bbE': '\mathbb{E}'
    '\p': '\partial'
---

# Residual Networks (ResNets)

In this notebook, we will explore:
1. Why deep networks are hard to train (the degradation problem)
2. How residual connections solve this problem
3. Implementation of ResNet building blocks
4. He initialization for ReLU networks
5. Training ResNets on CIFAR-10

**Prerequisites**: CNN notebook (07-cnn), Vanishing Gradients (09-vanishing)

In [1]:
# Install dependencies
!pip install otter-grader torch torchvision matplotlib


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [2]:
# Setup otter-grader
URL = "https://raw.githubusercontent.com/wecacuee/ECE490-F25-Neural-Networks/refs/heads/master/notebooks/11-regularization-resnets/ResNetsTests.zip"
fname = "ResNetsTests.zip"
import urllib
from zipfile import ZipFile
try:
    urllib.request.urlretrieve(URL, fname)
    ZipFile(fname).extractall()
except:
    print("Could not download tests. Grading may not work.")
import otter
grader = otter.Notebook(tests_dir="./tests")

Could not download tests. Grading may not work.


In [3]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

# Set device
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')
print(f"Using device: {DEVICE}")

Using device: cuda


## 1. The Degradation Problem

A surprising observation: adding more layers to a deep network can actually **increase** training error.

This is NOT overfitting (validation error is also higher). The problem is that deeper networks are harder to optimize.

![Degradation Problem](https://miro.medium.com/max/1400/1*vbx2sDfjB9YB5PEI3j2BXg.png)

**Key insight**: If we could learn identity mappings for the extra layers, a deeper network should be at least as good as a shallower one. But in practice, this is hard to learn with standard gradient descent.

In [4]:
# Load CIFAR-10 dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

57.7%IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100.0%


### Plain Network (No Skip Connections)

Let's first build a plain CNN without skip connections to demonstrate the degradation problem.

In [5]:
class PlainBlock(nn.Module):
    """A plain convolutional block without skip connections"""
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = F.relu(out)
        return out

class PlainNet(nn.Module):
    """Plain network without skip connections"""
    def __init__(self, block, num_blocks, num_classes=10):
        super().__init__()
        self.in_channels = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        
    def _make_layer(self, block, out_channels, num_blocks, stride):
        layers = [block(self.in_channels, out_channels, stride)]
        self.in_channels = out_channels
        for _ in range(1, num_blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

# Create a shallow and deep plain network
plain_shallow = PlainNet(PlainBlock, [1, 1, 1, 1])  # ~10 layers
plain_deep = PlainNet(PlainBlock, [3, 3, 3, 3])      # ~26 layers

print(f"Shallow PlainNet parameters: {sum(p.numel() for p in plain_shallow.parameters()):,}")
print(f"Deep PlainNet parameters: {sum(p.numel() for p in plain_deep.parameters()):,}")

Shallow PlainNet parameters: 4,729,418
Deep PlainNet parameters: 17,270,858


## 2. Residual Learning: The Solution

Instead of hoping each stack of layers directly fits a desired underlying mapping $\calH(\bfx)$, we explicitly let these layers fit a **residual mapping**:

$$\calF(\bfx) = \calH(\bfx) - \bfx$$

The original mapping becomes:

$$\calH(\bfx) = \calF(\bfx) + \bfx$$

**Key insight**: If the identity mapping is optimal, it's easier to push the residual $\calF(\bfx)$ to zero than to fit an identity mapping with nonlinear layers.

```
x ─────────────────────(+)──→ y
  │                     ↑
  └──→ [Conv] → [BN] → [ReLU] → [Conv] → [BN] ─┘
              F(x)
```

### Gradient Flow Through Residual Connections

For a residual block $\bfy = \calF(\bfx) + \bfx$, the gradient is:

$$\frac{\p \calL}{\p \bfx} = \frac{\p \calL}{\p \bfy} \cdot \frac{\p \bfy}{\p \bfx} = \frac{\p \calL}{\p \bfy} \cdot \left(1 + \frac{\p \calF}{\p \bfx}\right)$$

The **identity path** (the $+1$ term) ensures gradients can flow directly backward, even if $\frac{\p \calF}{\p \bfx}$ is small.

For a network with $L$ residual blocks:

$$\frac{\p \calL}{\p \bfx_0} = \frac{\p \calL}{\p \bfx_L} \cdot \prod_{l=1}^{L} \left(1 + \frac{\p \calF_l}{\p \bfx_{l-1}}\right)$$

The gradient always has a path through the identity connections!

## 3. Implementing ResNet Blocks

<!-- BEGIN QUESTION -->

### Exercise 1: Implement BasicBlock (15 points)

Implement the basic residual block used in ResNet-18 and ResNet-34.

**Architecture:**
```
x → Conv3x3 → BN → ReLU → Conv3x3 → BN → (+) → ReLU → out
|_______________________________________↑ (skip connection)
```

When `stride > 1` or input/output channels differ, use a 1x1 convolution for the skip connection (provided via `downsample`).

In [None]:
class BasicBlock(nn.Module):
    """
    Basic residual block: two 3x3 convolutions with skip connection
    
    Args:
        in_channels: Number of input channels
        out_channels: Number of output channels
        stride: Stride for first convolution (default: 1)
        downsample: Module for downsampling identity if dimensions change (default: None)
    """
    expansion = 1  # Output channels = out_channels * expansion
    
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super().__init__()
        # TODO: Implement the following layers:
        # self.conv1: 3x3 conv, in_channels -> out_channels, with stride, padding=1, no bias
        # self.bn1: BatchNorm2d for out_channels
        # self.conv2: 3x3 conv, out_channels -> out_channels, stride=1, padding=1, no bias
        # self.bn2: BatchNorm2d for out_channels
        # self.downsample: store the downsample module
        
        self.conv1 = ...  # YOUR CODE HERE
        self.bn1 = ...    # YOUR CODE HERE
        self.conv2 = ...  # YOUR CODE HERE
        self.bn2 = ...    # YOUR CODE HERE
        self.downsample = downsample
        
    def forward(self, x):
        """
        Forward pass with skip connection.
        
        Args:
            x: Input tensor of shape (batch, in_channels, H, W)
            
        Returns:
            Output tensor of shape (batch, out_channels, H', W')
            where H', W' depend on stride
        """
        # TODO: Implement the forward pass
        # 1. Store identity (x) for skip connection
        # 2. Apply conv1 -> bn1 -> relu
        # 3. Apply conv2 -> bn2
        # 4. If downsample is not None, apply it to identity
        # 5. Add identity to output (skip connection)
        # 6. Apply final relu
        
        identity = x
        
        # YOUR CODE HERE
        out = ...
        
        return out

In [None]:
# Test your BasicBlock implementation
def test_basic_block():
    # Test 1: Same dimensions
    block = BasicBlock(64, 64)
    x = torch.randn(2, 64, 32, 32)
    y = block(x)
    assert y.shape == (2, 64, 32, 32), f"Expected (2, 64, 32, 32), got {y.shape}"
    
    # Test 2: With downsampling
    downsample = nn.Sequential(
        nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),
        nn.BatchNorm2d(128)
    )
    block = BasicBlock(64, 128, stride=2, downsample=downsample)
    x = torch.randn(2, 64, 32, 32)
    y = block(x)
    assert y.shape == (2, 128, 16, 16), f"Expected (2, 128, 16, 16), got {y.shape}"
    
    print("BasicBlock tests passed!")

test_basic_block()

In [None]:
grader.check("basic_block")

### Exercise 2: Implement BottleneckBlock (15 points)

For deeper networks (ResNet-50+), we use **bottleneck blocks** to reduce computation:

```
x → Conv1x1 → BN → ReLU → Conv3x3 → BN → ReLU → Conv1x1 → BN → (+) → ReLU → out
|     (reduce)              (process)              (expand)        ↑
|______________________________________________________________↑ (skip)
```

The 1x1 convolutions reduce channels before the expensive 3x3 convolution, then expand back.

- First 1x1: `in_channels` → `out_channels`
- 3x3: `out_channels` → `out_channels`  
- Second 1x1: `out_channels` → `out_channels * expansion` (where expansion=4)

In [None]:
class BottleneckBlock(nn.Module):
    """
    Bottleneck residual block: 1x1 -> 3x3 -> 1x1 with skip connection
    
    The expansion factor means output has out_channels * 4 channels.
    """
    expansion = 4
    
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super().__init__()
        # TODO: Implement the following layers:
        # self.conv1: 1x1 conv, in_channels -> out_channels, no bias
        # self.bn1: BatchNorm2d
        # self.conv2: 3x3 conv, out_channels -> out_channels, with stride, padding=1, no bias
        # self.bn2: BatchNorm2d
        # self.conv3: 1x1 conv, out_channels -> out_channels * expansion, no bias
        # self.bn3: BatchNorm2d
        
        self.conv1 = ...  # YOUR CODE HERE
        self.bn1 = ...    # YOUR CODE HERE
        self.conv2 = ...  # YOUR CODE HERE
        self.bn2 = ...    # YOUR CODE HERE
        self.conv3 = ...  # YOUR CODE HERE
        self.bn3 = ...    # YOUR CODE HERE
        self.downsample = downsample
        
    def forward(self, x):
        identity = x
        
        # TODO: Implement forward pass
        # 1. conv1 -> bn1 -> relu
        # 2. conv2 -> bn2 -> relu
        # 3. conv3 -> bn3
        # 4. Apply downsample to identity if needed
        # 5. Add skip connection
        # 6. Final relu
        
        out = ...  # YOUR CODE HERE
        
        return out

In [None]:
# Test BottleneckBlock
def test_bottleneck():
    # Test with downsample (typical first block in a stage)
    downsample = nn.Sequential(
        nn.Conv2d(64, 256, kernel_size=1, stride=1, bias=False),
        nn.BatchNorm2d(256)
    )
    block = BottleneckBlock(64, 64, stride=1, downsample=downsample)
    x = torch.randn(2, 64, 32, 32)
    y = block(x)
    assert y.shape == (2, 256, 32, 32), f"Expected (2, 256, 32, 32), got {y.shape}"
    
    # Test with stride=2
    downsample = nn.Sequential(
        nn.Conv2d(256, 512, kernel_size=1, stride=2, bias=False),
        nn.BatchNorm2d(512)
    )
    block = BottleneckBlock(256, 128, stride=2, downsample=downsample)
    x = torch.randn(2, 256, 32, 32)
    y = block(x)
    assert y.shape == (2, 512, 16, 16), f"Expected (2, 512, 16, 16), got {y.shape}"
    
    print("BottleneckBlock tests passed!")

test_bottleneck()

In [None]:
grader.check("bottleneck")

<!-- END QUESTION -->

## 4. He Initialization

For deep networks with ReLU, proper weight initialization is crucial.

### Variance Propagation Analysis

Consider a linear layer $\bfy = \bfW\bfx$ where:
- $x_i \sim \text{i.i.d.}$ with $\bbE[x_i] = 0$, $\text{Var}(x_i) = \sigma_x^2$
- $W_{ij} \sim \text{i.i.d.}$ with $\bbE[W_{ij}] = 0$, $\text{Var}(W_{ij}) = \sigma_W^2$

The output variance is:
$$\text{Var}(y_j) = n_{in} \cdot \sigma_W^2 \cdot \sigma_x^2$$

**For variance preservation**: $\sigma_W^2 = \frac{1}{n_{in}}$ (Xavier/Glorot initialization)

**With ReLU**, half the values are zeroed on average, so:
$$\sigma_W^2 = \frac{2}{n_{in}} \quad \text{(He initialization)}$$

$$W \sim \calN\left(0, \sqrt{\frac{2}{n_{in}}}\right)$$

### Exercise 3: Verify He Initialization (10 points)

Implement a function to verify that He initialization maintains unit variance through the network.

In [None]:
def verify_he_init(model, input_shape=(64, 3, 32, 32)):
    """
    Verify that activations maintain approximately unit variance through the network.
    
    Args:
        model: Neural network model
        input_shape: Shape of input tensor (batch, channels, height, width)
        
    Returns:
        variances: List of variance values at each layer
        layer_names: List of layer names
    """
    variances = []
    layer_names = []
    
    # TODO: 
    # 1. Create random input with unit variance
    # 2. Register forward hooks on Conv2d and Linear layers to capture activations
    # 3. Run forward pass
    # 4. Compute variance of each layer's output
    # 5. Return list of variances
    
    # Hint: Use model.named_modules() to iterate over layers
    # Hint: Use register_forward_hook to capture layer outputs
    
    # YOUR CODE HERE
    x = torch.randn(input_shape)  # Unit variance input
    
    # Store activations
    activations = {}
    
    def get_activation(name):
        def hook(module, input, output):
            activations[name] = output.detach()
        return hook
    
    # Register hooks
    handles = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            handles.append(module.register_forward_hook(get_activation(name)))
    
    # Forward pass
    with torch.no_grad():
        model(x)
    
    # Remove hooks
    for h in handles:
        h.remove()
    
    # Compute variances
    for name, act in activations.items():
        var = act.var().item()
        variances.append(var)
        layer_names.append(name)
    
    return variances, layer_names

In [None]:
# Visualize variance propagation
def plot_variance_propagation(variances, layer_names, title="Variance through layers"):
    plt.figure(figsize=(12, 4))
    plt.bar(range(len(variances)), variances)
    plt.axhline(y=1.0, color='r', linestyle='--', label='Target variance=1')
    plt.xlabel('Layer')
    plt.ylabel('Variance')
    plt.title(title)
    plt.xticks(range(len(variances)), layer_names, rotation=45, ha='right')
    plt.legend()
    plt.tight_layout()
    plt.show()

# Test with a simple model using He init
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1, stride=2)
        self.conv4 = nn.Conv2d(128, 128, 3, padding=1)
        self.fc = nn.Linear(128 * 16 * 16, 10)
        
        # Apply He initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

simple_net = SimpleNet()
variances, names = verify_he_init(simple_net)
plot_variance_propagation(variances, names, "Variance with He Initialization")

In [None]:
grader.check("he_init")

## 5. Building the Full ResNet

Now let's assemble our blocks into a complete ResNet architecture.

In [None]:
class ResNet(nn.Module):
    """
    ResNet architecture for CIFAR-10.
    
    Args:
        block: BasicBlock or BottleneckBlock
        layers: List of number of blocks in each stage [2, 2, 2, 2] for ResNet-18
        num_classes: Number of output classes
    """
    def __init__(self, block, layers, num_classes=10):
        super().__init__()
        self.in_channels = 64
        
        # Initial convolution (adapted for CIFAR-10's 32x32 images)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        
        # Residual stages
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    def _make_layer(self, block, out_channels, num_blocks, stride):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion)
            )
            
        layers = [block(self.in_channels, out_channels, stride, downsample)]
        self.in_channels = out_channels * block.expansion
        
        for _ in range(1, num_blocks):
            layers.append(block(self.in_channels, out_channels))
            
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x

def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])

def ResNet50():
    return ResNet(BottleneckBlock, [3, 4, 6, 3])

In [None]:
# Create ResNet-18
resnet18 = ResNet18()
print(f"ResNet-18 parameters: {sum(p.numel() for p in resnet18.parameters()):,}")

# Test forward pass
x = torch.randn(2, 3, 32, 32)
y = resnet18(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")

## 6. Training ResNet on CIFAR-10

<!-- BEGIN QUESTION -->

### Exercise 4: Train ResNet-18 (15 points)

Train ResNet-18 on CIFAR-10 to achieve >75% test accuracy.

In [None]:
def train_epoch(model, trainloader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
    return running_loss / len(trainloader), 100. * correct / total

def evaluate(model, testloader, criterion, device):
    """Evaluate on test set."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
    return running_loss / len(testloader), 100. * correct / total

In [None]:
def train_resnet(epochs=10):
    """
    Train ResNet-18 on CIFAR-10.
    
    Args:
        epochs: Number of training epochs
        
    Returns:
        model: Trained ResNet-18 model
        test_acc: Final test accuracy (must be > 75%)
        history: Dict with 'train_loss', 'train_acc', 'test_loss', 'test_acc' lists
    """
    # TODO: Implement training
    # 1. Create ResNet-18 model and move to device
    # 2. Define loss function (CrossEntropyLoss)
    # 3. Define optimizer (SGD with momentum=0.9, weight_decay=5e-4)
    # 4. Optionally use learning rate scheduler
    # 5. Train for specified epochs, tracking metrics
    
    model = ResNet18().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}
    
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(model, trainloader, criterion, optimizer, DEVICE)
        test_loss, test_acc = evaluate(model, testloader, criterion, DEVICE)
        scheduler.step()
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        
        print(f"Epoch {epoch+1}/{epochs}: "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
    
    return model, test_acc, history

# Train the model (set epochs based on available compute)
# For quick testing, use fewer epochs; for best results, use 50-100
trained_model, final_acc, history = train_resnet(epochs=20)

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['test_loss'], label='Test')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].set_title('Loss Curves')

axes[1].plot(history['train_acc'], label='Train')
axes[1].plot(history['test_acc'], label='Test')
axes[1].axhline(y=75, color='r', linestyle='--', label='Target (75%)')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].legend()
axes[1].set_title('Accuracy Curves')

plt.tight_layout()
plt.show()

print(f"\nFinal Test Accuracy: {final_acc:.2f}%")

In [None]:
grader.check("train_resnet")

### Exercise 5: Ablation Study (10 points)

Compare different configurations to understand the impact of BatchNorm and Dropout.

In [None]:
def ablation_study(epochs=10):
    """
    Compare ResNet configurations:
    1. Standard ResNet with BatchNorm
    2. ResNet without BatchNorm (replace with Identity)
    3. ResNet with Dropout after each block
    
    Returns:
        results: Dict with accuracy for each configuration
                 {'with_bn': float, 'without_bn': float, 'with_dropout': float}
    """
    results = {}
    
    # TODO: Implement ablation study
    # Train each configuration for specified epochs
    # Store final test accuracy in results dict
    
    # Configuration 1: Standard ResNet with BatchNorm
    print("Training ResNet with BatchNorm...")
    model_bn = ResNet18().to(DEVICE)
    # ... train and evaluate
    
    # YOUR CODE HERE
    
    return results

# Run ablation study (use fewer epochs for quick testing)
# ablation_results = ablation_study(epochs=10)
# print("\nAblation Study Results:")
# for config, acc in ablation_results.items():
#     print(f"  {config}: {acc:.2f}%")

In [None]:
grader.check("ablation")

<!-- END QUESTION -->

## 7. Loss Landscape Visualization

One reason ResNets train better is that skip connections **smooth the loss landscape**.

In [None]:
def plot_loss_landscape_1d(model, dataloader, device, n_points=21):
    """
    Plot 1D slice of loss landscape along a random direction.
    """
    model.eval()
    criterion = nn.CrossEntropyLoss()
    
    # Save original parameters
    original_params = [p.clone() for p in model.parameters()]
    
    # Generate random direction
    direction = [torch.randn_like(p) for p in model.parameters()]
    # Normalize direction
    d_norm = sum((d**2).sum() for d in direction).sqrt()
    direction = [d / d_norm for d in direction]
    
    # Compute loss along direction
    alphas = np.linspace(-1, 1, n_points)
    losses = []
    
    # Get a single batch
    inputs, labels = next(iter(dataloader))
    inputs, labels = inputs.to(device), labels.to(device)
    
    for alpha in alphas:
        # Perturb parameters
        with torch.no_grad():
            for p, p0, d in zip(model.parameters(), original_params, direction):
                p.copy_(p0 + alpha * d)
        
        # Compute loss
        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            losses.append(loss.item())
    
    # Restore original parameters
    with torch.no_grad():
        for p, p0 in zip(model.parameters(), original_params):
            p.copy_(p0)
    
    return alphas, losses

# Compare loss landscapes
if 'trained_model' in dir():
    plt.figure(figsize=(10, 4))
    
    # ResNet loss landscape
    alphas, losses = plot_loss_landscape_1d(trained_model, testloader, DEVICE)
    plt.plot(alphas, losses, label='ResNet-18')
    
    plt.xlabel('Step along random direction')
    plt.ylabel('Loss')
    plt.title('Loss Landscape (1D slice)')
    plt.legend()
    plt.show()

## 8. Regularization Techniques Recap

### Dropout
During training, randomly zero activations with probability $p$:
$$\tilde{h}_i = \begin{cases} 0 & \text{with prob } p \\ h_i / (1-p) & \text{with prob } 1-p \end{cases}$$

### Batch Normalization
Normalize activations across the batch:
$$\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}, \quad y_i = \gamma \hat{x}_i + \beta$$

### Weight Decay (L2 regularization)
Add penalty on weight magnitude:
$$\calL_{reg} = \calL + \frac{\lambda}{2} \|\bfW\|_2^2$$

## 9. Extensions: DenseNet and U-Net

### DenseNet
Each layer receives feature maps from ALL preceding layers:
$$\bfx_l = H_l([\bfx_0, \bfx_1, \ldots, \bfx_{l-1}])$$

### U-Net
Encoder-decoder architecture with skip connections between corresponding layers:
- **Encoder**: Downsample path (like ResNet)
- **Decoder**: Upsample path
- **Skip connections**: Connect encoder layers to decoder layers at same resolution

Used extensively in image segmentation tasks.

## Summary

1. **Degradation Problem**: Deeper networks can have higher training error
2. **Residual Learning**: Learn $\calF(\bfx) = \calH(\bfx) - \bfx$ instead of $\calH(\bfx)$
3. **Skip Connections**: Enable direct gradient flow through identity path
4. **He Initialization**: Use $\sigma_W = \sqrt{2/n_{in}}$ for ReLU networks
5. **Building Blocks**: BasicBlock for ResNet-18/34, BottleneckBlock for deeper
6. **Loss Landscape**: Skip connections smooth the optimization landscape

## Submission

Make sure you have run all cells in order before exporting.

In [None]:
# Save your notebook first, then run this cell to export
grader.export(run_tests=True)