# CIFAR-10 Classification with an LDA Head
This notebook trains a lightweight convolutional encoder with an LDA head on CIFAR-10, alongside a softmax baseline, then visualises the learned embedding spaces side by side.

### Setup


In [1]:
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from dnll import LDAHead, DNLLLoss

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device =', device)

device = cuda


### Data


In [3]:
mean = (0.4914, 0.4822, 0.4465)
std = (0.2470, 0.2435, 0.2616)
pin_memory = torch.cuda.is_available()

train_tfm = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

test_tfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

train_ds = datasets.CIFAR10(root='./data', train=True, transform=train_tfm, download=True)
test_ds  = datasets.CIFAR10(root='./data', train=False, transform=test_tfm, download=True)
train_ld = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=4, pin_memory=pin_memory)
test_ld  = DataLoader(test_ds,  batch_size=1024, shuffle=False, num_workers=4, pin_memory=pin_memory)
len(train_ds), len(test_ds)


100%|██████████| 170M/170M [00:13<00:00, 12.9MB/s] 
  entry = pickle.load(f, encoding="latin1")


(50000, 10000)

### Model: encoder + heads (LDA + softmax)

In [4]:
class Encoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
        )
        self.proj = nn.Linear(256, dim)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        return self.proj(x)

class DeepLDA(nn.Module):
    def __init__(self, C, D):
        super().__init__()
        self.encoder = Encoder(D)
        self.head = LDAHead(C, D, covariance_type='spherical')

    def forward(self, x):
        z = self.encoder(x)
        return self.head(z)

class SoftmaxHead(nn.Module):
    def __init__(self, D, C):
        super().__init__()
        self.linear = nn.Linear(D, C)

    def forward(self, z):
        return self.linear(z)

class DeepClassifier(nn.Module):
    def __init__(self, C, D):
        super().__init__()
        self.encoder = Encoder(D)
        self.head = SoftmaxHead(D, C)

    def forward(self, x):
        z = self.encoder(x)
        return self.head(z)


### Train & Eval


In [5]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    ok = tot = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        ok += (logits.argmax(1) == y).sum().item()
        tot += y.size(0)
    return ok / tot

lda_model = DeepLDA(C=10, D=9).to(device)
lda_opt = torch.optim.Adam(lda_model.parameters())
lda_loss_fn = DNLLLoss(lambda_reg=.01)

lda_train_acc = []
lda_test_acc = []

for epoch in range(1, 101):
    lda_model.train()
    loss_sum = acc_sum = n_sum = 0
    for x, y in train_ld:
        x, y = x.to(device), y.to(device)
        logits = lda_model(x)
        loss = lda_loss_fn(logits, y)
        lda_opt.zero_grad(set_to_none=True)
        loss.backward()
        lda_opt.step()
        with torch.no_grad():
            pred = logits.argmax(1)
            acc_sum += (pred == y).sum().item()
            n_sum += y.size(0)
            loss_sum += loss.item() * y.size(0)
    tr_acc = acc_sum / n_sum
    te_acc = evaluate(lda_model, test_ld)
    lda_train_acc.append(tr_acc)
    lda_test_acc.append(te_acc)
    print(f"[LDA {epoch:02d}] train loss={loss_sum/n_sum:.4f} acc={tr_acc:.4f} | test acc={te_acc:.4f}")


[LDA 01] train loss=7.2210 acc=0.2897 | test acc=0.4180
[LDA 02] train loss=5.0649 acc=0.4884 | test acc=0.4962
[LDA 03] train loss=3.5229 acc=0.5741 | test acc=0.5552
[LDA 04] train loss=2.0474 acc=0.6374 | test acc=0.6409
[LDA 05] train loss=0.6675 acc=0.6762 | test acc=0.6214
[LDA 06] train loss=-0.6390 acc=0.7049 | test acc=0.6410
[LDA 07] train loss=-1.6781 acc=0.7375 | test acc=0.7392
[LDA 08] train loss=-2.3930 acc=0.7740 | test acc=0.7309
[LDA 09] train loss=-2.7373 acc=0.8021 | test acc=0.7463
[LDA 10] train loss=-2.9030 acc=0.8215 | test acc=0.7991
[LDA 11] train loss=-3.0165 acc=0.8406 | test acc=0.8233
[LDA 12] train loss=-3.0935 acc=0.8564 | test acc=0.7695
[LDA 13] train loss=-3.1336 acc=0.8659 | test acc=0.8123
[LDA 14] train loss=-3.1631 acc=0.8738 | test acc=0.8281
[LDA 15] train loss=-3.1936 acc=0.8828 | test acc=0.8374
[LDA 16] train loss=-3.2186 acc=0.8894 | test acc=0.8328
[LDA 17] train loss=-3.2356 acc=0.8961 | test acc=0.8527
[LDA 18] train loss=-3.2483 acc=0.90

