In [10]:
import os
import sys
import numpy as np
import einops
from typing import Union, Optional, Tuple, List, Dict
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float, Int
import functools
from pathlib import Path
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
from tqdm.notebook import tqdm
from dataclasses import dataclass
from PIL import Image
import json

# Make sure exercises are in the path
chapter = r"chapter0_fundamentals"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part2_cnns"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow, line, bar
import part2_cnns.tests as tests
from part2_cnns.utils import print_param_count

MAIN = __name__ == "__main__"

# device = t.device("cuda" if t.cuda.is_available() else "cpu")
device = t.device("mps" if t.backends.mps.is_available() else "cpu")

## Making your own Modules

### Exercise - Implement ReLu

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to ~10 minutes on this exercise.

In [11]:
class ReLU(nn.Module):
    def forward(self, x: t.Tensor) -> t.Tensor:
        x[x < 0] = 0
        return x


tests.test_relu(ReLU)

All tests in `test_relu` passed!


### Exercise - implement Linear

In [28]:
class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, bias=True):
        '''
        A simple linear (technically, affine) transformation.

        The fields should be named `weight` and `bias` for compatibility with PyTorch.
        If `bias` is False, set `self.bias` to None.
        '''
        super().__init__()
        self.weight = nn.Parameter(t.zeros(out_features, in_features).uniform_(-1/(in_features)**0.5, 1/(in_features)**0.5))
        self.bias = nn.Parameter(t.zeros(out_features).uniform_(-1/(in_features)**0.5, 1/(in_features)**0.5)) if bias else None

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x: shape (*, in_features)
        Return: shape (*, out_features)
        '''
        wx = einops.einsum(self.weight, x, 'out_f in_f, b in_f -> b out_f')
        if self.bias is None:
            return wx 
        return wx + self.bias

    def extra_repr(self) -> str:
        return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"


tests.test_linear_forward(Linear)
tests.test_linear_parameters(Linear)
tests.test_linear_no_bias(Linear)

All tests in `test_linear_forward` passed!
All tests in `test_linear_parameters` passed!
All tests in `test_linear_no_bias` passed!


### Exercise - Implement flatten

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to 10-15 minutes on this exercise.

In [52]:
class Flatten(nn.Module):
    def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
        super().__init__()
        self.start = start_dim
        self.end = end_dim

    def forward(self, input: t.Tensor) -> t.Tensor:
        '''
        Flatten out dimensions from start_dim to end_dim, inclusive of both.
        '''
        if self.end == -1:
            flattened_dim = t.prod(t.tensor(input.shape[self.start:]))
            return t.reshape(input, input.shape[:self.start] + (flattened_dim.item(),))
        else:
            flattened_dim = t.prod(t.tensor(input.shape[self.start:self.end+1]))
            return t.reshape(input, input.shape[:self.start] + (flattened_dim,) + input.shape[self.end+1:])

    def extra_repr(self) -> str:
        return ", ".join([f"{key}={getattr(self, key)}" for key in ["start_dim", "end_dim"]])


tests.test_flatten(Flatten)

All tests in `test_flatten` passed!


### Exercise - implement the simple MLP

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪

You should spend up to ~20 minutes on this exercise.

In [60]:
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_0 = Flatten(1, -1)
        self.layer_1 = Linear(28**2, 100)
        self.layer_2 = ReLU()
        self.layer_3 = Linear(100, 10)
        self.layers = [self.layer_0, self.layer_1, self.layer_2, self.layer_3]


    def forward(self, x: t.Tensor) -> t.Tensor:
        for layer in self.layers:
            x = layer.forward(x)
        return x



tests.test_mlp(SimpleMLP)

All tests in `test_mlp` passed!


## Training

In [61]:
MNIST_TRANSFORM = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

def get_mnist(subset: int = 1):
    '''Returns MNIST training data, sampled by the frequency given in `subset`.'''
    mnist_trainset = datasets.MNIST(root="./data", train=True, download=True, transform=MNIST_TRANSFORM)
    mnist_testset = datasets.MNIST(root="./data", train=False, download=True, transform=MNIST_TRANSFORM)

    if subset > 1:
        mnist_trainset = Subset(mnist_trainset, indices=range(0, len(mnist_trainset), subset))
        mnist_testset = Subset(mnist_testset, indices=range(0, len(mnist_testset), subset))

    return mnist_trainset, mnist_testset


mnist_trainset, mnist_testset = get_mnist()
mnist_trainloader = DataLoader(mnist_trainset, batch_size=64, shuffle=True)
mnist_testloader = DataLoader(mnist_testset, batch_size=64, shuffle=False)

#### Training Loop

In [62]:
model = SimpleMLP().to(device)

batch_size = 64
epochs = 3

mnist_trainset, _ = get_mnist(subset = 10)
mnist_trainloader = DataLoader(mnist_trainset, batch_size=batch_size, shuffle=True)

optimizer = t.optim.Adam(model.parameters(), lr=1e-3)
loss_list = []

for epoch in tqdm(range(epochs)):
    for imgs, labels in mnist_trainloader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        logits = model(imgs)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss_list.append(loss.item())   

line(
    loss_list, 
    yaxis_range=[0, max(loss_list) + 0.1],
    labels={"x": "Num batches seen", "y": "Cross entropy loss"}, 
    title="SimpleMLP training on MNIST",
    width=700
)

  0%|          | 0/3 [00:00<?, ?it/s]

### Exercise - Add a validation loop

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵🔵

You should spend up to ~20 minutes on this exercise.

It is very important that you understand training loops and how they work, because we'll be doing a lot of model training in this way.

In [109]:
@dataclass
class SimpleMLPTrainingArgs():
    '''
    Defining this class implicitly creates an __init__ method, which sets arguments as 
    given below, e.g. self.batch_size = 64. Any of these arguments can also be overridden
    when you create an instance, e.g. args = SimpleMLPTrainingArgs(batch_size=128).
    '''
    batch_size: int = 64
    epochs: int = 3
    learning_rate: float = 1e-3
    subset: int = 10


def train(args: SimpleMLPTrainingArgs):
    '''
    Trains the model, using training parameters from the `args` object.
    '''
    model = SimpleMLP().to(device)

    mnist_trainset, mnist_testset = get_mnist(subset=args.subset)
    mnist_trainloader = DataLoader(mnist_trainset, batch_size=args.batch_size, shuffle=True)
    mnist_testloader = DataLoader(mnist_testset, batch_size=args.batch_size)

    optimizer = t.optim.Adam(model.parameters(), lr=args.learning_rate)
    loss_list = []

    for epoch in tqdm(range(args.epochs)):
        for imgs, labels in mnist_trainloader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            logits = model(imgs)
            loss = F.cross_entropy(logits, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_list.append(loss.item()) 
        # validate here
        validated = 0
        total_data = 0
        for imgs, lbl in mnist_testloader:
            imgs = imgs.to(device)
            lbl = lbl.to(device)
            logits = model(imgs)
            predicted_lbl = t.argmax(logits, dim=-1)
            validated += t.sum(predicted_lbl == lbl).item()
            total_data += lbl.shape[0]
        print('Validation loop #{0}: {1} correctly predicted out of {2}. Accuracy: {3}%'.format(epoch, validated, total_data, validated/total_data *100))


    line(
        loss_list, 
        yaxis_range=[0, max(loss_list) + 0.1],
        labels={"x": "Num batches seen", "y": "Cross entropy loss"}, 
        title="SimpleMLP training on MNIST",
        width=700
    )


args = SimpleMLPTrainingArgs()
train(args)

  0%|          | 0/3 [00:00<?, ?it/s]

Validation loop #0: 873 correctly predicted out of 1000. Accuracy: 87.3%
Validation loop #1: 878 correctly predicted out of 1000. Accuracy: 87.8%
Validation loop #2: 895 correctly predicted out of 1000. Accuracy: 89.5%


## Convolutions

### Exercise - implement Conv2d

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪

You should spend up to ~20 minutes on this exercise.

Make sure you understand what operation is taking place here, and how the dimensions are changing.

In [112]:
class Conv2d(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0
    ):
        '''
        Same as torch.nn.Conv2d with bias=False.
        Name your weight field `self.weight` for compatibility with the PyTorch version.
        '''
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        xavier =  1 / np.sqrt(in_channels * kernel_size ** 2)
        self.weight = nn.Parameter(xavier * (2 * t.rand(out_channels, in_channels, kernel_size, kernel_size) - 1))

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''Apply the functional conv2d, which you can import.'''
        return t.nn.functional.conv2d(x, self.weight, stride=self.stride, padding=self.padding)

    def extra_repr(self) -> str:
        keys = ["in_channels", "out_channels", "kernel_size", "stride", "padding"]
        return ", ".join([f"{key}={getattr(self, key)}" for key in keys])


