# CIFAR ResNets
- I couldn't find a unified factory for ResNets that combines Imagenet and CIFAR model designs, so here is the cleanest version I came up with.
- Avoiding rewriting boilerplate -> patching torchvision models to fit CIFAR specification.
    - In order to make Imagenet models CIFAR compatible: replace the first conv, remove the maxpooling and adjust the fc layer.
    - To build CIFAR-native models: instantiate an Imagenet model and rebuild the layers to specification, a lot cleaner than rewriting the _make_layers.
- Comparing to Akamaster implementation: Same output, Same training perf
- Benchmarking models on CIFAR10/CIFAR100
- Benchmarking training strategies/ batch size for ResNet56 on CIFAR100
- Other notes:
    - Optimisation (opt + scheduler) and batch_size are the major hyper-parameters
    - mixed-precision, num_workers, deterministic/benchmarking don't impact acc results in pytorch 2
    - Small Arch changes, initialisation, not very impactful
    - Very difficult to get Resnet56 to 73% on CIFAR100 - only with bs=64 and a 200-epoch optimsation


## step-200 bs=64

| ds/model   | cifar10 | cifar100 |
|-----------|---------|----------|
| resnet18  | 0.95    | 0.78     |
| resnet20  | 0.92    | 0.68     |
| resnet32  | 0.93    | 0.69     |
| resnet34  | 0.96    | 0.79     |
| resnet44  | 0.92    | 0.70     |
| resnet50  | 0.95    | 0.78     |
| resnet56  | 0.92    | 0.71     |

## step-160 bs=256

| ds/model   | cifar10 | cifar100 |
|------------|---------|----------|
| resnet18   | 0.92    | 0.71     |
| resnet20   | 0.91    | 0.67     |
| resnet32   | 0.91    | 0.67     |
| resnet34   | 0.92    | 0.67     |
| resnet44   | 0.92    | 0.68     |
| resnet50   | 0.91    | 0.65     |
| resnet56   | 0.89    | 0.68     |

## ResNet56 CIFAR100

| bs / opt | 64     | 128    | 256    |
|------------|--------|--------|--------|
| cos-200    | 0.7272 | 0.7162 | 0.7147 |
| step-160   | 0.6997 | 0.6861 | 0.6740 |
| step-200   | 0.7127 | 0.7107 | 0.7117 |


# Implementation

In [1]:
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet

CIFAR_MODELS = ["resnet20", "resnet32", "resnet44", "resnet56", "resnet110", "resnet1202"]
IMAGENET_MODELS = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"]


class ShortcutA(nn.Module):
    """Option A: downsample via stride 2 and zero pad for extra channels"""
    def __init__(self, out_channels):
        super(ShortcutA, self).__init__()
        self.out_channels = out_channels

    def forward(self, x):
        missing = self.out_channels - x.size(1)
        return F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, missing//2, missing//2), "constant", 0)


def get_resnet(model: str, num_classes: int):
    """
    :author: @xapharius
    ResNet factory for CIFAR compatible models. Supports native CIFAR models and patched Imagenet models.
    Weights are not initialised.
    CIFAR Models: resnet{20, 32, 44, 56, 110, 1202}
    Imagenet Models: resnet{18, 34, 50, 101, 152}
    """
    if model in IMAGENET_MODELS:
        model = resnet.__dict__[model]()
        model.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False)
        model.maxpool = nn.Identity()
        model.fc = nn.Linear(512 * model.layer1[0].expansion, num_classes)
        return model
    if model in CIFAR_MODELS:
        n_blocks = (int(model.replace("resnet", "")) - 2) // 6
        model = resnet.resnet18() # layers will be replaced
        model.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False)
        model.bn1 = nn.BatchNorm2d(16)
        model.maxpool = nn.Identity()
        model.inplanes = 16
        model.layer1 = model._make_layer(block=resnet.BasicBlock, planes=16, blocks=n_blocks, stride=1)
        model.layer2 = model._make_layer(block=resnet.BasicBlock, planes=32, blocks=n_blocks, stride=2)
        model.layer2[0].downsample = ShortcutA(32) 
        model.layer3 = model._make_layer(block=resnet.BasicBlock, planes=64, blocks=n_blocks, stride=2)
        model.layer3[0].downsample = ShortcutA(64)
        model.layer4 = nn.Identity()
        model.fc = nn.Linear(64, num_classes)
        return model

