## setup

In [None]:
from collections import defaultdict
from itertools import islice, cycle

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import numpy as np
from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt

# optimize for blog post
plt.rcParams['figure.dpi'] = 300
plt.rcParams['axes.facecolor'] = 'none'
plt.rcParams['figure.facecolor'] = 'none'

In [None]:
config = dict(
    max_iters=10000,
    batch_size=64,
    lr=1e-3,
    in_dim=28*28,
    out_dim=10,
    hidden_dims=[32, 32, 16],
    pixels_allowed=[16, 32, 64, 128, 256, 512, 768],
)
config['eval_iters'] = config['max_iters'] // 10

device='cpu'

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='data/mnist', download=True, train=True, transform=transform)
val_dataset = datasets.MNIST(root='data/mnist', download=True, train=False, transform=transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=config['batch_size'],
    shuffle=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['batch_size'],
    shuffle=False,
)

f"{config['max_iters']/len(train_loader):.2f} Epochs"

In [None]:
class MLP(nn.Module):
    def __init__(self, dims, activation=F.gelu):
        super().__init__()

        self.linear_layers = nn.ModuleList(
            nn.Linear(i, o) for i, o in zip(dims[:-1], dims[1:])
        )
        self.act = activation

    def forward(self, x):
        x = x.view(x.size(0), -1) # flatten in case

        for layer in self.linear_layers[:-1]:
            x = self.act(layer(x))
        
        return self.linear_layers[-1](x)

In [None]:
def loop(iterable, max_iters):
    gen = islice(cycle(iterable), max_iters)
    return tqdm(gen, total=max_iters)

def eval(model, loss_fn):
    model.eval()
    metrics = defaultdict(lambda: 0)
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            y_ = model(x)
            metrics['loss'] += loss_fn(y_, y).item() / len(val_loader)
            metrics['accuracy'] += (y_.argmax(-1) == y).sum().item() / y.size(0) / len(val_loader)
    model.train()
    return dict(metrics)

def train(model, optimizer=None, loss_fn=None, verbose=True):
    model.to(device)
    model.train()

    if optimizer is None:
        optimizer = optim.Adam(model.parameters(), lr=config['lr'])
    if loss_fn is None:
        loss_fn = nn.CrossEntropyLoss()
    
    history = defaultdict(list)

    for step, (x, y) in enumerate(loop(train_loader, config['max_iters'])):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = loss_fn(model(x), y)
        loss.backward()
        optimizer.step()

        if (step + 1) % config['eval_iters'] == 0:
            metrics = eval(model, loss_fn)

            if verbose:
                print(f'step {step+1:{len(str(config["max_iters"]))}d}:', end=' ')
                print(' | '.join([f"{metric}: {value:.4f}" for metric, value in metrics.items()]))

            for metric, value in metrics.items():
                history[metric].append(value)
            
            history['steps'].append(step)
    return dict(history)

## Experiments

In [None]:
experiments = defaultdict(list)

In [None]:
# plot experiments for later

def plot(experiment_names=None):
    metric_names = ['loss', 'accuracy']

    fig, axes = plt.subplots(1, len(metric_names), figsize=(5 * len(metric_names), 5))
    axes = [axes] if len(metric_names) == 1 else axes.flatten()

    for ax, metric_name in zip(axes, metric_names):
        for name, results in experiments.items():
            if experiment_names is not None and name not in experiment_names:
                continue

            ax.plot(config['pixels_allowed'], [e[metric_name][-1] for e in results], label=name, alpha=.7)

        if not ax.lines:
            print(f"No data found for metric: {metric_name}")

        ax.set_xlabel('# Pixels Allowed (log scale)')
        ax.set_xscale('log')
        ax.set_title(metric_name.title())
        ax.grid()

        ax.set_xticks(config['pixels_allowed'])
        ax.set_xticklabels(config['pixels_allowed'])
        ax.minorticks_off()

        if metric_name == 'accuracy':
            ax.set_yticks([.1*i for i in range(11)] + [.93])

    # only one legend
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper left', bbox_to_anchor=(1, 0.6))

    plt.tight_layout()
    plt.show()

### Baseline

Train a simple linear classifier, which we'll later use to find important pixels.

In [None]:
baseline_mlp = MLP([config['in_dim'], *config['hidden_dims'], config['out_dim']])

history = train(baseline_mlp)

### How does restricting to a random set of pixels do?

Train a few baselines restricted to a **random** set of pixels

In [None]:
class PixelMLP(nn.Module):
    '''
    MLP that only uses a subset of input for prediction.
    '''

    def __init__(self, dims, pixel_indices=None):
        super().__init__()
        self.pixel_indices = pixel_indices
        self.mlp = MLP(dims)
    
    def forward(self, x):
        x = x.view(x.size(0), -1)[:,self.pixel_indices]
        return self.mlp(x)

In [None]:
for n_pixels in config['pixels_allowed']:
    pixel_indices = np.random.permutation(config['in_dim'])[:n_pixels]
    dims = [n_pixels, *config['hidden_dims'], config['out_dim']]

    pixel_mlp = PixelMLP(dims, pixel_indices)
    history = train(pixel_mlp, verbose=False)
    experiments['random'].append(history)

del pixel_mlp

In [None]:
plot()

