In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import v2
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib
from collections import defaultdict
from update_ratio_tracker import UpdateRatioTracker
from cifar_common import (
    DEVICE,
    get_data,
    train,
    evaluate,
    check_model_outputs,
    plot_loss
)

In [None]:
class PreActResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)

        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)

        self.shortcut = nn.Identity()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False),
            )

        # Dimensionality
        # [B, I, H, W] -- BatchNorm
        # [B, I, H, W] -- Relu
        # [B, O, H/2, W/2] -- Conv
        # [B, O, H/2, W/2] -- BatchNorm
        # [B, O, H/2, W/2] -- Relu
        # [B, O, H/2, W/2] -- Conv
        # [B, O, H/2, W/2] -- Shortcut

    def forward(self, x):
        out = self.bn1(x)
        out = self.relu(out)

        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        
        out += self.shortcut(x)
        return out
        
        

class ResNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        init_conv_channels = 32
        
        self.layer1 = nn.Sequential(
            PreActResidualBlock(in_channels=3, out_channels=init_conv_channels, stride=1),
            PreActResidualBlock(in_channels=init_conv_channels, out_channels=init_conv_channels, stride=1),
            PreActResidualBlock(in_channels=init_conv_channels, out_channels=init_conv_channels, stride=1)
        )

        self.layer2 = nn.Sequential(
            PreActResidualBlock(in_channels=init_conv_channels, out_channels=init_conv_channels * 2, stride=2),
            PreActResidualBlock(in_channels=init_conv_channels * 2, out_channels=init_conv_channels * 2, stride=1),
            PreActResidualBlock(in_channels=init_conv_channels * 2, out_channels=init_conv_channels * 2, stride=1)
        )
        # B, 3, 32, 32 -> B, 64, 16, 16
        
        self.layer3 = nn.Sequential(
            PreActResidualBlock(in_channels=init_conv_channels * 2, out_channels=init_conv_channels * 4, stride=2),
            PreActResidualBlock(in_channels=init_conv_channels * 4, out_channels=init_conv_channels * 4, stride=1),
            PreActResidualBlock(in_channels=init_conv_channels * 4, out_channels=init_conv_channels * 4, stride=1)
        )
        
        self.bn = nn.BatchNorm2d(init_conv_channels * 4)
        self.relu = nn.ReLU(inplace=True)
        
        # B, 64, 16, 16 -> B, 128, 8, 8  

        self.gap = nn.AdaptiveAvgPool2d(1)

        self.fc = nn.Linear(init_conv_channels * 4, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.bn(out)
        out = self.relu(out)
        out = self.gap(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out


In [None]:
num_classes = 10
model = ResNet(num_classes)
model = model.to(DEVICE)
num_epochs = 4
tracker = UpdateRatioTracker(log_every=100)

# Reset BatchNorm running statistics to ensure fresh start
# This is important if the model was previously trained or if you're re-running cells
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.reset_running_stats()

model.train()  # Ensure model is in training mode
train_loader, test_loader = get_data(100, num_workers=8, prefetch_factor=8)
criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1,  momentum=0.9, weight_decay=4e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-4)


size = 0
for param in model.parameters():
    size += param.numel()
print(f"Model size: {size/1e6:.2f}M")

# Diagnostic: Check initial model outputs before training
# Expected: Loss should be around 2.3 (which is -log(1/10) for random 10-class prediction)
# If loss is much lower, BatchNorm running stats might be from a previous run
print("=== Checking initial model state ===")
check_model_outputs(model, train_loader, criterion)
print("\nExpected initial loss: ~2.3 (random guessing)")
print("If loss is much lower, the model may have been trained before or BatchNorm stats are stale.\n")


In [None]:
# Train the model with live loss plotting
# The plot will update in real-time as training progresses
losses, steps = train(model, criterion, optimizer, scheduler, train_loader, num_epochs, tracker)

# Optional: Create a final static plot if needed
plot_loss(losses, steps)

In [None]:
evaluate(model, test_loader, per_class_accuracy=True, num_classes=num_classes)

# print optimizer current lr
print(f"Current learning rate: {optimizer.param_groups[0]['lr']}")

losses[0:10]

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_update_ratios(tracker, names, max_steps=None):
    plt.figure(figsize=(100, 40))

    for name in names:
        vals = tracker.history[name]
        if max_steps:
            vals = vals[:max_steps]
        plt.plot(vals, label=name)
    plt.yscale("log")
    plt.xlabel("training step")
    plt.ylabel("||Î”W|| / ||W||")
    plt.legend()
    plt.title("Relative update size over training")
    plt.show()


plot_update_ratios(tracker, list(tracker.history.keys()))

In [None]:
import matplotlib.pyplot as plt

@torch.no_grad()
def collect_weights(model):
    all_w = []
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)):
            all_w.append(m.weight.detach().flatten().float().cpu())
    return torch.cat(all_w) if all_w else torch.tensor([])

w = collect_weights(model)
plt.figure()
plt.hist(w.numpy(), bins=200)
plt.title("All Conv/Linear weights")
plt.xlabel("weight value")
plt.ylabel("count")
plt.show()

In [None]:
@torch.no_grad()
def layer_weight_tensors(model):
    items = []
    for name, m in model.named_modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            w = m.weight.detach().flatten().float().cpu()
            items.append((name, w))
    # sort by number of params
    items.sort(key=lambda t: t[1].numel(), reverse=True)
    return items

items = layer_weight_tensors(model)

for name, w in items[:6]:  # first 6 largest layers
    plt.figure()
    plt.hist(w.numpy(), bins=150)
    plt.title(f"{name}.weight  (n={w.numel()})")
    plt.xlabel("weight value")
    plt.ylabel("count")
    plt.show()


In [None]:
@torch.no_grad()
def grad_over_weight(model):
    out = []
    for name, p in model.named_parameters():
        if p.grad is None: 
            continue
        g = p.grad.detach().flatten().float().cpu()
        w = p.data.detach().flatten().float().cpu()
        out.append((name, g.norm().item() / (w.norm().item() + 1e-12)))
    out.sort(key=lambda t: t[1])
    return out

# after backward():
ratios = grad_over_weight(model)
print("Smallest grad/weight ratios:", ratios[:10])
print("Largest grad/weight ratios:", ratios[-10:])