## Setup

In [1]:
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import torch.nn.functional as F

from src.data import DataLoaderScratch
from src.trainer import TrainerScratch
from src.optimizers import SGDScratch
from src.functions import conv2d, maxpool2d

## Data Loading and Preprocessing

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
mnist_trainset = datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
mnist_testset = datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)

# Transform the training data
X_train = mnist_trainset.data.float() / 255.0
# Add single dimension for the input channel
X_train = X_train.unsqueeze(1)
y_train = mnist_trainset.targets

# Transform the test data
X_val = mnist_testset.data.float() / 255.0
# Add single dimension for the input channel
X_val = X_val.unsqueeze(1)
y_val = mnist_testset.targets

train_dataloader = DataLoaderScratch(X_train, y_train, batch_size=256, shuffle=True)
val_dataloader = DataLoaderScratch(X_val, y_val, batch_size=256, shuffle=False)

### Single Batch Iteration

In [3]:
def relu(x):
    out = torch.maximum(x, torch.zeros(1))
    return out

def softmax(X):
    X_exp = torch.exp(X)
    X_softmax = X_exp / X_exp.sum(axis=1, keepdims=True)
    return X_softmax

def log_loss(y_pred, y):
    y_one_hot = nn.functional.one_hot(y)
    loss = -(y_one_hot * torch.log(y_pred)).sum(axis=1).mean()
    return loss

def calculate_same_padding(input_size, kernel_size, stride):
    if input_size % stride == 0:
        pad_total = max(kernel_size - stride, 0)
    else:
        pad_total = max(kernel_size - (input_size % stride), 0)
    padding = pad_total // 2
    return padding

In [4]:
# Create a batch
batch_size = 128
perm = torch.randperm(len(X_train))
X_batch = X_train[perm][:batch_size]
y_batch = y_train[perm][:batch_size]

batch_size, in_channels, input_height, input_width = X_batch.shape

filter_size = 3 # Filter size for all layers
pool_size = 2 # Pool size for all layers

# CONV + POOL Layer 1
out_channels1 = 16 # Number of filters in the first conv layer
W1 = nn.Parameter(torch.randn(out_channels1, in_channels, filter_size, filter_size) * 0.01)
b1 =  nn.Parameter(torch.zeros(size=(1, out_channels1, 1, 1)))
# After Conv1 + Same Padding + Stride 1 => Shape remains [batch_size, out_channels1, input_height, input_width]
# After Pooling1 with pool_size 2 and stride 2 => Shape: [batch_size, out_channels1, input_height/2, input_width/2]

# CONV + POOL Layer 2
out_channels2 = 32 # Number of filters in the second conv layer
W2 = nn.Parameter(torch.randn(out_channels2, out_channels1, filter_size, filter_size) * 0.01)
b2 = nn.Parameter(torch.zeros(size=(1, out_channels2, 1, 1)))
# After Conv2 + Same Padding + Stride 1 => Shape: [batch_size, out_channels2, input_height/2, input_width/2]
# After Pooling2 with pool_size 2 and stride 2 => Shape: [batch_size, out_channels2, input_height/4, input_width/4]

# FC Layer
# Before the FC layer, the output from the last pooling layer is flattened
# Flattened shape: [batch_size, out_channels2 * (input_height/4) * (input_width/4)]
num_classes = 10  # For example, in a classification problem with 10 classes
# Initialize the fc layer weights
W3 = nn.Parameter(torch.randn(out_channels2 * int(input_height/4 * input_height/4), num_classes) * 0.01)
b3 = nn.Parameter(torch.zeros(num_classes))
# The FC layer maps from the flattened size to the number of classes
# After FC => Shape: [batch_size, num_classes]

parameters = [W1, b1, W2, b2, W3, b3]
optimizer = SGDScratch(parameters, lr=0.1)

In [5]:
# Zero gradients
optimizer.zero_grad()

# CONV + POOL Layer 1
padding = calculate_same_padding(input_height, filter_size, 1)
Z1 = conv2d(X_batch, W1, padding=padding) + b1
A1 = relu(Z1)
P1 = maxpool2d(A1, kernel_size=pool_size, stride=2)

# CONV + POOL Layer 2
padding = calculate_same_padding(P1.size(3), filter_size, 1)
Z2 = conv2d(P1, W2, padding=padding) + b2
A2 = relu(Z2)
P2 = maxpool2d(A2, kernel_size=pool_size, stride=2)

# FC Layer
P2_flat = P2.flatten(start_dim=1)
Z3 = P2_flat @ W3 + b3
y_pred = softmax(Z3)

# Calculate Loss
loss = log_loss(y_pred, y_batch)

# Compute gradients
loss.backward()

# Update parameters
optimizer.step()