In [None]:
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F


# Theory

## Why CNNs exist

Suppose you have a 224×224 RGB image:

- Inputs = 224×224×3 = 150,528 numbers.

- A single fully-connected layer to 1,000 hidden units would need:
  - 150,528 × 1,000 ≈ **150 million** weights (+ biases)

- This is:
  - expensive (memory/compute),

  - statistically inefficient (too many degrees of freedom),

  - ignores a huge prior: **nearby pixels matter together**.


## What is convolution

For 2D convolution on an image tensor `x` shaped `(C_in, H, W)`:

- We learn `C_out` filters (kernels).

- Each filter is shaped `(C_in, K, K)`, e.g. `(3, 3, 3)`.

- For each output channel `o` and spatial position `(i, j)`:

$$
y[o,i,j] = b[o] + \sum_{c=0}^{C_{in}-1} \sum_{u=0}^{K-1} \sum_{v=0}^{K-1}
W[o,c,u,v]\cdot x[c, i\cdot s + u - p,\; j\cdot s + v - p]
$$

Where:

- `s` = stride
- `p` = padding

### Output size

For input height `H`, kernel `K`, padding `P`, stride `S`:

$$
H_{out} = \left\lfloor \frac{H + 2P - K}{S} \right\rfloor + 1
$$

Same for width.


## Interpretation

<div align="center">
<img src="images/convs.png"  width="320"/>
</div>


## Stride, Padding, Dilation

- **Stride**:
    <div align="center">
    <img src="images/stride.png"  width="480"/>
    </div>

- **Padding**:
    <div align="center">
    <img src="images/padding.png"  width="480"/>
    </div>
    <br>

    > Using padding we can keep the dimensionality when using convolution and do many iterations (with activation function).

- **Dilation**:
    <div align="center">
    <img src="images/dilation.png"  width="480"/>
    </div>


## What is a “convolutional block”

A convolutional block is a standard mini-architecture that improves training stability and performance.

### The most common block

Conv $\to$ BatchNorm $\to$ ReLU

- **Conv** learns features

- **BatchNorm** stabilizes gradients and training

- **ReLU** adds nonlinearity (lets networks approximate complex functions)

Often you also include:

- **Pooling** (downsample spatial size) or use stride in conv

- **Dropout** (regularization)

- Sometimes multiple convs per block

### Typical patterns

1. **VGG-style**: (`Conv`$\to$ `ReLU`) $×2$ $\to$ `MaxPool`

2. **Modern**: `Conv`$\to$`BN`$\to$`ReLU` repeated, downsample using stride-2 conv

3. **Residual blocks (ResNet)**: add skip connection to ease optimization

4. **Depthwise-separable blocks (MobileNet)**: reduce compute dramatically


# Implementations

## I. Implementation of Convolution via `numpy`

In [None]:
def conv2d_forward_numpy(
    x: np.ndarray,
    w: np.ndarray,
    stride: int = 1,
    padding: int = 0,
) -> np.ndarray:
    """
    Educational 2D convolution forward pass (NCHW), using explicit loops.

    Args:
        x: Input tensor of shape (N, C_in, H, W)
        w: Weights of shape (C_out, C_in, K, K)
        b: Bias of shape (C_out,) or None
        stride: Stride
        padding: Zero-padding on H and W

    Returns:
        y: Output tensor of shape (N, C_out, H_out, W_out)
    """
    if x.ndim != 4:
        raise ValueError(f"x must be NCHW (4D). Got shape {x.shape}")
    if w.ndim != 4:
        raise ValueError(f"w must be (C_out, C_in, K, K). Got shape {w.shape}")

    N, C_in, H, W = x.shape
    C_out, C_in_w, K, K2 = w.shape
    
    if C_in_w != C_in:
        raise ValueError("Mismatch: w C_in does not match x C_in")

    # Pad input
    if padding > 0:
        x = np.pad(
            x,
            pad_width=((0, 0), (0, 0), (padding, padding), (padding, padding)),
            mode="constant",
            constant_values=0.0,
        )

    H_pad, W_pad = x.shape[2], x.shape[3]
    H_out = (H_pad - K) // stride + 1
    W_out = (W_pad - K) // stride + 1

    y = np.zeros((N, C_out, H_out, W_out), dtype=x.dtype)

    # Convolution
    for n in range(N):
        for co in range(C_out):
            for i in range(H_out):
                for j in range(W_out):
                    h0 = i * stride
                    w0 = j * stride
                    patch = x[n, :, h0:h0 + K, w0:w0 + K]   # (C_in, K, K)
                    y[n, co, i, j] = np.sum(patch * w[co, :, :, :])

    return y


