In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import v2
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib
from collections import defaultdict
from update_ratio_tracker import UpdateRatioTracker
from cifar_common import (
    DEVICE,
    get_data,
    train_step,
    train,
    evaluate,
    check_model_outputs,
    plot_loss
)

In [None]:
class DenseLayer(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super().__init__()
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
        
    def forward(self, x):
        out = self.bn(x)
        out = self.relu(out)
        out = self.conv1(out)
        return out

class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers):
        super().__init__()
        self.layers = []
        for i in range(num_layers):
            layer = DenseLayer(in_channels + i * growth_rate, growth_rate)
            self.layers.append(layer)
        self.layers = nn.ModuleList(self.layers)
        
    def forward(self, x):
        out = x
        for layer in self.layers:
            new = layer(out)
            out = torch.cat([out, new], dim = 1)
        return out

class TransitionLayer(nn.Module):
    def __init__(self, in_channels, theta):
        super().__init__()
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_channels, int(theta * in_channels), kernel_size=1, stride=1, padding=0, bias=False)
        self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
    
    def forward(self, x):
        out = self.bn(x)
        out = self.relu(out)
        out = self.conv(out)
        out = self.avgpool(out)
        return out
        

class DenseNet(nn.Module):
    def __init__(self, in_channels, theta, growth_rate, num_layers, num_classes):
        super().__init__()
        self.conv = nn.Conv2d(3, in_channels, kernel_size=3, padding=1, stride=1, bias=False)
        self.db1 = DenseBlock(in_channels, growth_rate, num_layers)
        self.transition1 = TransitionLayer(in_channels + num_layers*growth_rate, theta)

        out_channels_1 = int((in_channels + num_layers*growth_rate)*theta)

        self.db2 = DenseBlock(out_channels_1, growth_rate, num_layers)
        self.transition2 = TransitionLayer(out_channels_1 + num_layers*growth_rate, theta)

        out_channels_2 = int((out_channels_1 + num_layers*growth_rate)*theta)

        self.db3 = DenseBlock(out_channels_2, growth_rate, num_layers)

        out_channels_3 = int((out_channels_2 + num_layers*growth_rate))

        self.bn = nn.BatchNorm2d(out_channels_3)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(out_channels_3, num_classes)

    def forward(self, x):
        out = self.conv(x)
        out = self.db1(out)
        out = self.transition1(out)
        out = self.db2(out)
        out = self.transition2(out)
        out = self.db3(out)
        out = self.bn(out)
        out = self.relu(out)
        out = self.pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [None]:
model = DenseNet(in_channels=32, theta = 0.5, growth_rate = 12, num_layers = 6, num_classes = 10)
model = model.to(DEVICE)
num_epochs = 4
tracker = UpdateRatioTracker(log_every=00)

# Reset BatchNorm running statistics to ensure fresh start
# This is important if the model was previously trained or if you're re-running cells
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.reset_running_stats()

model.train()  # Ensure model is in training mode
train_loader, test_loader = get_data(500, num_workers=8, prefetch_factor=8)
criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1,  momentum=0.9, weight_decay=4e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-4)


size = 0
for param in model.parameters():
    size += param.numel()
print(f"Model size: {size/1e6:.2f}M")

# Diagnostic: Check initial model outputs before training
# Expected: Loss should be around 2.3 (which is -log(1/10) for random 10-class prediction)
# If loss is much lower, BatchNorm running stats might be from a previous run
print("=== Checking initial model state ===")
check_model_outputs(model, train_loader, criterion)
print("\nExpected initial loss: ~2.3 (random guessing)")
print("If loss is much lower, the model may have been trained before or BatchNorm stats are stale.\n")


In [None]:
# Train the model with live loss plotting
# The plot will update in real-time as training progresses
losses, steps = train(model, criterion, optimizer, scheduler, train_loader, num_epochs, tracker)

# Optional: Create a final static plot if needed
plot_loss(losses, steps)

In [None]:
evaluate(model, test_loader, per_class_accuracy=True)

# print optimizer current lr
print(f"Current learning rate: {optimizer.param_groups[0]['lr']}")

losses[0:10]

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_update_ratios(tracker, names, max_steps=None):
    plt.figure(figsize=(100, 40))

    for name in names:
        vals = tracker.history[name]
        if max_steps:
            vals = vals[:max_steps]
        plt.plot(vals, label=name)
    plt.yscale("log")
    plt.xlabel("training step")
    plt.ylabel("||Î”W|| / ||W||")
    plt.legend()
    plt.title("Relative update size over training")
    plt.show()


plot_update_ratios(tracker, list(tracker.history.keys()))

In [None]:
import matplotlib.pyplot as plt

@torch.no_grad()
def collect_weights(model):
    all_w = []
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)):
            all_w.append(m.weight.detach().flatten().float().cpu())
    return torch.cat(all_w) if all_w else torch.tensor([])

w = collect_weights(model)
plt.figure()
plt.hist(w.numpy(), bins=200)
plt.title("All Conv/Linear weights")
plt.xlabel("weight value")
plt.ylabel("count")
plt.show()

In [None]:
@torch.no_grad()
def layer_weight_tensors(model):
    items = []
    for name, m in model.named_modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            w = m.weight.detach().flatten().float().cpu()
            items.append((name, w))
    # sort by number of params
    items.sort(key=lambda t: t[1].numel(), reverse=True)
    return items

items = layer_weight_tensors(model)

for name, w in items[:6]:  # first 6 largest layers
    plt.figure()
    plt.hist(w.numpy(), bins=150)
    plt.title(f"{name}.weight  (n={w.numel()})")
    plt.xlabel("weight value")
    plt.ylabel("count")
    plt.show()


In [None]:
@torch.no_grad()
def grad_over_weight(model):
    out = []
    for name, p in model.named_parameters():
        if p.grad is None: 
            continue
        g = p.grad.detach().flatten().float().cpu()
        w = p.data.detach().flatten().float().cpu()
        out.append((name, g.norm().item() / (w.norm().item() + 1e-12)))
    out.sort(key=lambda t: t[1])
    return out

# after backward():
ratios = grad_over_weight(model)
print("Smallest grad/weight ratios:", ratios[:10])
print("Largest grad/weight ratios:", ratios[-10:])