# Convolutional Neural Networks (CNNs) for Image Data (from scratch NumPy + PyTorch)

CNNs are neural networks designed for **grid-like data** (images, audio spectrograms, some time-series). They work because they encode three key inductive biases:

- **Locality**: patterns are local (edges, corners).
- **Translation structure**: the same pattern can appear anywhere.
- **Weight sharing**: we reuse the same filter across the whole image.

This notebook builds intuition with Plotly visuals, implements a tiny CNN **from scratch in NumPy** (including convolution backprop), then shows the equivalent workflow in **PyTorch**.

---

## Learning goals

By the end you should be able to:

- explain convolution with **channels, kernels, padding, stride**
- understand why convolutions use far fewer parameters than dense layers
- implement and train a small CNN in **pure NumPy** (manual backprop)
- train the same model in **PyTorch** and visualize results

---

## Prerequisites

- backpropagation for MLPs
- NumPy fundamentals
- (optional) basic PyTorch


## Notebook roadmap

1. Data: load a small image dataset (no downloads)
2. Intuition: convolution and filters (Plotly heatmaps)
3. From scratch (NumPy): Conv2D + ReLU + MaxPool + Linear
4. Visual diagnostics (NumPy): filters, feature maps, confusion matrix
5. Practical (PyTorch): same model idea with autograd
6. Pitfalls + exercises


In [None]:
import time

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio

from sklearn.datasets import load_digits
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import DataLoader, TensorDataset
    TORCH_AVAILABLE = True
except Exception as e:
    TORCH_AVAILABLE = False
    _TORCH_IMPORT_ERROR = e


pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
np.set_printoptions(precision=4, suppress=True)

SEED = 42
rng = np.random.default_rng(SEED)


In [None]:
# --- Run configuration ---
FAST_RUN = True  # set False for better accuracy / more epochs

EPOCHS_NUMPY = 10 if FAST_RUN else 40
EPOCHS_TORCH = 10 if FAST_RUN else 30

BATCH_SIZE = 64
LR_NUMPY = 0.05
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-4

LR_TORCH = 1e-3

RUN_GRAD_CHECK = False  # optional (slow)


## 1) Data: a tiny image dataset (8×8 handwritten digits)

To keep this notebook **offline-friendly**, we use `sklearn.datasets.load_digits()`:

- 1 channel (grayscale)
- 8×8 images
- 10 classes (0–9)


In [None]:
digits = load_digits()
X = digits.images.astype(np.float32)  # (N, 8, 8)
y = digits.target.astype(np.int64)

# Normalize pixels roughly into [0, 1]
X = X / 16.0

# Add channel dimension: (N, C=1, H, W)
X = X[:, None, :, :]

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.25, random_state=SEED, stratify=y
)

print('X_train', X_train.shape, 'X_test', X_test.shape)


In [None]:
# Visualize a grid of sample images

idx = rng.choice(len(X_train), size=36, replace=False)
imgs = X_train[idx, 0]  # (36, 8, 8)

fig = px.imshow(
    imgs,
    facet_col=0,
    facet_col_wrap=9,
    color_continuous_scale='gray',
    title='Sample training images (8×8)',
)

for a in fig.layout.annotations:
    a.text = ''

fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
fig.update_layout(coloraxis_showscale=False, margin=dict(l=10, r=10, t=60, b=10))
fig.show()


## 2) Convolution intuition (filters as pattern detectors)

In deep learning, "convolution" is usually implemented as **cross-correlation** (we don’t flip the kernel):

\[
y[n, c_{out}, i, j] = b[c_{out}] + \sum_{c_{in}} \sum_{u=0}^{kH-1} \sum_{v=0}^{kW-1}
  x[n, c_{in}, i+u, j+v]\, W[c_{out}, c_{in}, u, v]
\]

Why this is powerful:

- **Local receptive fields**: each output sees a small patch.
- **Weight sharing**: the *same* kernel is used everywhere → far fewer parameters.


In [None]:
from plotly.subplots import make_subplots