### Softmax head baseline
Train a second model with the same encoder but a standard softmax classifier for a side-by-side comparison.

In [None]:
softmax_model = DeepClassifier(C=10, D=9).to(device)
softmax_opt = torch.optim.Adam(softmax_model.parameters())
softmax_loss_fn = nn.CrossEntropyLoss()

softmax_train_acc = []
softmax_test_acc = []

for epoch in range(1, 101):
    softmax_model.train()
    loss_sum = acc_sum = n_sum = 0
    for x, y in train_ld:
        x, y = x.to(device), y.to(device)
        logits = softmax_model(x)
        loss = softmax_loss_fn(logits, y)
        softmax_opt.zero_grad(set_to_none=True)
        loss.backward()
        softmax_opt.step()
        with torch.no_grad():
            pred = logits.argmax(1)
            acc_sum += (pred == y).sum().item()
            n_sum += y.size(0)
            loss_sum += loss.item() * y.size(0)
    tr_acc = acc_sum / n_sum
    te_acc = evaluate(softmax_model, test_ld)
    softmax_train_acc.append(tr_acc)
    softmax_test_acc.append(te_acc)
    print(f"[Softmax {epoch:02d}] train loss={loss_sum/n_sum:.4f} acc={tr_acc:.4f} | test acc={te_acc:.4f}")


### Confidence histograms

In [None]:
import numpy as np
import matplotlib.pyplot as plt

