In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from mda import MDAHead

In [None]:
# --- Data ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transform),
    batch_size=64, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', train=False, transform=transform),
    batch_size=1000, shuffle=False
)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 2)
        self.mda = MDAHead(d=2, num_classes=10, K=1)

    def encode(self, X):
        x = X.view(-1, 28*28)
        h = torch.relu(self.fc1(x))
        return self.fc2(h)                # z

    def forward(self, X):
        z = self.encode(X)
        return self.mda(z)               # -> class logits

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
#opt = torch.optim.Adam(list(model.fc1.parameters()) + list(model.fc2.parameters()))
opt = torch.optim.Adam(model.parameters())

In [None]:
for epoch in range(3):
    model.train()
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        opt.zero_grad()
        logits = model(X)
        loss = criterion(logits, y)
        loss.backward()
        opt.step()

    model.mda.em_update(model.encode, train_loader, device)

    print(f"Epoch {epoch+1}: loss={loss.item():.4f}")

In [None]:
# --- Evaluation ---
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for X, y in test_loader:
        X, y = X.to(device), y.to(device)
        preds = model(X).argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)
print(f"Test accuracy: {100 * correct / total:.2f}%")

In [None]:
import matplotlib.pyplot as plt
import numpy as np

model.eval()

In [None]:
# collect a small random subset of embeddings
emb_list, y_list = [], []
max_points = int(len(train_loader.dataset) * 0.05)

with torch.no_grad():
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        x_flat = X.view(-1, 28*28)
        h1 = torch.relu(model.fc1(x_flat))
        z = model.fc2(h1)       # encoder output (2D)
        emb_list.append(z)
        y_list.append(y)
        if sum(t.shape[0] for t in emb_list) >= max_points:
            break

Z = torch.cat(emb_list, dim=0).cpu().numpy()[:max_points]
Y = torch.cat(y_list, dim=0).cpu().numpy()[:max_points]

In [None]:
mu = model.mda.mu.detach().cpu()  # (C, K, D) or (C*K, D)

In [None]:
plt.figure(figsize=(7,6))
cmap = plt.get_cmap("tab10", 10)
markers = ['o','s','^','v','<','>','P','X','D','*']

for c in range(10):
    idx = (Y == c)
    plt.scatter(Z[idx,0], Z[idx,1], s=8, alpha=0.6,
                c=[cmap(c)], marker=markers[c], label=f"{c}")
    for k in range(mu.shape[1]):
        m = mu[c, k]
        plt.scatter(m[0], m[1], s=100, marker='X',
                    c=[cmap(c)], edgecolor='k', linewidths=1.2)
        #plt.text(m[0]-.1, m[1]-.5, f"{c}", fontsize=9, weight="bold")

plt.title("Deep MDA on MNIST: sampled training embeddings")
plt.xlabel("Embedding dim 1")
plt.ylabel("Embedding dim 2")
plt.legend(fontsize=8, ncol=2, frameon=False)
plt.tight_layout()
plt.show()