# CIFAR-100 Classification with a ResNet-50 Encoder + Simplex LDA Head

This mirrors the CIFAR100.ipynb workflow but swaps the custom encoder for an ImageNet-pretrained ResNet-50.


In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import resnet50, ResNet50_Weights

from lda import SimplexLDAHead

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [2]:
weights = ResNet50_Weights.IMAGENET1K_V2
weights_tfm = weights.transforms()
mean = weights_tfm.mean
std = weights_tfm.std
pin_memory = torch.cuda.is_available()

train_tfm = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

test_tfm = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

train_ds = datasets.CIFAR100(root='./data', train=True, transform=train_tfm, download=True)
test_ds = datasets.CIFAR100(root='./data', train=False, transform=test_tfm, download=True)
train_ld = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4, pin_memory=pin_memory)
test_ld = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=4, pin_memory=pin_memory)
len(train_ds), len(test_ds)


(50000, 10000)

In [3]:
class ResNet50Encoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        backbone = resnet50(weights=weights)
        self.features = nn.Sequential(*list(backbone.children())[:-1])
        self.proj = nn.Linear(2048, dim)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        return self.proj(x)


class DeepLDA(nn.Module):
    def __init__(self, C, D):
        super().__init__()
        self.encoder = ResNet50Encoder(D)
        self.head = SimplexLDAHead(C, D)

    def forward(self, x):
        z = self.encoder(x)
        return self.head(z)


In [4]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    ok = tot = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        ok += (logits.argmax(1) == y).sum().item()
        tot += y.size(0)
    return ok / tot

model = DeepLDA(C=100, D=99).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.NLLLoss()

train_acc = []
test_acc = []

for epoch in range(1, 101):
    model.train()
    loss_sum = acc_sum = n_sum = 0
    for x, y in train_ld:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = loss_fn(logits, y)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        with torch.no_grad():
            pred = logits.argmax(1)
            acc_sum += (pred == y).sum().item()
            n_sum += y.size(0)
            loss_sum += loss.item() * y.size(0)
    tr_acc = acc_sum / n_sum
    te_acc = evaluate(model, test_ld)
    train_acc.append(tr_acc)
    test_acc.append(te_acc)
    print(f"[{epoch:02d}] train loss={loss_sum/n_sum:.4f} acc={tr_acc:.4f} | test acc={te_acc:.4f}")


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:01<00:00, 94.9MB/s]


[01] train loss=11.6478 acc=0.2683 | test acc=0.6353
[02] train loss=7.3769 acc=0.5623 | test acc=0.7482
[03] train loss=4.7326 acc=0.6280 | test acc=0.7725
[04] train loss=2.4883 acc=0.6603 | test acc=0.7868
[05] train loss=0.3790 acc=0.6845 | test acc=0.7984
[06] train loss=-1.6789 acc=0.7051 | test acc=0.8025
[07] train loss=-3.7068 acc=0.7237 | test acc=0.8061
[08] train loss=-5.6339 acc=0.7330 | test acc=0.8107
[09] train loss=-7.5708 acc=0.7440 | test acc=0.8148
[10] train loss=-9.5064 acc=0.7536 | test acc=0.8145
[11] train loss=-11.4378 acc=0.7638 | test acc=0.8160
[12] train loss=-13.3463 acc=0.7707 | test acc=0.8161
[13] train loss=-15.2264 acc=0.7766 | test acc=0.8166
[14] train loss=-17.1009 acc=0.7821 | test acc=0.8129
[15] train loss=-19.0411 acc=0.7915 | test acc=0.8174
[16] train loss=-20.8744 acc=0.7950 | test acc=0.8224
[17] train loss=-22.6975 acc=0.7969 | test acc=0.8203
[18] train loss=-24.5914 acc=0.8058 | test acc=0.8185
[19] train loss=-26.4057 acc=0.8084 | test

KeyboardInterrupt: 