# Evaluation
- Using standard crop+flip augmentation
- In the literature there are 3 main optimisation strategies:
    - MultiStep 160 - milestones: [80, 120], gamma: 0.1
    - MultiStep 200 - milestones: [60, 120, 250], gamma: 0.2
    - CosineAnnealing 200
- Major hyper-parameters are
    - Optimisation strategy
    - Batch Size

In [35]:
import torch
import lightning as L
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100
from torchmetrics.classification import Accuracy


NORMVALS = {
    "mean": {
        "cifar10": [0.4914, 0.4822, 0.4465],
        "cifar100": [0.5071, 0.4867, 0.4408],
    },
    "std": {
        "cifar10": [0.2023, 0.1994, 0.2010],
        "cifar100": [0.2675, 0.2565, 0.2761],
    },
}

DATASET_ROOT = "/data/datasets"


def get_datamodule(dataset: str, batch_size=256, num_workers=10):
    ds_cls = CIFAR10 if dataset == "cifar10" else CIFAR100
    ds_path = DATASET_ROOT + "/" + dataset.upper()

    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(NORMVALS["mean"][dataset], NORMVALS["std"][dataset]),
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(NORMVALS["mean"][dataset], NORMVALS["std"][dataset]),
        ]
    )

    train_ds = ds_cls(root=ds_path, transform=train_transform, train=True)
    test_ds = ds_cls(root=ds_path, transform=test_transform, train=False)

    dm = L.LightningDataModule.from_datasets(train_ds, val_dataset=test_ds, batch_size=batch_size, num_workers=num_workers)
    return dm


def init_weights(m):
    if isinstance(m, (nn.Linear, nn.Conv2d)):
        torch.nn.init.kaiming_normal_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)


class LitModel(L.LightningModule):

    def __init__(self, model: nn.Module, opt_strat="step-200"):
        super().__init__()
        self.model = model.apply(init_weights)
        self.criterion = nn.CrossEntropyLoss()
        num_classes = [m for m in model.modules() if isinstance(m, nn.Linear)][-1].out_features
        self.metric = Accuracy(num_classes=num_classes, task="multiclass")
        self.opt_strat = opt_strat
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        out = self.model(x)
        loss = self.criterion(out, y)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        out = self.model(x)
        self.log("val_acc", self.metric(out, y), prog_bar=True)

    def configure_optimizers(self):
        if self.opt_strat == "step-160":
            optimizer = torch.optim.SGD(self.parameters(), lr=0.1, weight_decay=1e-4, momentum=0.9)
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 120], gamma=0.1)
        elif self.opt_strat == "step-200":
            optimizer = torch.optim.SGD(self.parameters(), lr=1e-1, weight_decay=5e-4, momentum=0.9)
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)
        elif self.opt_strat == "cos-200":
            optimizer = torch.optim.SGD(self.parameters(), lr=0.1, weight_decay=5e-4, momentum=0.9)
            lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
        return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
    

def train(model: nn.Module, dataset: str="cifar100", opt_strat="step160", batch_size=64, seed=0):
    L.seed_everything(seed, workers=True)
    datamodule = get_datamodule(dataset, batch_size=batch_size)
    module = LitModel(model, opt_strat=opt_strat)

    trainer = L.Trainer(
        max_epochs=int(opt_strat.split("-")[-1]),
        enable_checkpointing=False,
        enable_model_summary=False,
        logger=None,
        precision="16-mixed",
        deterministic=True,
    )
    trainer.fit(module, datamodule=datamodule)
    return trainer.logged_metrics["val_acc"].item()

# Vs Akamaster
- Yerlan Idelbayev's implementation seems to be the standard reference
- Comparing by:
    - making sure both networks are initialised using the same random seed, and then can produce the same output given the same input
    - training loop to get same perf