tests.test_conv2d_module(Conv2d)
m = Conv2d(in_channels=24, out_channels=12, kernel_size=3, stride=2, padding=1)
print(f"Manually verify that this is an informative repr: {m}")

All tests in `test_conv2d_module` passed!
Manually verify that this is an informative repr: Conv2d(in_channels=24, out_channels=12, kernel_size=3, stride=2, padding=1)


### Exercise - implement MaxPool2D

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵⚪⚪⚪

You should spend up to ~10 minutes on this exercise.

In [114]:
class MaxPool2d(nn.Module):
    def __init__(self, kernel_size: int, stride: Optional[int] = None, padding: int = 1):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''Call the functional version of max_pool2d.'''
        return t.nn.functional.max_pool2d(x, self.kernel_size, self.stride, self.padding)

    def extra_repr(self) -> str:
        '''Add additional information to the string representation of this class.'''
        return ", ".join([f"{key}={getattr(self, key)}" for key in ['kernel_size', 'stride', 'padding']])


tests.test_maxpool2d_module(MaxPool2d)
m = MaxPool2d(kernel_size=3, stride=2, padding=1)
print(f"Manually verify that this is an informative repr: {m}")

All tests in `test_maxpool2d_module` passed!
Manually verify that this is an informative repr: MaxPool2d(kernel_size=3, stride=2, padding=1)


## ResNets

### nn.Sequential

- allows for stringing together of layers - output from one strings into the input of the next - using ordered dict

In [None]:
class Sequential(nn.Module):
    _modules: Dict[str, nn.Module]

    def __init__(self, *modules: nn.Module):
        super().__init__()
        for index, mod in enumerate(modules):
            self._modules[str(index)] = mod

    def __getitem__(self, index: int) -> nn.Module:
        index %= len(self._modules) # deal with negative indices
        return self._modules[str(index)]

    def __setitem__(self, index: int, module: nn.Module) -> None:
        index %= len(self._modules) # deal with negative indices
        self._modules[str(index)] = module

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''Chain each module together, with the output from one feeding into the next one.'''
        for mod in self._modules.values():
            x = mod(x)
        return x