In [None]:
def conv2d_single_channel_naive(img: np.ndarray, kernel: np.ndarray, padding: int = 0, stride: int = 1) -> np.ndarray:
    img = np.asarray(img, dtype=np.float32)
    kernel = np.asarray(kernel, dtype=np.float32)

    H, W = img.shape
    kH, kW = kernel.shape

    if padding > 0:
        img_p = np.pad(img, ((padding, padding), (padding, padding)), mode='constant')
    else:
        img_p = img

    Hp, Wp = img_p.shape
    out_h = (Hp - kH) // stride + 1
    out_w = (Wp - kW) // stride + 1

    out = np.zeros((out_h, out_w), dtype=np.float32)
    for i in range(out_h):
        for j in range(out_w):
            patch = img_p[i*stride:i*stride+kH, j*stride:j*stride+kW]
            out[i, j] = float(np.sum(patch * kernel))
    return out


# Pick one image
img = X_train[0, 0]

# Two classic edge kernels
kx = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=np.float32)
ky = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=np.float32)

fx = conv2d_single_channel_naive(img, kx, padding=1)
fy = conv2d_single_channel_naive(img, ky, padding=1)

fig = make_subplots(rows=1, cols=3, subplot_titles=['Input', 'Horizontal edges', 'Vertical edges'])
fig.add_trace(go.Heatmap(z=img, colorscale='gray', showscale=False), row=1, col=1)
fig.add_trace(go.Heatmap(z=fx, colorscale='RdBu', zmid=0, showscale=False), row=1, col=2)
fig.add_trace(go.Heatmap(z=fy, colorscale='RdBu', zmid=0, showscale=False), row=1, col=3)
fig.update_layout(title='A toy convolution demo (naive implementation)', height=320, margin=dict(l=10, r=10, t=60, b=10))
fig.show()


In [None]:
# Parameter count: dense vs conv

# Suppose we map an 8×8 image (64 pixels) to 8 feature maps (still 8×8).
# - A dense layer would connect every input pixel to every output pixel.
# - A conv layer uses a small kernel shared across space.

def n_params_dense(fan_in: int, fan_out: int) -> int:
    return fan_in * fan_out + fan_out


def n_params_conv2d(c_in: int, c_out: int, k: int) -> int:
    return c_out * (c_in * k * k) + c_out


dense = n_params_dense(64, 64 * 8)
conv = n_params_conv2d(c_in=1, c_out=8, k=3)

_df = pd.DataFrame(
    [
        {'layer': 'Dense (64 → 512)', 'params': dense},
        {'layer': 'Conv2D (1→8, 3×3)', 'params': conv},
    ]
)

fig = px.bar(
    _df,
    x='layer',
    y='params',
    text='params',
    title='Parameter count: dense vs convolution (weight sharing)',
)
fig.update_traces(textposition='outside')
fig.update_layout(yaxis_title='number of parameters')
fig.show()


## 3) From scratch in NumPy: a tiny CNN

We implement a small CNN **without autograd**:

- `Conv2D` forward + backward (via `im2col` / `col2im`)
- `ReLU`, `MaxPool2D`, `Flatten`, `Linear`
- softmax cross-entropy loss
- SGD with momentum

The goal is clarity and correctness (not speed).


In [None]:
def _pad2d(x: np.ndarray, pad: int) -> np.ndarray:
    if pad <= 0:
        return x
    return np.pad(x, ((0, 0), (0, 0), (pad, pad), (pad, pad)), mode='constant')


def im2col(x: np.ndarray, kH: int, kW: int, stride: int, pad: int) -> tuple[np.ndarray, tuple[int,int,int,int]]:
    '''
    x: (N, C, H, W)
    Returns:
      cols: (N*out_h*out_w, C*kH*kW)
      out_shape: (out_h, out_w, H_padded, W_padded) for col2im
    '''
    N, C, H, W = x.shape
    x_p = _pad2d(x, pad)
    _, _, Hp, Wp = x_p.shape

    out_h = (Hp - kH) // stride + 1
    out_w = (Wp - kW) // stride + 1

    cols = np.empty((N, C, kH, kW, out_h, out_w), dtype=x.dtype)

    for i in range(kH):
        i_end = i + stride * out_h
        for j in range(kW):
            j_end = j + stride * out_w
            cols[:, :, i, j, :, :] = x_p[:, :, i:i_end:stride, j:j_end:stride]

    cols = cols.transpose(0, 4, 5, 1, 2, 3).reshape(N * out_h * out_w, C * kH * kW)
    return cols, (out_h, out_w, Hp, Wp)


