# ResNet-18 style on MNIST (low-level PyTorch)
This notebook explains the structure of a ResNet (residual blocks) and how the provided scratch-ish implementation works.

**Goal:** avoid `torch.nn.Module` / `nn.Conv2d` / `nn.BatchNorm2d` and implement layers manually using tensors + `torch.nn.functional`.


## 1) Residual learning (core idea)
Instead of learning `H(x)` directly, a residual block learns:
`y = F(x) + x` (or `y = F(x) + shortcut(x)` when shapes differ).

This makes optimization easier, especially for deep networks.


## 2) ResNet-18 block structure
A basic block has two 3Ã—3 convolutions with BatchNorm + ReLU:

- `conv3x3 -> bn -> relu -> conv3x3 -> bn`
- add shortcut
- final relu


In [None]:
import torch
from ResNet.resnet18_scratch import ResNet18Scratch

torch.manual_seed(0)
model = ResNet18Scratch(num_classes=10)
x = torch.randn(8, 1, 28, 28)
logits = model.forward(x, training=True)
print('logits shape:', logits.shape)

## 3) BatchNorm2dScratch
The ResNet implementation includes a `BatchNorm2dScratch` that:
- normalizes per channel over `(N, H, W)`
- maintains running mean/var
- switches behavior based on `training` flag


## 4) Training on MNIST
Use the provided training script. For a fast CPU run, use `--subset` to limit the dataset.


In [None]:
# In terminal, run:
# python -m ResNet.test_resnet_smoke
# python -m ResNet.train_mnist --epochs 1 --subset 5000 --batch-size 128 --lr 0.05