### Exercise - implement BatchNorm2d

- idea of doing batch norm during inference (prediction) is strange since you cant normalize a single point - use *buffer*
- buffer - regular tensor - not included in module.parameters, module modifies it during forward

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵🔵⚪

You should spend up to 20-40 minutes on this exercise.

This is the most challenging module you'll have implemented so far. Getting all the dimensions and operations right can be tricky.

In [119]:
class BatchNorm2d(nn.Module):
    # The type hints below aren't functional, they're just for documentation
    running_mean: Float[Tensor, "num_features"]
    running_var: Float[Tensor, "num_features"]
    num_batches_tracked: Int[Tensor, ""] # This is how we denote a scalar tensor

    def __init__(self, num_features: int, eps=1e-05, momentum=0.1):
        '''
        Like nn.BatchNorm2d with track_running_stats=True and affine=True.

        Name the learnable affine parameters `weight` and `bias` in that order.
        '''
        super().__init__()
        self.num_features = num_features
        self.eps= eps
        self.momentum = momentum

        self.weight = nn.Parameter(t.ones(num_features))       #shape num features - std dev
        self.bias = nn.Parameter(t.zeros(num_features))                #shape num features - mean

        self.register_buffer("running_mean", t.zeros(num_features))
        self.register_buffer("running_var", t.ones(num_features))
        self.register_buffer("num_batches_tracked", t.tensor(0))

        self.training = True

    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        Normalize each channel.

        Compute the variance using `torch.var(x, unbiased=False)`
        Hint: you may also find it helpful to use the argument `keepdim`.

        x: shape (batch, channels, height, width)
        Return: shape (batch, channels, height, width)
        '''
        if self.training:
            batch_var = t.var(x, dim=(0,2,3), unbiased=False, keepdim=True)
            batch_mean = t.mean(x, dim=(0,2,3), keepdim=True)
            self.running_mean = self.running_mean*(1- self.momentum) + batch_mean.squeeze() * self.momentum
            self.running_var = self.running_var*(1 - self.momentum) + batch_var.squeeze() * self.momentum
            self.num_batches_tracked += 1
        else:
            # so it can be broadcasted
            batch_mean = einops.rearrange(self.running_mean, "channels -> 1 channels 1 1")
            batch_var = einops.rearrange(self.running_var, 'channels -> 1 channels 1 1')

        # rearrange for broadcasting
        weight = einops.rearrange(self.weight, 'channels -> 1 channels 1 1')
        bias = einops.rearrange(self.bias, 'channels -> 1 channels 1 1 ')


        return ((x - batch_mean)/t.sqrt(batch_var + self.eps)) * weight + bias
        

    def extra_repr(self) -> str:
        return ', '.join([f"{key}={getattr(self, key)}" for key in ["num_features", "eps", "momentum"]])


tests.test_batchnorm2d_module(BatchNorm2d)
tests.test_batchnorm2d_forward(BatchNorm2d)
tests.test_batchnorm2d_running_mean(BatchNorm2d)

All tests in `test_batchnorm2d_module` passed!
All tests in `test_batchnorm2d_forward` passed!
All tests in `test_batchnorm2d_running_mean` passed!


### Exercise - implement AveragePool

Difficulty: 🔴⚪⚪⚪⚪
Importance: 🔵🔵⚪⚪⚪

You should spend up to 5-10 minutes on this exercise.

In [120]:
class AveragePool(nn.Module):
    def forward(self, x: t.Tensor) -> t.Tensor:
        '''
        x: shape (batch, channels, height, width)
        Return: shape (batch, channels)
        '''
        return einops.reduce(x, 'b c h w -> b c', 'mean')