# Library Imports

In [1]:
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from typing import Tuple
from torch.utils.data.dataloader import DataLoader
import numpy as np
from tqdm.notebook import tqdm
from matplotlib import pyplot as plt

# Set up Dataset

In [2]:
def load_CIFAR10_dataset(batch_size=1000) -> Tuple[DataLoader, DataLoader]:
    """This function loads training data and testing data from the CIFAR10 datasets.

    The CIFAR10 dataset contains 10 classes, and approximately 6000 images per class.
    While loading the data, this function converts the images into tensors of shape (32, 32, 3)
    and normalizes using the mean and standard deviation of the CIFAR10 dataset

    Args:
        batch_size: The batch size used in both the train and test sets

    Returns:
        Pytorch dataloaders containing the train and test sets, returned
        in the order Train, Test
    """

    train_dataloader = DataLoader(
            datasets.CIFAR10('data', train=True, download=True, transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262))
                        ])),
            batch_size=batch_size,
            shuffle=True)
    test_dataloader = DataLoader(
            datasets.CIFAR10('data', train=False, download=True, transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.491, 0.482, 0.447), (0.247, 0.243, 0.262))
                        ])),
            batch_size=batch_size,
            shuffle=True)

    return train_dataloader, test_dataloader

def plot_results(image, pred_label, gt_label):
    """This function produces a plot showing the input image, the predicted label, and the ground truth label

    Args:
        image: Tensor representing the image used as input to the model, of shape (3, 32, 32)
        pred_label: An integer (0-9) representing the predicted class
        gt_label: An integer (0-9) representing the image's ground truth class
    """

    labels = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    img = image.detach().permute(1, 2, 0).numpy() # Detach from computational graph, shape to (32, 32, 3), convert to numpy array
    img = img * 0.25 + 0.5 # Undo normalization

    plt.subplot(1, 1, 1)
    plt.title(f"Prediction: {labels[pred_label]}\nGround Truth: {labels[gt_label]}")
    plt.imshow(img)
    plt.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    plt.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)
    plt.show()

# Load the Dataset

In [3]:
BATCH_SIZE = 20
train_dataloader, test_dataloader = load_CIFAR10_dataset(BATCH_SIZE)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:10<00:00, 16.0MB/s] 


Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified


# Convolutional Forward Pass

In [None]:
def conv_forward(input_tensor, weights, bias, padding=0, stride=1):
    """
    Manual implementation of convolutional forward pass

    Args:
        input_tensor: Input tensor of shape (batch_size, in_channels, height, width)
        weights: Kernel weights of shape (out_channels, in_channels, kernel_height, kernel_width)
        bias: Bias tensor of shape (out_channels,)
        padding: Padding size (default 0)
        stride: Stride size (default 1)

    Returns:
        output: Convolved output tensor
    """
    batch_size, in_channels, input_height, input_width = input_tensor.shape
    out_channels, _, kernel_height, kernel_width = weights.shape


    output_height = (input_height + 2*padding - kernel_height)//stride + 1
    output_width = (input_width + 2*padding - kernel_width)//stride + 1


    # # forloop impl
    # padded_input = torch.zeros(batch_size, in_channels,
    #                           input_height + 2 * padding,
    #                           input_width + 2 * padding,
    #                           dtype=input_tensor.dtype,
    #                           device=input_tensor.device)
    # padded_input[:, :, padding:padding+input_height, padding:padding+input_width] = input_tensor
    # output = torch.zeros(batch_size, out_channels, output_height, output_width,
    #                     dtype=input_tensor.dtype, device=input_tensor.device)

    # for b in range(batch_size):
    #     for oc in range(out_channels):
    #         for i in range(output_height):
    #             for j in range(output_width):
    #                 # Compute the start/end indices of the current patch
    #                 h_start = i * stride
    #                 h_end = h_start + kernel_height
    #                 w_start = j * stride
    #                 w_end = w_start + kernel_width

    #                 # Extract input patch
    #                 patch = padded_input[b, :, h_start:h_end, w_start:w_end]  # shape (in_channels, kH, kW)

    #                 # Element-wise multiply and sum over channels and kernel dims
    #                 output[b, oc, i, j] = torch.sum(patch * weights[oc]) + bias[oc]



    # return output


    padded_input = torch.zeros(batch_size, in_channels,
                               input_height + 2*padding,
                               input_width + 2*padding,
                               dtype=input_tensor.dtype,
                               device=input_tensor.device)
    padded_input[:, :, padding:padding+input_height, padding:padding+input_width] = input_tensor

    patches = padded_input.unfold(2, kernel_height, stride).unfold(3, kernel_width, stride) # Shape: (B, C, out_H, out_W, kH, kW)

    # Flatten last three dims (C, kH, kW) → (B, out_H, out_W, C*kH*kW)
    patches = patches.contiguous().view(batch_size, in_channels, output_height, output_width, -1)
    patches = patches.permute(0, 2, 3, 1, 4).reshape(batch_size, output_height*output_width, -1)  # (B, H*W, C*kH*kW)

    # Flatten weights
    weights_flat = weights.view(out_channels, -1)  # (out_channels, C*kH*kW)

    # Batch matrix multiply: (B, H*W, C*kH*kW) @ (C*kH*kW, out_channels) -> (B, H*W, out_channels)
    output = patches @ weights_flat.T  # shape (B, H*W, out_channels)

    output = output.permute(0, 2, 1).contiguous().view(batch_size, out_channels, output_height, output_width)
    output += bias.view(1, -1, 1, 1)

    return output


# Test it!

Run the following tests to check your implementation (you may still have edge-case errors, these tests are nowhere near all-encompassing)

In [8]:
# Convolution cloning test
weights = torch.zeros(3, 3, 3, 3)
bias = torch.zeros(3)
# Create cloning kernel and biases
for out_c in range(3):
    for in_c in range(3):
        if out_c == in_c:
            weights[out_c, in_c, 1, 1] = 1
batch_sum = 0
out_sum = 0
print('Testing cloning convolution')
pbar = tqdm(total=4)
for i, (batch, labels) in enumerate(train_dataloader):
    if i == 4: break
    batch_sum += batch.detach().flatten().sum()
    out = conv_forward(batch, weights, bias, 1, 1)
    out_sum += out.flatten().sum()
    pbar.update(1)
pbar.close()
# Check correctness
if np.isclose(batch_sum, out_sum, atol=0.5):
    print('Hooray! Your implementation matches the correct output!')
else:
    print('Womp Womp :( Your implementation is not correct')

print()

# Convolution test
torch.manual_seed(27)
# Create randomized kernel and biases
weights = torch.randn(12, 3, 3, 3)
bias = torch.randn(12)
out_sum = 0
print('Testing randomly initialized weights convolution')
pbar = tqdm(total=4)
for i, (batch, labels) in enumerate(train_dataloader):
    if i == 4: break
    out = conv_forward(batch, weights, bias, 0, 1)
    out_sum += out.flatten().sum()
    pbar.update(1)
pbar.close()
# Check correctness
if np.isclose(out_sum, -212331.1562, atol=0.5):
    print('Hooray! Your implementation matches the correct output!')
else:
    print('Womp Womp :( Your implementation is not correct')


Testing cloning convolution


  0%|          | 0/4 [00:00<?, ?it/s]

Hooray! Your implementation matches the correct output!

Testing randomly initialized weights convolution


  0%|          | 0/4 [00:00<?, ?it/s]

Hooray! Your implementation matches the correct output!