def col2im(cols: np.ndarray, x_shape: tuple[int,int,int,int], kH: int, kW: int, stride: int, pad: int, out_shape: tuple[int,int,int,int]) -> np.ndarray:
    N, C, H, W = x_shape
    out_h, out_w, Hp, Wp = out_shape

    cols_reshaped = cols.reshape(N, out_h, out_w, C, kH, kW).transpose(0, 3, 4, 5, 1, 2)

    x_p = np.zeros((N, C, Hp, Wp), dtype=cols.dtype)

    for i in range(kH):
        i_end = i + stride * out_h
        for j in range(kW):
            j_end = j + stride * out_w
            x_p[:, :, i:i_end:stride, j:j_end:stride] += cols_reshaped[:, :, i, j, :, :]

    if pad <= 0:
        return x_p
    return x_p[:, :, pad:-pad, pad:-pad]


class Conv2D:
    def __init__(self, c_in: int, c_out: int, k: int, stride: int = 1, pad: int = 0, rng: np.random.Generator | None = None):
        self.c_in = c_in
        self.c_out = c_out
        self.k = k
        self.stride = stride
        self.pad = pad

        rng = rng or np.random.default_rng(0)
        scale = np.sqrt(2.0 / (c_in * k * k))
        self.W = (rng.standard_normal((c_out, c_in, k, k)).astype(np.float32) * scale)
        self.b = np.zeros((c_out,), dtype=np.float32)

        self.dW = np.zeros_like(self.W)
        self.db = np.zeros_like(self.b)

        self._cache = None

    def forward(self, x: np.ndarray) -> np.ndarray:
        cols, out_shape = im2col(x, self.k, self.k, self.stride, self.pad)
        W_col = self.W.reshape(self.c_out, -1)
        out = cols @ W_col.T + self.b[None, :]

        N, _, _, _ = x.shape
        out_h, out_w, _, _ = out_shape
        out = out.reshape(N, out_h, out_w, self.c_out).transpose(0, 3, 1, 2)

        self._cache = (x, cols, out_shape)
        return out

    def backward(self, dout: np.ndarray) -> np.ndarray:
        x, cols, out_shape = self._cache

        dout_col = dout.transpose(0, 2, 3, 1).reshape(-1, self.c_out)
        self.db[...] = dout_col.sum(axis=0)

        W_col = self.W.reshape(self.c_out, -1)
        self.dW[...] = (dout_col.T @ cols).reshape(self.W.shape)

        dcols = dout_col @ W_col
        dx = col2im(dcols, x.shape, self.k, self.k, self.stride, self.pad, out_shape)
        return dx

    def step(self, lr: float, vW: np.ndarray, vb: np.ndarray, weight_decay: float, momentum: float) -> tuple[np.ndarray, np.ndarray]:
        vW = momentum * vW - lr * (self.dW + weight_decay * self.W)
        vb = momentum * vb - lr * self.db
        self.W += vW
        self.b += vb
        return vW, vb


class ReLU:
    def __init__(self):
        self._mask = None

    def forward(self, x: np.ndarray) -> np.ndarray:
        self._mask = x > 0
        return x * self._mask

    def backward(self, dout: np.ndarray) -> np.ndarray:
        return dout * self._mask


class MaxPool2D:
    def __init__(self, pool: int = 2, stride: int = 2):
        self.pool = pool
        self.stride = stride
        self._cache = None

    def forward(self, x: np.ndarray) -> np.ndarray:
        N, C, H, W = x.shape
        p = self.pool
        assert H % p == 0 and W % p == 0, 'For simplicity we assume H and W divisible by pool size.'

        out_h, out_w = H // p, W // p
        x_reshaped = x.reshape(N, C, out_h, p, out_w, p)
        out = x_reshaped.max(axis=(3, 5))

        self._cache = (x, x_reshaped, out)
        return out

    def backward(self, dout: np.ndarray) -> np.ndarray:
        x, x_reshaped, out = self._cache
        _, _, out_h, p, out_w, _ = x_reshaped.shape

        mask = x_reshaped == out[:, :, :, None, :, None]
        denom = mask.sum(axis=(3, 5), keepdims=True)
        denom = np.where(denom == 0, 1, denom)

        dx_reshaped = mask * (dout[:, :, :, None, :, None] / denom)
        dx = dx_reshaped.reshape(x.shape)
        return dx