### What if you train only on important pixels?

We can try to identify important pixels by L2 norm

In [None]:
# first layer weights
weights = baseline_mlp.linear_layers[0].weight.detach()

# L2 Norm of weights associated with each pixel
weight_norms = (weights**2).sum(axis=0)**.5

# sort by weight_magnitudes
important_pixels = torch.argsort(weight_norms, descending=True)

In [None]:
plt.imshow(weight_norms.view(28, 28), aspect='auto')
plt.colorbar()
# plt.title('L2 Norm of First Layer Weights for MLP')
plt.axis('off')
plt.show()

Now to train models on important pixels and compare

In [None]:
for n_pixels in config['pixels_allowed']:
    pixel_indices = important_pixels[:n_pixels]
    dims = [n_pixels, *config['hidden_dims'], config['out_dim']]

    pixel_mlp = PixelMLP(dims, pixel_indices)
    history = train(pixel_mlp, verbose=False)
    experiments['L2'].append(history)

del pixel_mlp

In [None]:
plot()

### Can we do even better by identifying "orthogonal" important pixels?

In [None]:
def sparse_train(
    model, verbose=True, optimizer = None, loss_fn = None,
    lasso_weight = .05,
    lasso_final = config['in_dim'],
    warmup_wait=.1,
    final_wait=.1,
    ):

    model.to(device)
    model.train()

    if optimizer is None:
        optimizer = optim.Adam(model.parameters(), lr=config['lr'])
    if loss_fn is None:
        loss_fn = nn.CrossEntropyLoss()
    
    history = defaultdict(list)

    for step, (x, y) in enumerate(loop(train_loader, config['max_iters'])):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        ## lasso loss start
        norm = model.linear_layers[0].weight.pow(2).sum(dim=0).sqrt()

        # pixels to anneal to - redundant, don't care
        final_proportion = (config['in_dim']-lasso_final) / config['in_dim']

        # anneal between warmup_wait and final_wait
        effective_step = step - (config['max_iters'] * warmup_wait)
        anneal_duration = config['max_iters'] * (1 - final_wait - warmup_wait)
        anneal = max(min(effective_step / anneal_duration, 1), 0) * final_proportion
        accept_threshold = norm.quantile(anneal).item()

        lasso_loss = torch.clip(norm, 0, accept_threshold).sum() * lasso_weight

        loss = loss_fn(model(x), y) + lasso_loss + (lasso_weight * .2 * norm.sum())

        loss.backward()
        optimizer.step()

        if (step + 1) % config['eval_iters'] == 0:
            metrics = eval(model, loss_fn)

            if verbose:
                print(f'step {step+1:{len(str(config["max_iters"]))}d}:', end=' ')
                print(' | '.join([f"{metric}: {value:.4f}" for metric, value in metrics.items()]))

            for metric, value in metrics.items():
                history[metric].append(value)
            
            history['steps'].append(step)
    return dict(history)

In [None]:
# demo of sparse annealing, plot important weights by l2 for pixels=16
sparse_mlp = MLP([config['in_dim'], *config['hidden_dims'], config['out_dim']])
history = sparse_train(sparse_mlp, lasso_final=16)

weights = sparse_mlp.linear_layers[0].weight.detach()
weight_norms = (weights**2).sum(axis=0)**.5
important_pixels = torch.argsort(weight_norms, descending=True)

plt.imshow(weight_norms.view(28, 28), aspect='auto')
plt.colorbar()
# plt.title('L2 Norm of First Layer Weights for Sparse MLP')
plt.axis('off')
plt.show()

In [None]:
for n_pixels in config['pixels_allowed']:
    # train sparse mlp
    sparse_mlp = MLP([config['in_dim'], *config['hidden_dims'], config['out_dim']])
    sparse_train(sparse_mlp, lasso_final=n_pixels, verbose=False)

    # compute important pixels from first layer of sparse mlp
    weights = sparse_mlp.linear_layers[0].weight.detach()
    weight_norms = (weights**2).sum(axis=0)**.5
    important_pixels = torch.argsort(weight_norms, descending=True)

    # train pixel mlp
    dims = [n_pixels, *config['hidden_dims'], config['out_dim']]
    pixel_mlp = PixelMLP(dims, important_pixels[:n_pixels])
    history = train(pixel_mlp, verbose=False)
    experiments['sparse'].append(history)

del pixel_mlp

In [None]:
plot()

### plot mnist examples (via gpt)

In [None]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
mnist_data = datasets.MNIST(root='data/mnist', train=True, download=True, transform=transform)

In [None]:
def plot_mnist_examples(data, examples_per_digit=5):
    digits = list(range(10))
    fig, axes = plt.subplots(examples_per_digit, len(digits), figsize=(20, 10), dpi=400)
    
    # fig.suptitle('MNIST Digits', fontsize=16)
    
    for digit in digits:
        digit_indices = [i for i, (img, label) in enumerate(data) if label == digit]
        for i in range(examples_per_digit):
            ax = axes[i, digit]
            img, _ = data[digit_indices[i]]
            ax.imshow(img.squeeze(), cmap='gray', interpolation='none')
            ax.axis('off')
            if i == 0:
                ax.set_title(str(digit), fontsize=14)
    
    # plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

plot_mnist_examples(mnist_data)