# Fashion-MNIST Classification with an MDA Head
This notebook trains a small convolutional encoder with a mixture discriminant analysis (MDA) head 
on Fashion-MNIST, then evaluates calibration plots.


### Setup


In [1]:
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from dnll import MDAHead, DNLLLoss


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device =', device)


device = cuda


### Data


In [3]:
tfm = transforms.ToTensor()
train_ds = datasets.FashionMNIST(root='./data', train=True, transform=tfm, download=True)
test_ds  = datasets.FashionMNIST(root='./data', train=False, transform=tfm, download=True)
train_ld = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=2, pin_memory=True)
test_ld  = DataLoader(test_ds,  batch_size=1024, shuffle=False, num_workers=2, pin_memory=True)
len(train_ds), len(test_ds)


(60000, 10000)

### Model: encoder + MDA head


In [4]:
class Encoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
        )
        self.proj = nn.Linear(256, dim)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        return self.proj(x)

class DeepMDA(nn.Module):
    def __init__(self, C, D, K):
        super().__init__()
        self.encoder = Encoder(D)
        self.head = MDAHead(C, D, K)

    def forward(self, x):
        z = self.encoder(x)
        return self.head(z)


### Train & Eval


In [5]:
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    ok = tot = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        ok += (logits.argmax(1) == y).sum().item()
        tot += y.size(0)
    return ok / tot

model = DeepMDA(C=10, D=9, K=3).to(device)
opt = torch.optim.Adam(model.encoder.parameters())
loss_fn = DNLLLoss(lambda_reg=.01)

for epoch in range(1, 101):
    model.train()
    loss_sum = acc_sum = n_sum = 0
    for x, y in train_ld:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = loss_fn(logits, y)
        opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
        with torch.no_grad():
            pred = logits.argmax(1)
            acc_sum += (pred == y).sum().item()
            n_sum += y.size(0)
            loss_sum += loss.item() * y.size(0)
    tr_acc = acc_sum / n_sum
    te_acc = evaluate(model, test_ld)
    print(f'[MDA {epoch:02d}] train loss={loss_sum/n_sum:.4f} acc={tr_acc:.4f} | test acc={te_acc:.4f}')


[MDA 01] train loss=4.7913 acc=0.8100 | test acc=0.7888
[MDA 02] train loss=4.2608 acc=0.8908 | test acc=0.8739


KeyboardInterrupt: 

### Component separation inspection


In [None]:
import numpy as np
import matplotlib.pyplot as plt

@torch.no_grad()
def component_separation_report(model):
    model.eval()
    mu = model.head.mu.detach().cpu()  # (C, K, D)
    C, K, D = mu.shape
    max_dists = []
    for c in range(C):
        diff = mu[c].unsqueeze(1) - mu[c].unsqueeze(0)
        dist = diff.pow(2).sum(-1).sqrt()
        max_dists.append(dist.max().item())
    max_class = int(np.argmax(max_dists))
    return max_class, max_dists

target_class, class_dists = component_separation_report(model)
print('max-separated class:', target_class, 'max distance:', class_dists[target_class])

with torch.no_grad():
    mix_logits = model.head.mixture_logits[target_class].cpu()
    mix_probs = torch.softmax(mix_logits, dim=0)
    majority_k = int(torch.argmax(mix_probs))
    minority_ks = [k for k in range(model.head.K) if k != majority_k]
    print('mixture probs:', mix_probs.numpy())
    print('majority component:', majority_k, 'minority components:', minority_ks)

@torch.no_grad()
def assign_components(model, loader, target_class, max_per_comp=12):
    model.eval()
    device = next(model.parameters()).device
    head = model.head
    selected = {k: [] for k in range(head.K)}
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        mask = y == target_class
        if not mask.any():
            continue
        x = x[mask]
        z = model.encoder(x)
        mu = head.mu[target_class].to(device=device, dtype=z.dtype)
        diff = z.unsqueeze(1) - mu.unsqueeze(0)
        if head.covariance_type == "full":
            L = head._get_cholesky(z.dtype, device)[target_class]
            solved = torch.linalg.solve_triangular(
                L.unsqueeze(0), diff.unsqueeze(-1), upper=False
            ).squeeze(-1)
            m2 = (solved * solved).sum(dim=-1)
            log_det = 2.0 * torch.log(torch.diagonal(L, dim1=-2, dim2=-1)).sum(dim=-1)
            log_comp = -0.5 * (m2 + log_det.unsqueeze(0))
        elif head.covariance_type == "spherical":
            m2 = (diff * diff).sum(-1)
            log_cov = head.log_cov[target_class].to(device=device, dtype=z.dtype)
            var = torch.exp(log_cov)
            log_det = head.D * log_cov
            log_comp = -0.5 * (m2 / var.unsqueeze(0) + log_det.unsqueeze(0))
        else:
            log_cov_diag = head.log_cov_diag[target_class].to(device=device, dtype=z.dtype)
            var = torch.exp(log_cov_diag)
            m2 = (diff * diff / var).sum(-1)
            log_det = log_cov_diag.sum(dim=-1)
            log_comp = -0.5 * (m2 + log_det.unsqueeze(0))
        log_mix = torch.log_softmax(
            head.mixture_logits[target_class].to(device=device, dtype=z.dtype), dim=-1
        )
        scores = log_mix.unsqueeze(0) + log_comp
        comp = scores.argmax(1)
        for k in range(head.K):
            if len(selected[k]) >= max_per_comp:
                continue
            idx = (comp == k).nonzero(as_tuple=False).squeeze(1)
            for i in idx:
                if len(selected[k]) >= max_per_comp:
                    break
                selected[k].append(x[i].detach().cpu())
        if all(len(selected[k]) >= max_per_comp for k in range(head.K)):
            break
    return selected

