# Standard CIFAR10/100 Setup
Just a place to hold the standard models and hyperparams for cifar in one place so I don't have to keep looking up epochs and norm values.

In [1]:
import torch
from torch.cuda import amp
from torch import nn, optim
from torchvision import transforms
from torchvision.datasets import CIFAR100, CIFAR10
from torch.utils.data import DataLoader
from torchvision.models import vgg16_bn, resnet50
from torchmetrics import Accuracy
from collections import defaultdict

## Data

In [2]:
NORMVALS = {
    "mean": {
        "cifar10": [0.4914, 0.4822, 0.4465],
        "cifar100": [0.5071, 0.4867, 0.4408],
    },
    "std": {
        "cifar10": [0.2023, 0.1994, 0.2010],
        "cifar100": [0.2675, 0.2565, 0.2761],
    },
}

DATASET_ROOT = "/data/datasets"


def get_transforms(dataset: str):
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(NORMVALS["mean"][dataset], NORMVALS["std"][dataset]),
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(NORMVALS["mean"][dataset], NORMVALS["std"][dataset]),
        ]
    )
    return train_transform, test_transform


def get_loaders(dataset: str, batch_size=64):
    train_transform, test_transform = get_transforms(dataset)
    ds_cls = CIFAR10 if dataset == "cifar10" else CIFAR100
    ds_path = DATASET_ROOT + "/" + dataset.upper()

    train_ds = ds_cls(root=ds_path, transform=train_transform, train=True)
    test_ds = ds_cls(root=ds_path, transform=test_transform, train=False)

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
    )
    test_loader = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True
    )

    return train_loader, test_loader

## Model
- wrapped torchvision models - vgg16, resnet50

In [3]:
def get_model(dataset="cifar10", model="vgg16", init=True):
    """
    Model factory that wraps classic torchvision models for cifar datasets
    torchvision models are defined for imagenet and require slight adjustments to work on cifar.
    When quickly prototyping one does not always want to define models from scratch.

    As of now does not include the "correct" resnets for cifar e.g resnet32
    """

    assert dataset in ["cifar10", "cifar100"]
    assert model in ["vgg16", "resnet50"]
    n_classes = 10 if dataset == "cifar10" else 100

    if model == "vgg16":
        model = vgg16_bn(pretrained=False)
        model.features = model.features[:-1]  # dropping the last maxpool
        model.avgpool = nn.AvgPool2d(kernel_size=2)
        model.classifier = nn.Sequential(nn.Linear(512, n_classes))
    if model == "resnet50":
        model = resnet50(pretrained=False)
        model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        model.relu = nn.Sequential()
        model.maxpool = nn.Sequential()
        repr_size = 512 * model.layer4[0].expansion  # resnet18/34: 512; resnet50:2048
        model.fc = nn.Linear(in_features=repr_size, out_features=n_classes, bias=True)
    if init:
        model = init_model(model)
    return model

def init_model(model):
    for m in [m for m in model.modules() if isinstance(m, nn.Conv2d)]:
        nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")

    for m in [m for m in model.modules() if isinstance(m, nn.Linear)]:
        m.weight.data.normal_(0, 0.01)
        m.bias.data.zero_()

    for m in [m for m in model.modules() if isinstance(m, nn.BatchNorm2d)]:
        nn.init.constant_(m.weight, val=1.0)
        nn.init.constant_(m.bias, val=0.0)
    return model

## Metrics
- Just test acc using torchmetrics

In [4]:
@torch.no_grad
def evaluate(loader, model) -> dict:
    device = list(model.parameters())[0].device
    n_classes = len(loader.dataset.classes)
    metric = Accuracy(task="multiclass", num_classes=n_classes).to(device)
    model.eval()
    for X, y in loader:
        with amp.autocast():
            X, y = X.to(device), y.to(device)
            out = model(X)
            metric.update(out, y)
    model.train()
    res = {"acc": metric.compute().item()}
    return res

## Trainer

In [5]:
def trainer(dataset="cifar10", model="resnet50", n_epochs=200, batch_size=64, device="cuda:0"):
    logs = defaultdict(list)
    gscaler = torch.cuda.amp.GradScaler()

    model = get_model(dataset, model).to(device)
    #model = torch.compile(model)

    train_loader, test_loader = get_loaders(dataset, batch_size)

    criterion = nn.CrossEntropyLoss()
    optimiser = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, nesterov=True, weight_decay=5e-4)
    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=n_epochs) # 200 epochs
    #lr_scheduler = optim.lr_scheduler.MultiStepLR(optimiser, milestones=[80, 120], gamma=0.1) # 160 epochs

    for epoch in range(n_epochs):
        for X, y in train_loader:
            X, y = X.to(device), y.to(device).long()
                
            with amp.autocast():
                out = model(X)
                loss = criterion(out, y)

            gscaler.scale(loss).backward()
            gscaler.step(optimiser)
            gscaler.update()
            optimiser.zero_grad()

        logs["acc"].append(evaluate(test_loader, model)["acc"])
        lr_scheduler.step()
    return logs


## Test

In [11]:
res = trainer()
print(f"Best Acc: {max(res['acc'])}")

Best Acc: 0.9476000070571899