class Flatten:
    def __init__(self):
        self._shape = None

    def forward(self, x: np.ndarray) -> np.ndarray:
        self._shape = x.shape
        return x.reshape(x.shape[0], -1)

    def backward(self, dout: np.ndarray) -> np.ndarray:
        return dout.reshape(self._shape)


class Linear:
    def __init__(self, fan_in: int, fan_out: int, rng: np.random.Generator | None = None):
        rng = rng or np.random.default_rng(0)
        scale = np.sqrt(2.0 / fan_in)
        self.W = (rng.standard_normal((fan_in, fan_out)).astype(np.float32) * scale)
        self.b = np.zeros((fan_out,), dtype=np.float32)

        self.dW = np.zeros_like(self.W)
        self.db = np.zeros_like(self.b)

        self._cache = None

    def forward(self, x: np.ndarray) -> np.ndarray:
        self._cache = x
        return x @ self.W + self.b

    def backward(self, dout: np.ndarray) -> np.ndarray:
        x = self._cache
        self.dW[...] = x.T @ dout
        self.db[...] = dout.sum(axis=0)
        return dout @ self.W.T

    def step(self, lr: float, vW: np.ndarray, vb: np.ndarray, weight_decay: float, momentum: float) -> tuple[np.ndarray, np.ndarray]:
        vW = momentum * vW - lr * (self.dW + weight_decay * self.W)
        vb = momentum * vb - lr * self.db
        self.W += vW
        self.b += vb
        return vW, vb


def softmax_cross_entropy(logits: np.ndarray, y: np.ndarray) -> tuple[float, np.ndarray]:
    logits = logits.astype(np.float32)
    logits = logits - logits.max(axis=1, keepdims=True)
    exp = np.exp(logits)
    probs = exp / exp.sum(axis=1, keepdims=True)

    N = logits.shape[0]
    loss = -np.log(probs[np.arange(N), y] + 1e-12).mean()

    dlogits = probs
    dlogits[np.arange(N), y] -= 1
    dlogits /= N
    return float(loss), dlogits


### Optional: quick gradient check (small and slow)

Finite-difference gradient checks are a good way to verify your `Conv2D.backward()`.

Kept off by default.


In [None]:
def grad_check_conv2d():
    rng_local = np.random.default_rng(0)
    x = rng_local.standard_normal((2, 1, 6, 6)).astype(np.float32)
    conv = Conv2D(1, 2, k=3, stride=1, pad=1, rng=rng_local)

    out = conv.forward(x)
    dout = rng_local.standard_normal(out.shape).astype(np.float32)
    _ = conv.backward(dout)

    eps = 1e-3
    for _ in range(10):
        oc = int(rng_local.integers(0, conv.W.shape[0]))
        ic = int(rng_local.integers(0, conv.W.shape[1]))
        i = int(rng_local.integers(0, conv.W.shape[2]))
        j = int(rng_local.integers(0, conv.W.shape[3]))

        old = conv.W[oc, ic, i, j]

        conv.W[oc, ic, i, j] = old + eps
        out_p = conv.forward(x)
        loss_p = float((out_p * dout).sum())

        conv.W[oc, ic, i, j] = old - eps
        out_m = conv.forward(x)
        loss_m = float((out_m * dout).sum())

        conv.W[oc, ic, i, j] = old

        num = (loss_p - loss_m) / (2 * eps)
        ana = float(conv.dW[oc, ic, i, j])
        rel_err = abs(num - ana) / (abs(num) + abs(ana) + 1e-12)
        print(f'W[{oc},{ic},{i},{j}]  num={num:.5f}  ana={ana:.5f}  rel_err={rel_err:.3e}')