selected = assign_components(model, test_ld, target_class, max_per_comp=12)
majority_imgs = selected.get(majority_k, [])
minority_imgs = []
for k in minority_ks:
    minority_imgs.extend(selected.get(k, []))

def plot_grid(images, title, ncols=6):
    if not images:
        print(f'No images found for {title}')
        return
    n = len(images)
    ncols = min(ncols, n)
    nrows = (n + ncols - 1) // ncols
    plt.figure(figsize=(1.6 * ncols, 1.6 * nrows))
    for i, img in enumerate(images):
        ax = plt.subplot(nrows, ncols, i + 1)
        ax.imshow(img.squeeze(0), cmap="gray")
        ax.axis("off")
    plt.suptitle(title)
    plt.tight_layout()

plot_grid(majority_imgs[:12], f"Class {target_class} majority component")
plot_grid(minority_imgs[:12], f"Class {target_class} minority components")


### Confidence histogram


In [None]:
import numpy as np
import matplotlib.pyplot as plt

@torch.no_grad()
def plot_confidence_hist(model, loader, out_path, title=None):
    model.eval()
    conf_list, pred_list, label_list = [], [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        probs = F.softmax(logits, dim=1)
        conf, pred = probs.max(1)
        conf_list.append(conf.cpu())
        pred_list.append(pred.cpu())
        label_list.append(y.cpu())

    conf = torch.cat(conf_list).numpy()
    pred = torch.cat(pred_list)
    labels = torch.cat(label_list)
    acc = (pred == labels).float().mean().item()
    avg_conf = conf.mean().item()

    bins = np.linspace(0.0, 1.0, 21)
    weights = np.ones_like(conf) / conf.size

    plt.figure(figsize=(4.5, 4.5))
    plt.hist(conf, bins=bins, weights=weights, color='blue', edgecolor='black')
    plt.axvline(avg_conf, color='gray', linestyle='--', linewidth=3)
    plt.axvline(acc, color='gray', linestyle='--', linewidth=3)
    plt.text(avg_conf, 0.95 * plt.gca().get_ylim()[1], 'Avg confidence',
             rotation=90, va='top', ha='center')
    plt.text(acc, 0.95 * plt.gca().get_ylim()[1], 'Accuracy',
             rotation=90, va='top', ha='center')
    if title:
        plt.title(title)
    plt.xlabel('Confidence')
    plt.ylabel('% of Samples')
    plt.xlim(0, 1)
    plt.tight_layout()
    plt.savefig(out_path, dpi=600)

plot_confidence_hist(model, test_ld, 'plots/fashion_mnist_mda_confidence_hist.png', title='MDA')


### Reliability diagram


In [None]:
import numpy as np
import matplotlib.pyplot as plt

@torch.no_grad()
def plot_reliability_diagram(model, loader, out_path, title=None, n_bins=10):
    model.eval()
    conf_list, pred_list, label_list = [], [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        probs = F.softmax(logits, dim=1)
        conf, pred = probs.max(1)
        conf_list.append(conf.cpu())
        pred_list.append(pred.cpu())
        label_list.append(y.cpu())

    conf = torch.cat(conf_list).numpy()
    pred = torch.cat(pred_list).numpy()
    labels = torch.cat(label_list).numpy()
    correct = (pred == labels).astype(np.float32)

    bins = np.linspace(0.0, 1.0, n_bins + 1)
    bin_ids = np.digitize(conf, bins[1:-1], right=True)

    bin_acc = np.zeros(n_bins, dtype=np.float32)
    bin_conf = np.zeros(n_bins, dtype=np.float32)
    bin_frac = np.zeros(n_bins, dtype=np.float32)

    for b in range(n_bins):
        mask = bin_ids == b
        if mask.any():
            bin_acc[b] = correct[mask].mean()
            bin_conf[b] = conf[mask].mean()
            bin_frac[b] = mask.mean()

    ece = np.sum(np.abs(bin_acc - bin_conf) * bin_frac)

    bin_centers = (bins[:-1] + bins[1:]) / 2
    bin_width = bins[1] - bins[0]

    plt.figure(figsize=(4.5, 4.5))
    plt.plot([0, 1], [0, 1], '--', color='gray', linewidth=3)
    plt.bar(bin_centers, bin_acc, width=bin_width, color='blue', edgecolor='black', label='Outputs')
    gap = bin_conf - bin_acc
    plt.bar(bin_centers, gap, bottom=bin_acc, width=bin_width, color='salmon', edgecolor='red',
            alpha=0.4, hatch='//', label='Gap')
    if title:
        plt.title(title)
    plt.xlabel('Confidence')
    plt.ylabel('Accuracy')
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.legend(loc='upper left')
    plt.text(0.95, 0.05, f'ECE={ece*100:.2f}', ha='right', va='bottom',
             bbox=dict(boxstyle='round', facecolor='lightsteelblue', alpha=0.8))
    plt.tight_layout()
    plt.savefig(out_path, dpi=600)

plot_reliability_diagram(model, test_ld, 'plots/fashion_mnist_mda_reliability_diagram.png', title='MDA')