@torch.no_grad()
def plot_confidence_hist(model, loader, out_path, title=None):
    model.eval()
    conf_list, pred_list, label_list = [], [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        probs = F.softmax(logits, dim=1)
        conf, pred = probs.max(1)
        conf_list.append(conf.cpu())
        pred_list.append(pred.cpu())
        label_list.append(y.cpu())

    conf = torch.cat(conf_list).numpy()
    pred = torch.cat(pred_list)
    labels = torch.cat(label_list)
    acc = (pred == labels).float().mean().item()
    avg_conf = conf.mean().item()

    bins = np.linspace(0.0, 1.0, 21)
    weights = np.ones_like(conf) / conf.size

    plt.figure(figsize=(4.5, 4.5))
    plt.hist(conf, bins=bins, weights=weights, color='blue', edgecolor='black')
    plt.axvline(avg_conf, color='gray', linestyle='--', linewidth=3)
    plt.axvline(acc, color='gray', linestyle='--', linewidth=3)
    plt.text(avg_conf, 0.95 * plt.gca().get_ylim()[1], 'Avg confidence',
             rotation=90, va='top', ha='center')
    plt.text(acc, 0.95 * plt.gca().get_ylim()[1], 'Accuracy',
             rotation=90, va='top', ha='center')
    if title:
        plt.title(title)
    plt.xlabel('Confidence')
    plt.ylabel('% of Samples')
    plt.xlim(0, 1)
    plt.tight_layout()
    plt.savefig(out_path, dpi=600)

plot_confidence_hist(lda_model, test_ld, 'plots/cifar10_lda_confidence_hist.png', title='LDA')
plot_confidence_hist(softmax_model, test_ld, 'plots/cifar10_softmax_confidence_hist.png', title='Softmax')


### Reliability diagrams


In [None]:
import numpy as np
import matplotlib.pyplot as plt

@torch.no_grad()
def plot_reliability_diagram(model, loader, out_path, title=None, n_bins=10):
    model.eval()
    conf_list, pred_list, label_list = [], [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        probs = F.softmax(logits, dim=1)
        conf, pred = probs.max(1)
        conf_list.append(conf.cpu())
        pred_list.append(pred.cpu())
        label_list.append(y.cpu())

    conf = torch.cat(conf_list).numpy()
    pred = torch.cat(pred_list).numpy()
    labels = torch.cat(label_list).numpy()
    correct = (pred == labels).astype(np.float32)

    bins = np.linspace(0.0, 1.0, n_bins + 1)
    bin_ids = np.digitize(conf, bins[1:-1], right=True)

    bin_acc = np.zeros(n_bins, dtype=np.float32)
    bin_conf = np.zeros(n_bins, dtype=np.float32)
    bin_frac = np.zeros(n_bins, dtype=np.float32)

    for b in range(n_bins):
        mask = bin_ids == b
        if mask.any():
            bin_acc[b] = correct[mask].mean()
            bin_conf[b] = conf[mask].mean()
            bin_frac[b] = mask.mean()

    ece = np.sum(np.abs(bin_acc - bin_conf) * bin_frac)

    bin_centers = (bins[:-1] + bins[1:]) / 2
    bin_width = bins[1] - bins[0]

    plt.figure(figsize=(4.5, 4.5))
    plt.plot([0, 1], [0, 1], '--', color='gray', linewidth=3)
    plt.bar(bin_centers, bin_acc, width=bin_width, color='blue', edgecolor='black', label='Outputs')
    gap = bin_conf - bin_acc
    plt.bar(bin_centers, gap, bottom=bin_acc, width=bin_width, color='salmon', edgecolor='red',
            alpha=0.4, hatch='//', label='Gap')
    if title:
        plt.title(title)
    plt.xlabel('Confidence')
    plt.ylabel('Accuracy')
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.legend(loc='upper left')
    plt.text(0.95, 0.05, f'ECE={ece*100:.2f}', ha='right', va='bottom',
             bbox=dict(boxstyle='round', facecolor='lightsteelblue', alpha=0.8))
    plt.tight_layout()
    plt.savefig(out_path, dpi=600)

plot_reliability_diagram(lda_model, test_ld, 'plots/cifar10_lda_reliability_diagram.png', title='LDA')
plot_reliability_diagram(softmax_model, test_ld, 'plots/cifar10_softmax_reliability_diagram.png', title='Softmax')


### PCA over Embeddings

In [None]:
import matplotlib.pyplot as plt

@torch.no_grad()
def collect_embeddings(model, loader, batches=10):
    model.eval()
    embeds, labels = [], []
    for i, (x, y) in enumerate(loader):   # use train_ld if you prefer
        x = x.to(device)
        z = model.encoder(x).cpu()
        embeds.append(z)
        labels.append(y)
        if i >= batches - 1:   # 10 batches ≈10k points; raise/lower to taste
            break

    z = torch.cat(embeds)
    y = torch.cat(labels)

    # 2D projection (PCA)
    z0 = z - z.mean(0, keepdim=True)
    _, _, V = torch.pca_lowrank(z0, q=2)
    z2 = z0 @ V[:, :2]
    return z2, y

lda_z2, lda_y = collect_embeddings(lda_model, train_ld)
softmax_z2, softmax_y = collect_embeddings(softmax_model, train_ld)

plt.rcParams.update({
    "font.size": 14, "axes.labelsize": 16, "legend.fontsize": 13,
    "xtick.labelsize": 13, "ytick.labelsize": 13
})

palette = plt.cm.tab10.colors
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
plots = [
    (axes[0], lda_z2, lda_y, "Deep LDA embeddings"),
    (axes[1], softmax_z2, softmax_y, "Softmax head embeddings"),
]

for ax, z2, y, title in plots:
    for c in range(10):
        idx = y == c
        ax.scatter(z2[idx, 0], z2[idx, 1], s=8, alpha=0.6, color=palette[c], label=train_ds.classes[c])
    ax.set_xlabel("PC1")
    ax.set_ylabel("PC2")
    ax.set_title(title)

handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, bbox_to_anchor=(0.5, -0.02), loc="lower center", ncol=5)
plt.tight_layout(rect=(0, 0.07, 1, 1))
plt.savefig('plots/cifar10_lda_vs_softmax_embeddings.png', dpi=600)


# CIFAR-100 Classification with an LDA Head
This notebook trains a lightweight convolutional encoder with a linear discriminant analysis (LDA) head on CIFAR-100, then visualises the learned embedding space.


### Setup


In [None]:
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from dnll import LDAHead, DNLLLoss

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device =', device)


### Data


In [None]:
mean = (0.5071, 0.4867, 0.4408)
std = (0.2675, 0.2565, 0.2761)
pin_memory = torch.cuda.is_available()

train_tfm = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

test_tfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

train_ds = datasets.CIFAR100(root='./data', train=True, transform=train_tfm, download=True)
test_ds  = datasets.CIFAR100(root='./data', train=False, transform=test_tfm, download=True)
train_ld = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=4, pin_memory=pin_memory)
test_ld  = DataLoader(test_ds,  batch_size=1024, shuffle=False, num_workers=4, pin_memory=pin_memory)
len(train_ds), len(test_ds)


### Model: encoder + LDA head (on-the-fly stats)


In [None]:
class Encoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
        )
        self.proj = nn.Linear(256, dim)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        return self.proj(x)

class DeepLDA(nn.Module):
    def __init__(self, C, D):
        super().__init__()
        self.encoder = Encoder(D)
        self.head = LDAHead(C, D)

    def forward(self, x):
        z = self.encoder(x)
        return self.head(z)