In [9]:
import torch
from torch import nn

import httpimport
import pandas as pd

def get_resnet_akamaster(model: str, num_classes: int = 10):
    """CIFAR-native models"""
    with httpimport.github_repo("akamaster", "pytorch_resnet_cifar10"):
        import resnet as resnet_akamaster
        model = resnet_akamaster.__dict__[model]()
        model.linear = nn.Linear(model.linear.in_features, num_classes)
        return model
    
akamaster_resnet = get_resnet_akamaster("resnet56")
torch.manual_seed(0)
akamaster_resnet = akamaster_resnet.apply(init_weights)

my_resnet = get_resnet("resnet56", 10)
torch.manual_seed(0)
my_resnet = my_resnet.apply(init_weights)

X = torch.randn(1, 3, 32, 32)
print("Outputs match:", torch.allclose(akamaster_resnet(X), my_resnet(X)))

Outputs match: True


In [27]:
akamaster_resnet = get_resnet_akamaster("resnet56", num_classes=100)
my_resnet = get_resnet("resnet56", num_classes=100)
res = {
    "akamaster": train(akamaster_resnet, dataset="cifar100", opt_strat="step-200"), 
    "mine": train(my_resnet, dataset="cifar100", opt_strat="step-200")
}
pd.DataFrame(res, index=[0])

Unnamed: 0,akamaster,mine
0,0.7202,0.7198


# Model Benchmarks
- Major difference between Imagenet and CIFAR versions on perf

In [None]:
benchmarks = []
for dataset in ["cifar10", "cifar100"]:
    for model in ["resnet18", "resnet20", "resnet32", "resnet34", "resnet44", "resnet50", "resnet56"]:
        net = get_resnet(model, num_classes=10 if dataset == "cifar10" else 100)
        acc = train(net, dataset, opt_strat="step-200", batch_size=64)
        benchmarks.append({"dataset": dataset, "model": model, "val_acc": acc})

In [34]:
# step-200 bs=64
pd.DataFrame(benchmarks).pivot(index="model", columns="dataset", values="val_acc").round(2)

dataset,cifar10,cifar100
model,Unnamed: 1_level_1,Unnamed: 2_level_1
resnet18,0.95,0.78
resnet20,0.92,0.68
resnet32,0.93,0.69
resnet34,0.96,0.79
resnet44,0.92,0.7
resnet50,0.95,0.78
resnet56,0.92,0.71


In [5]:
# step-160, bs 256
pd.DataFrame(benchmarks).pivot(index="model", columns="dataset", values="val_acc").round(2)

dataset,cifar10,cifar100
model,Unnamed: 1_level_1,Unnamed: 2_level_1
resnet18,0.92,0.71
resnet20,0.91,0.67
resnet32,0.91,0.67
resnet34,0.92,0.67
resnet44,0.92,0.68
resnet50,0.91,0.65
resnet56,0.89,0.68


# HP Benchmarks
- Optimisation Strategy and Batch Size hyper-param analysis
- BS major impact on step-16 perf; 200 epoch optimisations a lot less sensitive to batch-size
- Only bs=64 epoch=200 versions can get closer to 73% (seed dependent)

In [None]:
dataset = "cifar100"
model = "resnet56"
model_benchmarks = []
for batch_size in [64, 128, 256]:
    for opt_strat in ["step-160", "step-200", "cos-200"]:
        net = get_resnet(model, num_classes=10 if dataset == "cifar10" else 100)
        acc = train(net, dataset, opt_strat=opt_strat, batch_size=batch_size)
        model_benchmarks.append({"dataset": dataset, "model": model, "val_acc": acc, "opt_strat": opt_strat, "batch_size": batch_size})

In [38]:
pd.DataFrame(model_benchmarks).pivot(index="opt_strat", columns="batch_size", values="val_acc").round(4)

batch_size,64,128,256
opt_strat,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
cos-200,0.7272,0.7162,0.7147
step-160,0.6997,0.6861,0.674
step-200,0.7127,0.7107,0.7117