if RUN_GRAD_CHECK:
    grad_check_conv2d()


### Train the NumPy CNN

Architecture:

- Conv(1→8, 3×3, pad=1) + ReLU + MaxPool(2×2)
- Conv(8→16, 3×3, pad=1) + ReLU + MaxPool(2×2)
- Flatten
- Linear(16·2·2 → 10)


In [None]:
rng_model = np.random.default_rng(SEED)

conv1 = Conv2D(1, 8, k=3, stride=1, pad=1, rng=rng_model)
relu1 = ReLU()
pool1 = MaxPool2D(pool=2, stride=2)

conv2 = Conv2D(8, 16, k=3, stride=1, pad=1, rng=rng_model)
relu2 = ReLU()
pool2 = MaxPool2D(pool=2, stride=2)

flat = Flatten()
fc = Linear(16 * 2 * 2, 10, rng=rng_model)

# Momentum buffers
v_c1_W = np.zeros_like(conv1.W)
v_c1_b = np.zeros_like(conv1.b)
v_c2_W = np.zeros_like(conv2.W)
v_c2_b = np.zeros_like(conv2.b)
v_fc_W = np.zeros_like(fc.W)
v_fc_b = np.zeros_like(fc.b)


def forward_numpy(xb: np.ndarray) -> np.ndarray:
    out = conv1.forward(xb)
    out = relu1.forward(out)
    out = pool1.forward(out)

    out = conv2.forward(out)
    out = relu2.forward(out)
    out = pool2.forward(out)

    out = flat.forward(out)
    out = fc.forward(out)
    return out


def backward_numpy(dlogits: np.ndarray) -> None:
    dout = fc.backward(dlogits)
    dout = flat.backward(dout)

    dout = pool2.backward(dout)
    dout = relu2.backward(dout)
    dout = conv2.backward(dout)

    dout = pool1.backward(dout)
    dout = relu1.backward(dout)
    _ = conv1.backward(dout)


def step_numpy(lr: float) -> None:
    global v_c1_W, v_c1_b, v_c2_W, v_c2_b, v_fc_W, v_fc_b

    v_c1_W, v_c1_b = conv1.step(lr, v_c1_W, v_c1_b, WEIGHT_DECAY, MOMENTUM)
    v_c2_W, v_c2_b = conv2.step(lr, v_c2_W, v_c2_b, WEIGHT_DECAY, MOMENTUM)
    v_fc_W, v_fc_b = fc.step(lr, v_fc_W, v_fc_b, WEIGHT_DECAY, MOMENTUM)


def predict_numpy(x: np.ndarray) -> np.ndarray:
    logits = forward_numpy(x)
    return logits.argmax(axis=1)


def iterate_minibatches(X_: np.ndarray, y_: np.ndarray, batch_size: int, rng_: np.random.Generator):
    idx = rng_.permutation(len(X_))
    for start in range(0, len(X_), batch_size):
        sl = idx[start:start + batch_size]
        yield X_[sl], y_[sl]


history_numpy = []
start = time.time()
for epoch in range(1, EPOCHS_NUMPY + 1):
    losses = []
    for xb, yb in iterate_minibatches(X_train, y_train, BATCH_SIZE, rng):
        logits = forward_numpy(xb)
        loss, dlogits = softmax_cross_entropy(logits, yb)
        losses.append(loss)

        backward_numpy(dlogits)
        step_numpy(LR_NUMPY)

    yhat_train = predict_numpy(X_train)
    yhat_test = predict_numpy(X_test)

    train_acc = accuracy_score(y_train, yhat_train)
    test_acc = accuracy_score(y_test, yhat_test)

    history_numpy.append({
        'epoch': epoch,
        'loss': float(np.mean(losses)),
        'train_acc': float(train_acc),
        'test_acc': float(test_acc),
    })

    print(f"[NumPy] epoch {epoch:02d}/{EPOCHS_NUMPY}  loss={history_numpy[-1]['loss']:.4f}  train_acc={train_acc:.3f}  test_acc={test_acc:.3f}")

elapsed = time.time() - start
print(f"NumPy training time: {elapsed:.2f}s")


In [None]:
dfn = pd.DataFrame(history_numpy)