class SoftmaxHead(nn.Module):
    def __init__(self, D, C):
        super().__init__()
        self.linear = nn.Linear(D, C)

    def forward(self, z):
        return self.linear(z)

class DeepClassifier(nn.Module):
    def __init__(self, C, D):
        super().__init__()
        self.encoder = Encoder(D)
        self.head = SoftmaxHead(D, C)

    def forward(self, x):
        z = self.encoder(x)
        return self.head(z)


### Train & Eval


In [None]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    ok = tot = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        ok += (logits.argmax(1) == y).sum().item()
        tot += y.size(0)
    return ok / tot

model = DeepLDA(C=100, D=99).to(device)
opt = torch.optim.Adam(model.parameters())
#loss_fn = nn.CrossEntropyLoss()
#loss_fn = nn.NLLLoss()
#loss_fn = LDALoss()
#loss_fn = nn.PoissonNLLLoss()
#loss_fn = LogisticLoss()
loss_fn = DNLLLoss(lambda_reg=.01)

train_acc = []
test_acc = []

for epoch in range(1, 101):
    model.train()
    loss_sum = acc_sum = n_sum = 0
    for x, y in train_ld:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        #y_oh = 100 * F.one_hot(y, num_classes=100).to(logits.dtype)  # shape [batch,100]
        loss = loss_fn(logits, y)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        with torch.no_grad():
            pred = logits.argmax(1)
            acc_sum += (pred == y).sum().item()
            n_sum += y.size(0)
            loss_sum += loss.item() * y.size(0)
    tr_acc = acc_sum / n_sum
    te_acc = evaluate(model, test_ld)
    train_acc.append(tr_acc)
    test_acc.append(te_acc)
    print(f"[{epoch:02d}] train loss={loss_sum/n_sum:.4f} acc={tr_acc:.4f} | test acc={te_acc:.4f}")


### Softmax head baseline
Train a second model with the same encoder but a standard softmax classifier for a side-by-side comparison.


In [None]:
softmax_model = DeepClassifier(C=100, D=99).to(device)
softmax_opt = torch.optim.Adam(softmax_model.parameters())
softmax_loss_fn = nn.CrossEntropyLoss()

softmax_train_acc = []
softmax_test_acc = []

for epoch in range(1, 101):
    softmax_model.train()
    loss_sum = acc_sum = n_sum = 0
    for x, y in train_ld:
        x, y = x.to(device), y.to(device)
        logits = softmax_model(x)
        loss = softmax_loss_fn(logits, y)
        softmax_opt.zero_grad(set_to_none=True)
        loss.backward()
        softmax_opt.step()
        with torch.no_grad():
            pred = logits.argmax(1)
            acc_sum += (pred == y).sum().item()
            n_sum += y.size(0)
            loss_sum += loss.item() * y.size(0)
    tr_acc = acc_sum / n_sum
    te_acc = evaluate(softmax_model, test_ld)
    softmax_train_acc.append(tr_acc)
    softmax_test_acc.append(te_acc)
    print(f"[Softmax {epoch:02d}] train loss={loss_sum/n_sum:.4f} acc={tr_acc:.4f} | test acc={te_acc:.4f}")


In [None]:
import numpy as np
import matplotlib.pyplot as plt

model.eval()
conf_list = []
pred_list = []
label_list = []
with torch.no_grad():
    for x, y in test_ld:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        probs = F.softmax(logits, dim=1)
        conf, pred = probs.max(1)
        conf_list.append(conf.cpu())
        pred_list.append(pred.cpu())
        label_list.append(y.cpu())

conf = torch.cat(conf_list).numpy()
pred = torch.cat(pred_list)
labels = torch.cat(label_list)
acc = (pred == labels).float().mean().item()
avg_conf = conf.mean().item()

bins = np.linspace(0.0, 1.0, 21)
weights = np.ones_like(conf) / conf.size

plt.figure(figsize=(4.5, 4.5))
plt.hist(conf, bins=bins, weights=weights, color='blue', edgecolor='black')
plt.axvline(avg_conf, color='gray', linestyle='--', linewidth=3)
plt.axvline(acc, color='gray', linestyle='--', linewidth=3)
plt.text(avg_conf, 0.95 * plt.gca().get_ylim()[1], 'Avg confidence',
         rotation=90, va='top', ha='center')
plt.text(acc, 0.95 * plt.gca().get_ylim()[1], 'Accuracy',
         rotation=90, va='top', ha='center')
plt.xlabel('Confidence')
plt.ylabel('% of Samples')
plt.xlim(0, 1)
plt.tight_layout()
plt.savefig('plots/cifar100_confidence_hist.png', dpi=600)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

