# 01 Â· Exploring the Limits of Linearity on MNIST

In [None]:

# %pip install -r ../requirements.txt
import torch, numpy as np, matplotlib.pyplot as plt
from pathlib import Path
import sys
sys.path.append(str(Path('..').resolve()))

from src.data import get_mnist_dataloaders
from src.models import LinearClassifier, MLP
from src.utils import get_device, select_loss, accuracy_from_logits, to_onehot

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device


In [None]:

train_loader, val_loader, test_loader = get_mnist_dataloaders(batch_size=256)
next(iter(train_loader))[0].shape


In [None]:

import torch.nn as nn, torch.optim as optim
from tqdm import tqdm

def run_experiment(model, loss_name='mse', epochs=2, lr=1e-3):
    criterion = select_loss(loss_name)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    for ep in range(1, epochs+1):
        model.train(); tl, ta, count = 0.0, 0.0, 0
        for x, y in tqdm(train_loader, leave=False):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            if loss_name == 'mse':
                y_oh = to_onehot(y, 10).to(device)
                loss = criterion(torch.softmax(logits, dim=1), y_oh)
            else:
                loss = criterion(logits, y)
            loss.backward(); optimizer.step()
            b = y.size(0)
            tl += loss.item()*b; ta += accuracy_from_logits(logits, y)*b; count += b
        history['train_loss'].append(tl/count); history['train_acc'].append(ta/count)

        model.eval(); vl, va, count = 0.0, 0.0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                if loss_name == 'mse':
                    y_oh = to_onehot(y, 10).to(device)
                    loss = criterion(torch.softmax(logits, dim=1), y_oh)
                else:
                    loss = criterion(logits, y)
                b = y.size(0)
                vl += loss.item()*b; va += accuracy_from_logits(logits, y)*b; count += b
        history['val_loss'].append(vl/count); history['val_acc'].append(va/count)
    return history


In [None]:

# Stage 1: Linear + MSE
model = LinearClassifier().to(device)
hist_linear_mse = run_experiment(model, loss_name='mse', epochs=2)
hist_linear_mse


In [None]:

# Stage 2: Linear + CrossEntropy
model = LinearClassifier().to(device)
hist_linear_ce = run_experiment(model, loss_name='crossentropy', epochs=2)

import matplotlib.pyplot as plt
plt.figure()
plt.plot(hist_linear_mse['val_loss'], label='Linear MSE')
plt.plot(hist_linear_ce['val_loss'], label='Linear CE')
plt.legend(); plt.title('Validation Loss'); plt.xlabel('epoch'); plt.ylabel('loss'); plt.show()


In [None]:

# Stage 3: MLP + ReLU
model = MLP(hidden=256).to(device)
hist_mlp = run_experiment(model, loss_name='crossentropy', epochs=2)
plt.figure()
plt.plot(hist_mlp['val_loss'], label='MLP + ReLU (CE)')
plt.legend(); plt.title('Validation Loss'); plt.xlabel('epoch'); plt.ylabel('loss'); plt.show()


In [None]:

# Visualize linear weights as 28x28 templates
model = LinearClassifier().to(device)
_ = run_experiment(model, loss_name='crossentropy', epochs=1)
W = model.fc.weight.detach().cpu().numpy()

import math
cols = 5
rows = math.ceil(10/cols)
plt.figure(figsize=(10, 4))
for i in range(10):
    plt.subplot(rows, cols, i+1)
    plt.imshow(W[i].reshape(28,28))
    plt.axis('off'); plt.title(str(i))
plt.suptitle('Linear Class Templates (Weights)')
plt.show()