fig = go.Figure()
fig.add_trace(go.Scatter(x=dfn['epoch'], y=dfn['loss'], mode='lines+markers', name='loss'))
fig.update_layout(title='NumPy training loss', xaxis_title='epoch', yaxis_title='cross-entropy')
fig.show()

fig = go.Figure()
fig.add_trace(go.Scatter(x=dfn['epoch'], y=dfn['train_acc'], mode='lines+markers', name='train'))
fig.add_trace(go.Scatter(x=dfn['epoch'], y=dfn['test_acc'], mode='lines+markers', name='test'))
fig.update_layout(title='NumPy accuracy', xaxis_title='epoch', yaxis_title='accuracy', yaxis=dict(range=[0, 1]))
fig.show()


## 4) Visual diagnostics (NumPy)

- learned **filters** in the first conv layer
- **feature maps** (activations) after Conv1 + ReLU
- confusion matrix and misclassifications


In [None]:
W1 = conv1.W[:, 0]  # (8, 3, 3)

fig = px.imshow(
    W1,
    facet_col=0,
    facet_col_wrap=4,
    color_continuous_scale='RdBu',
    zmin=float(W1.min()),
    zmax=float(W1.max()),
    title='NumPy: learned Conv1 filters (3×3)',
)
for a in fig.layout.annotations:
    a.text = ''
fig.update_layout(coloraxis_showscale=False)
fig.show()


In [None]:
x0 = X_test[0:1]

out1 = relu1.forward(conv1.forward(x0))  # (1, 8, 8, 8)

fig = px.imshow(
    out1[0],
    facet_col=0,
    facet_col_wrap=4,
    color_continuous_scale='Viridis',
    title='NumPy: feature maps after Conv1 + ReLU (one test image)',
)
for a in fig.layout.annotations:
    a.text = ''
fig.update_layout(coloraxis_showscale=False)
fig.show()


In [None]:
yhat = predict_numpy(X_test)
cm = confusion_matrix(y_test, yhat, labels=list(range(10)))

fig = px.imshow(
    cm,
    text_auto=True,
    color_continuous_scale='Blues',
    title='NumPy: confusion matrix (test set)',
    labels=dict(x='pred', y='true', color='count'),
)
fig.update_xaxes(tickmode='array', tickvals=list(range(10)))
fig.update_yaxes(tickmode='array', tickvals=list(range(10)))
fig.show()


In [None]:
mis = np.where(yhat != y_test)[0]

n_show = min(36, len(mis))
mis = mis[:n_show]

imgs = X_test[mis, 0]
texts = [f"true={int(y_test[i])}, pred={int(yhat[i])}" for i in mis]

fig = px.imshow(
    imgs,
    facet_col=0,
    facet_col_wrap=9,
    color_continuous_scale='gray',
    title=f'NumPy: misclassified test images (first {n_show})',
)

for a, t in zip(fig.layout.annotations, texts):
    a.text = t

fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
fig.update_layout(coloraxis_showscale=False, margin=dict(l=10, r=10, t=60, b=10))
fig.show()


## 5) Practical implementation in PyTorch

PyTorch gives you:

- automatic differentiation (autograd)
- optimized kernels
- convenient layers (`nn.Conv2d`, `nn.MaxPool2d`, etc.)

We train a small CNN with the same structure.


In [None]:
if not TORCH_AVAILABLE:
    raise RuntimeError(f"PyTorch not available: {_TORCH_IMPORT_ERROR}")

# Device (suppress noisy CUDA init warnings in restricted environments)
import warnings

with warnings.catch_warnings():
    warnings.filterwarnings('ignore', message='CUDA initialization:.*')
    cuda_ok = bool(torch.cuda.is_available())

device = torch.device('cuda' if cuda_ok else 'cpu')
print('device:', device)

Xtr = torch.from_numpy(X_train).to(device)
ytr = torch.from_numpy(y_train).to(device)
Xte = torch.from_numpy(X_test).to(device)
yte = torch.from_numpy(y_test).to(device)