model.eval()
conf_list = []
pred_list = []
label_list = []
with torch.no_grad():
    for x, y in test_ld:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        probs = F.softmax(logits, dim=1)
        conf, pred = probs.max(1)
        conf_list.append(conf.cpu())
        pred_list.append(pred.cpu())
        label_list.append(y.cpu())

conf = torch.cat(conf_list).numpy()
pred = torch.cat(pred_list).numpy()
labels = torch.cat(label_list).numpy()
correct = (pred == labels).astype(np.float32)

n_bins = 10
bins = np.linspace(0.0, 1.0, n_bins + 1)
bin_ids = np.digitize(conf, bins[1:-1], right=True)

bin_acc = np.zeros(n_bins, dtype=np.float32)
bin_conf = np.zeros(n_bins, dtype=np.float32)
bin_frac = np.zeros(n_bins, dtype=np.float32)

for b in range(n_bins):
    mask = bin_ids == b
    if mask.any():
        bin_acc[b] = correct[mask].mean()
        bin_conf[b] = conf[mask].mean()
        bin_frac[b] = mask.mean()

ece = np.sum(np.abs(bin_acc - bin_conf) * bin_frac)

bin_centers = (bins[:-1] + bins[1:]) / 2
bin_width = bins[1] - bins[0]

plt.figure(figsize=(4.5, 4.5))
plt.plot([0, 1], [0, 1], '--', color='gray', linewidth=3)
plt.bar(bin_centers, bin_acc, width=bin_width, color='blue', edgecolor='black', label='Outputs')
gap = bin_conf - bin_acc
plt.bar(bin_centers, gap, bottom=bin_acc, width=bin_width, color='salmon', edgecolor='red',
        alpha=0.4, hatch='//', label='Gap')
plt.xlabel('Confidence')
plt.ylabel('Accuracy')
plt.xlim(0, 1)
plt.ylim(0, 1)
plt.legend(loc='upper left')
plt.text(0.95, 0.05, f'ECE={ece*100:.2f}', ha='right', va='bottom',
         bbox=dict(boxstyle='round', facecolor='lightsteelblue', alpha=0.8))
plt.tight_layout()
plt.savefig('plots/cifar100_reliability_diagram.png', dpi=600)


In [None]:
import matplotlib.pyplot as plt

epochs = range(1, len(train_acc) + 1)
plt.figure(figsize=(6, 4))
#plt.plot(epochs, train_acc, label='Train accuracy')
plt.plot(epochs, test_acc, label='Test accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('CIFAR-100 accuracy during training')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()


In [None]:
# Save trained model
save_path = 'cifar100_model.pth'
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': opt.state_dict(),
}, save_path)
print(f'Saved model checkpoint to {save_path}')

### PCA over Embeddings


In [None]:
import matplotlib.pyplot as plt

@torch.no_grad()
def collect_embeddings(model, loader, batches=10):
    model.eval()
    embeds, labels = [], []
    for i, (x, y) in enumerate(loader):   # use train_ld if you prefer
        x = x.to(device)
        z = model.encoder(x).cpu()
        embeds.append(z)
        labels.append(y)
        if i >= batches - 1:   # 10 batches ~10k points; raise/lower to taste
            break

    z = torch.cat(embeds)
    y = torch.cat(labels)

    # 2D projection (PCA)
    z0 = z - z.mean(0, keepdim=True)
    _, _, V = torch.pca_lowrank(z0, q=2)
    z2 = z0 @ V[:, :2]
    return z2, y

lda_model = model
if 'softmax_model' not in globals():
    raise RuntimeError('softmax_model not found. Train or load a Softmax baseline before running this cell.')
lda_z2, lda_y = collect_embeddings(lda_model, train_ld)
softmax_z2, softmax_y = collect_embeddings(softmax_model, train_ld)

plt.rcParams.update({
    "font.size": 14, "axes.labelsize": 16, "legend.fontsize": 13,
    "xtick.labelsize": 13, "ytick.labelsize": 13
})

fig, axes = plt.subplots(1, 2, figsize=(14, 6))
plots = [
    (axes[0], lda_z2, lda_y, "Deep LDA embeddings"),
    (axes[1], softmax_z2, softmax_y, "Softmax head embeddings"),
]

for ax, z2, y, title in plots:
    sc = ax.scatter(z2[:, 0], z2[:, 1], c=y, cmap="nipy_spectral",
                    s=8, alpha=0.6, vmin=0, vmax=99, rasterized=True)
    ax.set_xlabel("PC1")
    ax.set_ylabel("PC2")
    ax.set_title(title)

cbar = fig.colorbar(sc, ax=axes, fraction=0.046, pad=0.04)
cbar.set_label("Class index")
plt.tight_layout()
plt.savefig('plots/cifar100_lda_vs_softmax_embeddings.png', dpi=600)


