# Fashion-MNIST Classification with an LDA Head
This notebook trains a small convolutional encoder with a linear discriminant analysis (LDA) head on Fashion-MNIST, then visualises the learned embedding space.


### Setup


In [1]:
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from src.lda import LDAHead, LDALoss

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device =', device)

device = cuda


### Data


In [3]:
tfm = transforms.ToTensor()
train_ds = datasets.FashionMNIST(root='./data', train=True, transform=tfm, download=True)
test_ds  = datasets.FashionMNIST(root='./data', train=False, transform=tfm, download=True)
train_ld = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=2, pin_memory=True)
test_ld  = DataLoader(test_ds,  batch_size=1024, shuffle=False, num_workers=2, pin_memory=True)
len(train_ds), len(test_ds)

(60000, 10000)

### Model: encoder + LDA head (on-the-fly stats)


In [4]:
class Encoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 256), nn.ReLU(inplace=True),
            nn.Linear(256, 64), nn.ReLU(inplace=True),
            nn.Linear(64, dim),
        )
    def forward(self, x): return self.net(x)

class DeepLDA(nn.Module):
    def __init__(self, C, D):
        super().__init__()
        self.encoder = Encoder(D)
        self.head = LDAHead(C, D)
    def forward(self, x):
        z = self.encoder(x)
        return self.head(z)

### Train & Eval


In [5]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    ok = tot = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)  # EMA stats
        ok += (logits.argmax(1) == y).sum().item()
        tot += y.size(0)
    return ok / tot

model = DeepLDA(C=10, D=9).to(device)
opt = torch.optim.Adam(model.encoder.parameters())
#loss_fn = nn.CrossEntropyLoss()
#loss_fn = nn.NLLLoss()
loss_fn = LDALoss()

for epoch in range(1, 41):
    model.train()
    loss_sum = acc_sum = n_sum = 0
    for x, y in train_ld:
        x, y = x.to(device), y.to(device)
        logits = model(x)        # uses batch stats + EMA update
        loss = loss_fn(logits, y)
        opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
        with torch.no_grad():
            pred = logits.argmax(1)
            acc_sum += (pred == y).sum().item()
            n_sum += y.size(0)
            loss_sum += loss.item() * y.size(0)
    tr_acc = acc_sum / n_sum
    te_acc = evaluate(model, test_ld)
    print(f"[{epoch:02d}] train loss={loss_sum/n_sum:.4f} acc={tr_acc:.4f} | test acc={te_acc:.4f}")

[01] train loss=-14.8614 acc=0.7711 | test acc=0.8300
[02] train loss=-16.1347 acc=0.8559 | test acc=0.8574
[03] train loss=-16.4379 acc=0.8727 | test acc=0.8654
[04] train loss=-16.6159 acc=0.8816 | test acc=0.8670
[05] train loss=-16.7528 acc=0.8885 | test acc=0.8746
[06] train loss=-16.8628 acc=0.8941 | test acc=0.8766
[07] train loss=-16.9395 acc=0.8982 | test acc=0.8800
[08] train loss=-17.0027 acc=0.9026 | test acc=0.8788
[09] train loss=-17.0557 acc=0.9047 | test acc=0.8818
[10] train loss=-17.1113 acc=0.9086 | test acc=0.8807
[11] train loss=-17.1489 acc=0.9113 | test acc=0.8851
[12] train loss=-17.1944 acc=0.9147 | test acc=0.8886
[13] train loss=-17.2362 acc=0.9177 | test acc=0.8882
[14] train loss=-17.2516 acc=0.9189 | test acc=0.8893
[15] train loss=-17.3022 acc=0.9228 | test acc=0.8879
[16] train loss=-17.3224 acc=0.9256 | test acc=0.8893
[17] train loss=-17.3514 acc=0.9262 | test acc=0.8909
[18] train loss=-17.3747 acc=0.9294 | test acc=0.8917
[19] train loss=-17.4066 acc