rng = np.random.default_rng(0)

x = rng.normal(size=(2, 3, 8, 8)).astype(np.float32)          # N=2, C=3, H=W=8
w = rng.normal(size=(4, 3, 3, 3)).astype(np.float32)          # C_out=4, K=3

y = conv2d_forward_numpy(x, w, stride=1, padding=1)       # keep size 8x8
y = np.maximum(y, 0.0)
print("Output shape:", y.shape)


## II. nn.Conv


In [None]:
nn.Conv2d(
    in_channels=3,
    out_channels=32,
    kernel_size=3,
    stride=1,
    padding=1,
    bias=False,
)

- Conv1x1 (?)

- Conv3x3

- Conv5x5

- Conv7x7


## III. Shortly on Residual Blocks

In [None]:
class PlainBlock(nn.Module):
    """
    A small MLP-like block: x -> Linear -> ReLU -> Linear
    Computes H(x).
    """
    def __init__(self, d: int):
        super().__init__()
        self.fc1 = nn.Linear(d, d, bias=False)
        self.fc2 = nn.Linear(d, d, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(F.relu(self.fc1(x)))


class ResidualBlock(nn.Module):
    """
    Residual block: y = x + F(x), where F is a small stack.
    """
    def __init__(self, d: int):
        super().__init__()
        self.f = PlainBlock(d)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.f(x)


# ----------------------------
# Helpers to measure gradients
# ----------------------------

def grad_norm(t: torch.Tensor) -> float:
    if t.grad is None:
        return float("nan")
    return float(t.grad.norm().detach().cpu())


def weight_norm(m: nn.Module) -> float:
    s = 0.0
    for p in m.parameters():
        s += float(p.detach().pow(2).sum().cpu())
    return math.sqrt(s)


def cosine(a: torch.Tensor, b: torch.Tensor) -> float:
    a = a.flatten()
    b = b.flatten()
    return float((a @ b).detach().cpu() / (a.norm() * b.norm() + 1e-12))


# ----------------------------
# Experiment
# ----------------------------

def run_once(block: nn.Module, *, d=128, batch=32, small_init=1e-3, seed=0) -> dict:
    torch.manual_seed(seed)

    x = torch.randn(batch, d, requires_grad=True)

    # small init
    with torch.no_grad():
        for p in block.parameters():
            p.zero_()
            p.add_(small_init * torch.randn_like(p))

    saved = {}

    def save_grad(name: str):
        def hook(grad: torch.Tensor):
            saved[name] = grad.detach()
        return hook

    # Forward pass that *uses the hooked tensor*
    if isinstance(block, ResidualBlock):
        # F(x) = fc2(relu(fc1(x)))
        a1 = block.f.fc1(x)
        a1.retain_grad()
        a1.register_hook(save_grad("grad_a1"))

        h1 = F.relu(a1)
        fx = block.f.fc2(h1)
        out = x + fx
    else:
        # H(x) = fc2(relu(fc1(x)))
        a1 = block.fc1(x)
        a1.retain_grad()
        a1.register_hook(save_grad("grad_a1"))

        h1 = F.relu(a1)
        out = block.fc2(h1)

    loss = out.pow(2).mean()
    loss.backward()

    info = {
        "loss": float(loss.detach().cpu()),
        "x_grad_norm": grad_norm(x),
        "a1_grad_norm": float(a1.grad.norm().detach().cpu()),  # use a1.grad directly
        "w_norm": weight_norm(block),
    }

    total = 0.0
    for p in block.parameters():
        if p.grad is not None:
            total += float(p.grad.detach().pow(2).sum().cpu())
    info["param_grad_norm"] = math.sqrt(total)

    return info


In [None]:
d = 256
batch = 64
small_init = 1e-4  # try 1e-2, 1e-4, 1e-6 to see the difference more clearly
seed = 123

plain = PlainBlock(d)
resid = ResidualBlock(d)

plain_info = run_once(plain, d=d, batch=batch, small_init=small_init, seed=seed)
resid_info = run_once(resid, d=d, batch=batch, small_init=small_init, seed=seed)

print("=== Plain block H(x) ===")
for k, v in plain_info.items():
    print(f"{k:>16s}: {v:.6g}")

print("\n=== Residual block x + F(x) ===")
for k, v in resid_info.items():
    print(f"{k:>16s}: {v:.6g}")