In [None]:
# Show a few examples from the two closest classes
import numpy as np

# Find closest pair using class means from the previous cell
with torch.no_grad():
    dist_mat = torch.cdist(class_means, class_means)
    dist_mat.fill_diagonal_(float('inf'))
    row_min, row_argmin = dist_mat.min(dim=1)
    cls_a = row_min.argmin().item()
    cls_b = row_argmin[cls_a].item()

cls_names = train_ds.classes
print(f"Closest classes: {cls_a} ({cls_names[cls_a]}) and {cls_b} ({cls_names[cls_b]}), distance={dist_mat[cls_a, cls_b]:.4f}")

# Collect a few raw images (no augmentation) for each class from the training set
n_per_class = 4
indices_a = [i for i, t in enumerate(train_ds.targets) if t == cls_a][:n_per_class]
indices_b = [i for i, t in enumerate(train_ds.targets) if t == cls_b][:n_per_class]

fig, axes = plt.subplots(2, n_per_class, figsize=(2 * n_per_class, 4))
for j, idx in enumerate(indices_a):
    axes[0, j].imshow(train_ds.data[idx])
    axes[0, j].axis('off')
    axes[0, j].set_title(f"{cls_names[cls_a]}")
for j, idx in enumerate(indices_b):
    axes[1, j].imshow(train_ds.data[idx])
    axes[1, j].axis('off')
    axes[1, j].set_title(f"{cls_names[cls_b]}")
plt.tight_layout()
plt.show()


In [None]:
# Show examples from the 2nd closest pair of classes
# Assumes class_means computed above
with torch.no_grad():
    dist_mat = torch.cdist(class_means, class_means)
    dist_mat.fill_diagonal_(float('inf'))
    i_idx, j_idx = torch.triu_indices(dist_mat.size(0), dist_mat.size(1), offset=1)
    pair_dists = dist_mat[i_idx, j_idx]
    if pair_dists.numel() < 2:
        raise RuntimeError('Need at least two class pairs to find the 2nd closest pair.')
    sorted_vals, sorted_idx = torch.sort(pair_dists)
    second_idx = sorted_idx[1].item()
    cls_a = i_idx[second_idx].item()
    cls_b = j_idx[second_idx].item()
    pair_dist = sorted_vals[1].item()

cls_names = train_ds.classes
print(f"2nd closest classes: {cls_a} ({cls_names[cls_a]}) and {cls_b} ({cls_names[cls_b]}), distance={pair_dist:.4f}")

n_per_class = 4
indices_a = [i for i, t in enumerate(train_ds.targets) if t == cls_a][:n_per_class]
indices_b = [i for i, t in enumerate(train_ds.targets) if t == cls_b][:n_per_class]

fig, axes = plt.subplots(2, n_per_class, figsize=(2 * n_per_class, 4))
for j, idx in enumerate(indices_a):
    axes[0, j].imshow(train_ds.data[idx])
    axes[0, j].axis('off')
    axes[0, j].set_title(f"{cls_names[cls_a]}")
for j, idx in enumerate(indices_b):
    axes[1, j].imshow(train_ds.data[idx])
    axes[1, j].axis('off')
    axes[1, j].set_title(f"{cls_names[cls_b]}")
plt.tight_layout()
plt.show()


# Fashion-MNIST Classification with an LDA Head
This notebook trains a small convolutional encoder with a linear discriminant analysis (LDA) head on Fashion-MNIST, then visualises the learned embedding space.


### Setup


In [None]:
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from dnll import LDAHead, DNLLLoss

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device =', device)

### Data


In [None]:
tfm = transforms.ToTensor()
train_ds = datasets.FashionMNIST(root='./data', train=True, transform=tfm, download=True)
test_ds  = datasets.FashionMNIST(root='./data', train=False, transform=tfm, download=True)
train_ld = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=2, pin_memory=True)
test_ld  = DataLoader(test_ds,  batch_size=1024, shuffle=False, num_workers=2, pin_memory=True)
len(train_ds), len(test_ds)

### Model: encoder + LDA head (on-the-fly stats)


In [None]:
class Encoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
        )
        self.proj = nn.Linear(256, dim)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        return self.proj(x)

class DeepLDA(nn.Module):
    def __init__(self, C, D):
        super().__init__()
        self.encoder = Encoder(D)
        self.head = LDAHead(C, D)
    def forward(self, x):
        z = self.encoder(x)
        return self.head(z)

class SoftmaxHead(nn.Module):
    def __init__(self, D, C):
        super().__init__()
        self.linear = nn.Linear(D, C)

    def forward(self, z):
        return self.linear(z)

