In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# How Residual Block works internally

In [None]:
class PlainBlock(nn.Module):
    """
    A small MLP-like block: x -> Linear -> ReLU -> Linear
    Computes H(x).
    """
    def __init__(self, d: int):
        super().__init__()
        self.fc1 = nn.Linear(d, d, bias=False)
        self.fc2 = nn.Linear(d, d, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(F.relu(self.fc1(x)))


class ResidualBlock(nn.Module):
    """
    Residual block: y = x + F(x), where F is a small stack.
    """
    def __init__(self, d: int):
        super().__init__()
        self.f = PlainBlock(d)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.f(x)


# ----------------------------
# Helpers to measure gradients
# ----------------------------

def grad_norm(t: torch.Tensor) -> float:
    if t.grad is None:
        return float("nan")
    return float(t.grad.norm().detach().cpu())


def weight_norm(m: nn.Module) -> float:
    s = 0.0
    for p in m.parameters():
        s += float(p.detach().pow(2).sum().cpu())
    return math.sqrt(s)


def cosine(a: torch.Tensor, b: torch.Tensor) -> float:
    a = a.flatten()
    b = b.flatten()
    return float((a @ b).detach().cpu() / (a.norm() * b.norm() + 1e-12))


# ----------------------------
# Experiment
# ----------------------------

def run_once(block: nn.Module, *, d=128, batch=32, small_init=1e-3, seed=0) -> dict:
    torch.manual_seed(seed)

    x = torch.randn(batch, d, requires_grad=True)

    # small init
    with torch.no_grad():
        for p in block.parameters():
            p.zero_()
            p.add_(small_init * torch.randn_like(p))

    saved = {}

    def save_grad(name: str):
        def hook(grad: torch.Tensor):
            saved[name] = grad.detach()
        return hook

    # Forward pass that *uses the hooked tensor*
    if isinstance(block, ResidualBlock):
        # F(x) = fc2(relu(fc1(x)))
        a1 = block.f.fc1(x)
        a1.retain_grad()
        a1.register_hook(save_grad("grad_a1"))

        h1 = F.relu(a1)
        fx = block.f.fc2(h1)
        out = x + fx
    else:
        # H(x) = fc2(relu(fc1(x)))
        a1 = block.fc1(x)
        a1.retain_grad()
        a1.register_hook(save_grad("grad_a1"))

        h1 = F.relu(a1)
        out = block.fc2(h1)

    loss = out.pow(2).mean()
    loss.backward()

    info = {
        "loss": float(loss.detach().cpu()),
        "x_grad_norm": grad_norm(x),
        "a1_grad_norm": float(a1.grad.norm().detach().cpu()),  # use a1.grad directly
        "w_norm": weight_norm(block),
    }

    total = 0.0
    for p in block.parameters():
        if p.grad is not None:
            total += float(p.grad.detach().pow(2).sum().cpu())
    info["param_grad_norm"] = math.sqrt(total)

    return info


In [None]:
d = 256
batch = 64
small_init = 1e-4  # try 1e-2, 1e-4, 1e-6 to see the difference more clearly
seed = 123

plain = PlainBlock(d)
resid = ResidualBlock(d)

plain_info = run_once(plain, d=d, batch=batch, small_init=small_init, seed=seed)
resid_info = run_once(resid, d=d, batch=batch, small_init=small_init, seed=seed)

print("=== Plain block H(x) ===")
for k, v in plain_info.items():
    print(f"{k:>16s}: {v:.6g}")

print("\n=== Residual block x + F(x) ===")
for k, v in resid_info.items():
    print(f"{k:>16s}: {v:.6g}")

In [None]:
torch.manual_seed(seed)
x = torch.randn(batch, d, requires_grad=True)

# Re-init weights small again for a controlled comparison
def small_init_(m: nn.Module):
    with torch.no_grad():
        for p in m.parameters():
            p.zero_()
            p.add_(small_init * torch.randn_like(p))

plain2 = PlainBlock(d)
resid2 = ResidualBlock(d)
small_init_(plain2)
small_init_(resid2)

# Plain gradient
y_plain = plain2(x)
loss_plain = y_plain.pow(2).mean()
loss_plain.backward(retain_graph=True)
gx_plain = x.grad.detach().clone()

x.grad.zero_()

# Residual gradient
y_resid = resid2(x)
loss_resid = y_resid.pow(2).mean()
loss_resid.backward()
gx_resid = x.grad.detach().clone()

print("\n=== Input gradient comparison (same x) ===")
print(f"||dL/dx|| plain : {gx_plain.norm().item():.6g}")
print(f"||dL/dx|| resid : {gx_resid.norm().item():.6g}")
print(f"cos(dL/dx_plain, dL/dx_resid): {cosine(gx_plain, gx_resid):.6g}")
