In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
from torchvision import transforms, datasets

from torchinfo import summary


In [None]:
# Load in Data from disk
CIFAR = "CIFAR10"   # or "CIFAR100"
num_classes = 10 if CIFAR == "CIFAR10" else 100
Dataset = getattr(datasets, CIFAR)

data_path = "/home/bam/bam_ws/src/bam_brain/bam_gym/data_downloads"

# Load CIFAR without normalization so we can compute stats
dataset = Dataset(
    root=data_path, train=True, download=True,
    transform=transforms.ToTensor()
)

loader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=False, num_workers=4)

mean = 0.0
std = 0.0
n_samples = 0

for images, _ in loader:
    # images shape: [batch, channels, height, width]
    batch_samples = images.size(0)
    images = images.view(batch_samples, images.size(1), -1)  # [batch, C, H*W]
    
    mean += images.mean(2).sum(0)   # sum over batch
    std += images.std(2).sum(0)
    n_samples += batch_samples

mean /= n_samples
std /= n_samples

print("Mean:", mean)
print("Std:", std)

In [None]:
# Create Data loaders with transforms

train_transforms = transforms.Compose([
    # transforms.RandomCrop(32, padding=4),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])
test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

train_dataset = Dataset(root=data_path, train=True, download=True, transform=train_transforms)
test_dataset  = Dataset(root=data_path, train=False, download=True, transform=test_transforms)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
test_dataloader  = torch.utils.data.DataLoader(test_dataset,  batch_size=256, shuffle=False, num_workers=4, pin_memory=True)

# Why are larger batch sizes used for testing?

# At testing we are doing stochastic gradient descent (stochastic, beacuse we estimating the gradient, as its not the full datset). This bit of noise in practice doesn't seem to harm, and can actually be helpful
# Rule of thumb is to load the largest batch size that fits into GPU memory, while still leaving room for gradients. Its a hyper parameter to tune though, who knows what works best in practice for the loss landscape

# For testing, reqs. are different. No gradients, so less memory requirment, and we just want to go through as fast as possible. So generally you can use a larger batch size.

# See chat: https://chatgpt.com/c/68b58cae-31f4-8332-980b-5ceb5065d7ce




In [None]:
# Create model

def resnet18_cifar(num_classes=10):
    model = tv.models.resnet18(weights=None)  # start from ImageNet-style architecture
    # Replace the 7x7 stride-2 conv + maxpool with 3x3 stride-1 and no pool:
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()
    # Adjust classifier head:
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet18_cifar(num_classes).to(device)
summary(model, input_size=(1, 3, 32, 32))

In [None]:
# Create Model with pretrained weights

def resnet18_cifar(num_classes=10, pretrained=True):
    # Load pretrained weights from ImageNet
    model = tv.models.resnet18(weights="IMAGENET1K_V1" if pretrained else None)

    # ---- 1) Replace the first conv (7x7 -> 3x3, stride 2 -> stride 1) ----
    old_conv = model.conv1
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

    if pretrained:
        # Adapt pretrained conv1 weights (old shape: [64, 3, 7, 7])
        with torch.no_grad():
            w = old_conv.weight
            # Center-crop the 7x7 kernels to 3x3
            model.conv1.weight.copy_(w[:, :, 2:5, 2:5])

    # ---- 2) Remove the maxpool (not needed for CIFAR’s 32x32 images) ----
    model.maxpool = nn.Identity()

    # ---- 3) Replace classifier ----
    model.fc = nn.Linear(model.fc.in_features, num_classes)

    return model

# Example usage
device = "cuda" if torch.cuda.is_available() else "cpu"
model = resnet18_cifar(num_classes=10, pretrained=True).to(device)
summary(model, input_size=(1, 3, 32, 32))

In [None]:
# Print model with no freezing
print(model)

for name, param in model.named_parameters():
    print(f"{name:30} requires_grad={param.requires_grad}")

In [None]:
# Freeze Layers
def freeze_module(module, freeze_bn: bool = False):
    """
    Freeze a module by disabling gradients for its parameters.

    Args:
        module (nn.Module): The PyTorch module to freeze.
        freeze_bn (bool): If True, also freeze BatchNorm layers inside:
            - set requires_grad=False for gamma/beta
            - put BN in eval() mode (stops running stats updates)
    """
    # freeze all params in this module
    for p in module.parameters():
        p.requires_grad = False

    if freeze_bn:
        # walk over all nested submodules
        for m in module.modules():
            if isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
                for p in m.parameters():
                    p.requires_grad = False
                m.eval()

freeze_module(model.layer1, freeze_bn=True)
freeze_module(model.layer2, freeze_bn=True)
freeze_module(model.layer3, freeze_bn=True)

for name, param in model.named_parameters():
    print(f"{name:30} requires_grad={param.requires_grad}")

In [None]:
# Check that BN layers are frozen properly
for name, module in model.named_modules():
    if isinstance(module, (nn.BatchNorm2d, nn.SyncBatchNorm)):
        print(f"{name:30} training={module.training}  "
              f"weight_grad={module.weight.requires_grad}  "
              f"bias_grad={module.bias.requires_grad}")
        
"""
training=False means it’s in eval mode (running stats not updating).

requires_grad=False for weight/bias means γ/β won’t be trained.
"""

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
loss_fn = nn.CrossEntropyLoss()



In [None]:
def train_one_epoch():
    model.train()

    size = len(train_dataloader.dataset)

    total, correct, loss_sum = 0, 0, 0.0
    for batch, (x, y) in enumerate(train_dataloader):
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)

        logits = model(x)
        loss = loss_fn(logits, y)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()


        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(x)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


In [None]:
def evaluate():
    model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    with torch.no_grad():
        for x, y in test_dataloader:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            
            logits = model(x)
            loss = loss_fn(logits, y)

            total += x.size(0)
            loss_sum += loss.item() * x.size(0)
            correct += (logits.argmax(1) == y).sum().item()

    print(f"Test Error: \n Accuracy: {(100*correct / total):>0.1f}%, Avg loss: {loss_sum / total:>8f} \n")

    return loss_sum / total, correct / total



In [None]:
for epoch in range(10):
    print(f"Epoch {epoch+1}\n-------------------------------")
    train_one_epoch()
    evaluate()