class DeepClassifier(nn.Module):
    def __init__(self, C, D):
        super().__init__()
        self.encoder = Encoder(D)
        self.head = SoftmaxHead(D, C)

    def forward(self, x):
        z = self.encoder(x)
        return self.head(z)


### Train & Eval


In [None]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    ok = tot = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)  # EMA stats
        ok += (logits.argmax(1) == y).sum().item()
        tot += y.size(0)
    return ok / tot

model = DeepLDA(C=10, D=9).to(device)
opt = torch.optim.Adam(model.encoder.parameters())
#loss_fn = nn.CrossEntropyLoss()
#loss_fn = nn.NLLLoss()
#loss_fn = LDALoss()
loss_fn = DNLLLoss(lambda_reg=.01)
#loss_fn = LogisticLoss()

for epoch in range(1, 101):
    model.train()
    loss_sum = acc_sum = n_sum = 0
    for x, y in train_ld:
        x, y = x.to(device), y.to(device)
        logits = model(x)        # uses batch stats + EMA update
        loss = loss_fn(logits, y)
        opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
        with torch.no_grad():
            pred = logits.argmax(1)
            acc_sum += (pred == y).sum().item()
            n_sum += y.size(0)
            loss_sum += loss.item() * y.size(0)
    tr_acc = acc_sum / n_sum
    te_acc = evaluate(model, test_ld)
    print(f"[{epoch:02d}] train loss={loss_sum/n_sum:.4f} acc={tr_acc:.4f} | test acc={te_acc:.4f}")

### Softmax head baseline
Train a second model with the same encoder but a standard softmax classifier for a side-by-side comparison.


In [None]:
softmax_model = DeepClassifier(C=10, D=9).to(device)
softmax_opt = torch.optim.Adam(softmax_model.parameters())
softmax_loss_fn = nn.CrossEntropyLoss()

softmax_train_acc = []
softmax_test_acc = []

for epoch in range(1, 101):
    softmax_model.train()
    loss_sum = acc_sum = n_sum = 0
    for x, y in train_ld:
        x, y = x.to(device), y.to(device)
        logits = softmax_model(x)
        loss = softmax_loss_fn(logits, y)
        softmax_opt.zero_grad(set_to_none=True)
        loss.backward()
        softmax_opt.step()
        with torch.no_grad():
            pred = logits.argmax(1)
            acc_sum += (pred == y).sum().item()
            n_sum += y.size(0)
            loss_sum += loss.item() * y.size(0)
    tr_acc = acc_sum / n_sum
    te_acc = evaluate(softmax_model, test_ld)
    softmax_train_acc.append(tr_acc)
    softmax_test_acc.append(te_acc)
    print(f"[Softmax {epoch:02d}] train loss={loss_sum/n_sum:.4f} acc={tr_acc:.4f} | test acc={te_acc:.4f}")


In [None]:
import numpy as np
import matplotlib.pyplot as plt

