In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import v2
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib
from cifar_common import (
    DEVICE,
    get_data,
    train_step,
    train,
    evaluate,
    check_model_outputs,
    plot_loss
)

In [None]:
class ConvNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # nn.Dropout(p=0.1),
        )
        # B, 3, 32, 32 -> B, 64, 16, 16
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # nn.Dropout(p=0.1),
        )
        # B, 64, 16, 16 -> B, 128, 8, 8
        self.layer3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # nn.Dropout(p=0.1),
        )
        # B, 128, 8, 8 -> B, 256, 4, 4  

        self.gap = nn.AdaptiveAvgPool2d(1)

        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.gap(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

In [None]:
model = ConvNet(10)
model = model.to(DEVICE)
num_epochs = 40

# Reset BatchNorm running statistics to ensure fresh start
# This is important if the model was previously trained or if you're re-running cells
for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.reset_running_stats()

model.train()  # Ensure model is in training mode
train_loader, test_loader = get_data(100, num_workers=4, prefetch_factor=4)
criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1,  momentum=0.9, weight_decay=4e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-4)


size = 0
for param in model.parameters():
    size += param.numel()
print(f"Model size: {size/1e6:.2f}M")

# Diagnostic: Check initial model outputs before training
# Expected: Loss should be around 2.3 (which is -log(1/10) for random 10-class prediction)
# If loss is much lower, BatchNorm running stats might be from a previous run
print("=== Checking initial model state ===")
check_model_outputs(model, train_loader, criterion)
print("\nExpected initial loss: ~2.3 (random guessing)")
print("If loss is much lower, the model may have been trained before or BatchNorm stats are stale.\n")


In [None]:
# Train the model with live loss plotting
# The plot will update in real-time as training progresses
losses, steps = train(model, criterion, optimizer, scheduler, train_loader, num_epochs, tracker=None)

# Optional: Create a final static plot if needed
plot_loss(losses, steps)

In [None]:
evaluate(model, test_loader)

# print optimizer current lr
print(f"Current learning rate: {optimizer.param_groups[0]['lr']}")

losses[0:10]