train_loader = DataLoader(TensorDataset(Xtr, ytr), batch_size=BATCH_SIZE, shuffle=True)
train_eval_loader = DataLoader(TensorDataset(Xtr, ytr), batch_size=256, shuffle=False)
test_loader = DataLoader(TensorDataset(Xte, yte), batch_size=256, shuffle=False)


class TinyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(8, 16, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.head = nn.Linear(16 * 2 * 2, 10)

    def forward(self, x):
        x = self.features(x)
        x = x.flatten(1)
        return self.head(x)


torch.manual_seed(SEED)
model = TinyCNN().to(device)
opt = torch.optim.Adam(model.parameters(), lr=LR_TORCH, weight_decay=WEIGHT_DECAY)


def eval_loader(loader):
    model.eval()
    ys = []
    yhats = []
    losses = []
    with torch.no_grad():
        for xb, yb in loader:
            logits = model(xb)
            loss = F.cross_entropy(logits, yb)
            losses.append(loss.item())
            ys.append(yb.detach().cpu().numpy())
            yhats.append(logits.argmax(dim=1).detach().cpu().numpy())
    y_all = np.concatenate(ys)
    yhat_all = np.concatenate(yhats)
    return float(np.mean(losses)), float(accuracy_score(y_all, yhat_all))


history_torch = []
start = time.time()
for epoch in range(1, EPOCHS_TORCH + 1):
    model.train()
    batch_losses = []
    for xb, yb in train_loader:
        opt.zero_grad(set_to_none=True)
        logits = model(xb)
        loss = F.cross_entropy(logits, yb)
        loss.backward()
        opt.step()
        batch_losses.append(loss.item())

    train_loss, train_acc = eval_loader(train_eval_loader)
    test_loss, test_acc = eval_loader(test_loader)

    history_torch.append({
        'epoch': epoch,
        'loss': float(np.mean(batch_losses)),
        'train_acc': train_acc,
        'test_acc': test_acc,
        'train_loss_eval': train_loss,
        'test_loss_eval': test_loss,
    })
    print(f"[Torch] epoch {epoch:02d}/{EPOCHS_TORCH}  loss={history_torch[-1]['loss']:.4f}  train_acc={train_acc:.3f}  test_acc={test_acc:.3f}")

elapsed = time.time() - start
print(f"Torch training time: {elapsed:.2f}s")


In [None]:
dft = pd.DataFrame(history_torch)

fig = go.Figure()
fig.add_trace(go.Scatter(x=dft['epoch'], y=dft['loss'], mode='lines+markers', name='train_loss(batch)'))
fig.add_trace(go.Scatter(x=dft['epoch'], y=dft['test_loss_eval'], mode='lines+markers', name='test_loss'))
fig.update_layout(title='Torch loss', xaxis_title='epoch', yaxis_title='cross-entropy')
fig.show()

fig = go.Figure()
fig.add_trace(go.Scatter(x=dft['epoch'], y=dft['train_acc'], mode='lines+markers', name='train'))
fig.add_trace(go.Scatter(x=dft['epoch'], y=dft['test_acc'], mode='lines+markers', name='test'))
fig.update_layout(title='Torch accuracy', xaxis_title='epoch', yaxis_title='accuracy', yaxis=dict(range=[0, 1]))
fig.show()


In [None]:
model.eval()
with torch.no_grad():
    logits = model(Xte)
    yhat = logits.argmax(dim=1).detach().cpu().numpy()

cm = confusion_matrix(y_test, yhat, labels=list(range(10)))

fig = px.imshow(
    cm,
    text_auto=True,
    color_continuous_scale='Blues',
    title='Torch: confusion matrix (test set)',
    labels=dict(x='pred', y='true', color='count'),
)
fig.update_xaxes(tickmode='array', tickvals=list(range(10)))
fig.update_yaxes(tickmode='array', tickvals=list(range(10)))
fig.show()


## Pitfalls + exercises

- If training is unstable: lower the learning rate, check initialization, and verify shapes.
- If accuracy stalls: add channels, add another conv layer, or train longer.

---

## Exercises

1. Add a third convolution layer and compare curves.
2. Replace MaxPool with a strided convolution.
3. Visualize feature maps deeper in the network.

---

## References

- CS231n notes on CNNs