@torch.no_grad()
def plot_confidence_hist(model, loader, out_path, title=None):
    model.eval()
    conf_list, pred_list, label_list = [], [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        probs = F.softmax(logits, dim=1)
        conf, pred = probs.max(1)
        conf_list.append(conf.cpu())
        pred_list.append(pred.cpu())
        label_list.append(y.cpu())

    conf = torch.cat(conf_list).numpy()
    pred = torch.cat(pred_list)
    labels = torch.cat(label_list)
    acc = (pred == labels).float().mean().item()
    avg_conf = conf.mean().item()

    bins = np.linspace(0.0, 1.0, 21)
    weights = np.ones_like(conf) / conf.size

    plt.figure(figsize=(4.5, 4.5))
    plt.hist(conf, bins=bins, weights=weights, color='blue', edgecolor='black')
    plt.axvline(avg_conf, color='gray', linestyle='--', linewidth=3)
    plt.axvline(acc, color='gray', linestyle='--', linewidth=3)
    plt.text(avg_conf, 0.95 * plt.gca().get_ylim()[1], 'Avg confidence',
             rotation=90, va='top', ha='center')
    plt.text(acc, 0.95 * plt.gca().get_ylim()[1], 'Accuracy',
             rotation=90, va='top', ha='center')
    if title:
        plt.title(title)
    plt.xlabel('Confidence')
    plt.ylabel('% of Samples')
    plt.xlim(0, 1)
    plt.tight_layout()
    plt.savefig(out_path, dpi=600)

plot_confidence_hist(model, test_ld, 'plots/fashion_mnist_lda_confidence_hist.png', title='LDA')
plot_confidence_hist(softmax_model, test_ld, 'plots/fashion_mnist_softmax_confidence_hist.png', title='Softmax')


### Reliability diagram


In [None]:
import numpy as np
import matplotlib.pyplot as plt

@torch.no_grad()
def plot_reliability_diagram(model, loader, out_path, title=None, n_bins=10):
    model.eval()
    conf_list, pred_list, label_list = [], [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        probs = F.softmax(logits, dim=1)
        conf, pred = probs.max(1)
        conf_list.append(conf.cpu())
        pred_list.append(pred.cpu())
        label_list.append(y.cpu())

    conf = torch.cat(conf_list).numpy()
    pred = torch.cat(pred_list).numpy()
    labels = torch.cat(label_list).numpy()
    correct = (pred == labels).astype(np.float32)

    bins = np.linspace(0.0, 1.0, n_bins + 1)
    bin_ids = np.digitize(conf, bins[1:-1], right=True)

    bin_acc = np.zeros(n_bins, dtype=np.float32)
    bin_conf = np.zeros(n_bins, dtype=np.float32)
    bin_frac = np.zeros(n_bins, dtype=np.float32)

    for b in range(n_bins):
        mask = bin_ids == b
        if mask.any():
            bin_acc[b] = correct[mask].mean()
            bin_conf[b] = conf[mask].mean()
            bin_frac[b] = mask.mean()

    ece = np.sum(np.abs(bin_acc - bin_conf) * bin_frac)

    bin_centers = (bins[:-1] + bins[1:]) / 2
    bin_width = bins[1] - bins[0]

    plt.figure(figsize=(4.5, 4.5))
    plt.plot([0, 1], [0, 1], '--', color='gray', linewidth=3)
    plt.bar(bin_centers, bin_acc, width=bin_width, color='blue', edgecolor='black', label='Outputs')
    gap = bin_conf - bin_acc
    plt.bar(bin_centers, gap, bottom=bin_acc, width=bin_width, color='salmon', edgecolor='red',
            alpha=0.4, hatch='//', label='Gap')
    if title:
        plt.title(title)
    plt.xlabel('Confidence')
    plt.ylabel('Accuracy')
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.legend(loc='upper left')
    plt.text(0.95, 0.05, f'ECE={ece*100:.2f}', ha='right', va='bottom',
             bbox=dict(boxstyle='round', facecolor='lightsteelblue', alpha=0.8))
    plt.tight_layout()
    plt.savefig(out_path, dpi=600)

plot_reliability_diagram(model, test_ld, 'plots/fashion_mnist_lda_reliability_diagram.png', title='LDA')
plot_reliability_diagram(softmax_model, test_ld, 'plots/fashion_mnist_softmax_reliability_diagram.png', title='Softmax')


In [None]:
import matplotlib.pyplot as plt

@torch.no_grad()
def collect_embeddings(model, loader, batches=10):
    model.eval()
    embeds, labels = [], []
    for i, (x, y) in enumerate(loader):   # use train_ld if you prefer
        x = x.to(device)
        z = model.encoder(x).cpu()
        embeds.append(z)
        labels.append(y)
        if i >= batches - 1:   # 10 batches ~10k points; raise/lower to taste
            break

    z = torch.cat(embeds)
    y = torch.cat(labels)

    # 2D projection (PCA)
    z0 = z - z.mean(0, keepdim=True)
    _, _, V = torch.pca_lowrank(z0, q=2)
    z2 = z0 @ V[:, :2]
    return z2, y

lda_model = model
if 'softmax_model' not in globals():
    raise RuntimeError('softmax_model not found. Train or load a Softmax baseline before running this cell.')
lda_z2, lda_y = collect_embeddings(lda_model, train_ld)
softmax_z2, softmax_y = collect_embeddings(softmax_model, train_ld)

plt.rcParams.update({
    "font.size": 14, "axes.labelsize": 16, "legend.fontsize": 13,
    "xtick.labelsize": 13, "ytick.labelsize": 13
})

palette = plt.cm.tab10.colors
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
plots = [
    (axes[1], lda_z2, lda_y, "LDA"),
    (axes[0], softmax_z2, softmax_y, "Softmax"),
]

for ax, z2, y, title in plots:
    for c in range(10):
        idx = y == c
        ax.scatter(z2[idx, 0], z2[idx, 1], s=8, alpha=0.6, color=palette[c], label=train_ds.classes[c])
    ax.set_xlabel("PC1")
    ax.set_ylabel("PC2")
    ax.set_title(title)

handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, bbox_to_anchor=(0.5, -0.02), loc="lower center", ncol=5)
plt.tight_layout(rect=(0, 0.07, 1, 1))
plt.savefig('plots/fashion_mnist_lda_vs_softmax_embeddings.png', dpi=600)
