<a href="https://colab.research.google.com/github/slucey-cs-cmu-edu/RVSS26/blob/main/Classification_TokenMix_simple.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MNIST Classification with Tokenization + Separable Mixing (U across tokens, V within tokens)

This notebook mirrors the fully-connected MLP notebook, but replaces each dense hidden layer on the flattened image with:

1. **Tokenization**: split the 28×28 image into non-overlapping **p×p** patches (tokens).
2. **Separable mixing layer** (repeated `depth` times):
   - $U$ mixes **across tokens**
   - $V$ mixes **within tokens**

Mathematically, if $X ∈ R^{N×D}$ is the token matrix (N tokens, D channels per token), a layer is:

$$X \leftarrow \eta( U X V )$$

where $U ∈ R^{N×N}$ and $V ∈ R^{D×D}$.

At the end, we print parameter counts so you can compare against the fully-connected baseline.


## 1. Imports

In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm.auto import tqdm

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# Simple reproducibility (optional)
torch.manual_seed(0)

## 2. Device

In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from tqdm.auto import tqdm

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

# Simple reproducibility (optional)
torch.manual_seed(0)

## 3. Data

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean/std
])

full_train = datasets.MNIST(root="data", train=True, download=True, transform=transform)
test_set   = datasets.MNIST(root="data", train=False, download=True, transform=transform)

train_set, val_set = random_split(full_train, [50_000, 10_000], generator=torch.Generator().manual_seed(0))

batch_size = 128
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, num_workers=0)

print("train:", len(train_set), "val:", len(val_set), "test:", len(test_set))

## 3.5 Visualize tokenization (patches)

Before training, it helps to **see** what tokenization is doing.

We take a few MNIST images and show:
1. the original image
2. the same image with a patch grid overlay
3. the extracted patches laid out in a grid (this is the token matrix reshaped back into 2D patches)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

# ----------------------------
# Self-contained patchify
# ----------------------------
def patchify_mnist_local(x, p: int):
    B, C, H, W = x.shape
    assert C == 1 and H == 28 and W == 28
    patches = x.unfold(2, p, p).unfold(3, p, p)
    Hp, Wp = patches.size(2), patches.size(3)
    return patches.contiguous().view(B, Hp * Wp, p * p), Hp, Wp

# Use model patch size if defined later, otherwise default
p_to_show = patch_size if "patch_size" in globals() else 7
print("Using patch size for visualization:", p_to_show)

def show_patch_grid(img_1x28x28, p: int, gap: int = 2):
    if img_1x28x28.dim() == 3:
        img = img_1x28x28[0].cpu().numpy()
    else:
        img = img_1x28x28.cpu().numpy()

    img_t = torch.tensor(img).unsqueeze(0).unsqueeze(0)
    X, Hp, Wp = patchify_mnist_local(img_t, p=p)
    patches = X[0].reshape(Hp, Wp, p, p).cpu().numpy()

    fig = plt.figure(figsize=(10, 3))

    # (1) original
    ax1 = plt.subplot(1, 3, 1)
    ax1.imshow(img, cmap="gray")
    ax1.set_title(f"Original (p={p})")
    ax1.axis("off")

    # (2) grid overlay
    ax2 = plt.subplot(1, 3, 2)
    ax2.imshow(img, cmap="gray")
    for k in range(0, 29, p):
        ax2.axhline(k - 0.5, linewidth=1)
        ax2.axvline(k - 0.5, linewidth=1)
    ax2.set_title("Patch grid overlay")
    ax2.axis("off")

    # (3) patches separated by gaps
    Hm = Hp * p + (Hp - 1) * gap
    Wm = Wp * p + (Wp - 1) * gap
    mosaic = np.ones((Hm, Wm))  # white background

    for i in range(Hp):
        for j in range(Wp):
            r0 = i * (p + gap)
            c0 = j * (p + gap)
            mosaic[r0:r0+p, c0:c0+p] = patches[i, j]

    ax3 = plt.subplot(1, 3, 3)
    ax3.imshow(mosaic, cmap="gray")
    ax3.set_title("Tokens shown separately")
    ax3.axis("off")

    plt.tight_layout()
    plt.show()

# Show a few examples
x_batch, y_batch = next(iter(train_loader))
for i in range(3):
    print(f"Label: {int(y_batch[i])}")
    show_patch_grid(x_batch[i], p=p_to_show, gap=2)


## 4. Baseline: Fully Connected MLP (for comparison)

In [None]:
class MLP(nn.Module):
    def __init__(self, hidden=256):
        super().__init__()
        self.fc0 = nn.Linear(28*28, hidden)
        self.fc1 = nn.Linear(hidden, hidden)
        self.fc2 = nn.Linear(hidden, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)      # flatten: (B, 1, 28, 28) -> (B, 784)
        x = F.relu(self.fc0(x))
        x = F.relu(self.fc1(x))
        logits = self.fc2(x)
        return logits

net = MLP(hidden=256).to(device)
print(net)

## 5. Tokenization + Separable Mixing Model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def patchify_mnist(x, p: int):
    # x: (B,1,28,28) -> X: (B,N,D) with N=(28/p)^2, D=p^2
    B, C, H, W = x.shape
    assert C == 1 and H == 28 and W == 28, "Expected MNIST shape (B,1,28,28)"
    assert H % p == 0 and W % p == 0, "Patch size must divide 28"
    patches = x.unfold(2, p, p).unfold(3, p, p)   # (B,1,Hp,Wp,p,p)
    Hp, Wp = patches.size(2), patches.size(3)
    N = Hp * Wp
    D = p * p
    return patches.contiguous().view(B, N, D)

def left_mul_tokens(X, U):
    # X: (B,N,D), U: (N,N)  ->  U @ X  (mix across tokens)
    B, N, D = X.shape
    X = X.transpose(1, 2).reshape(B * D, N)  # (B·D, N)
    X = X @ U.T                              # (B·D, N)
    return X.reshape(B, D, N).transpose(1, 2)

def right_mul_channels(X, V):
    # X: (B,N,D), V: (D,D)  ->  X @ V  (mix within tokens)
    B, N, D = X.shape
    X = X.reshape(B * N, D)                  # (B·N, D)
    X = X @ V                                # (B·N, D)
    return X.reshape(B, N, D)

class TokenMixLayer(nn.Module):
    def __init__(self, N, D):
        super().__init__()
        # U mixes across tokens, V mixes within tokens
        self.U = nn.Parameter(torch.randn(N, N) * (1.0 / (N ** 0.5)))
        self.V = nn.Parameter(torch.randn(D, D) * (1.0 / (D ** 0.5)))

    def forward(self, X):
        X = left_mul_tokens(X, self.U)
        X = right_mul_channels(X, self.V)
        return F.relu(X)

class TokenMixNet(nn.Module):
    def __init__(self, patch_size=7, depth=2, num_classes=10):
        super().__init__()
        p = patch_size
        self.p = p
        self.N = (28 // p) * (28 // p)  # number of tokens
        self.D = p * p                  # channels per token (flattened patch)
        self.layers = nn.ModuleList([TokenMixLayer(self.N, self.D) for _ in range(depth)])
        self.fc = nn.Linear(self.N * self.D, num_classes)

    def forward(self, x):
        X = patchify_mnist(x, self.p)  # (B,N,D)
        for layer in self.layers:
            X = layer(X)
        X = X.reshape(x.size(0), -1)   # (B, N*D) == (B, 784)
        return self.fc(X)


## 6. Parameter counting (so you can compare models)

In [None]:
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def pretty_int(n):
    return f"{n:,}"


## 7. Choose model + hyperparameters

In [None]:
# Choose settings
epochs = 5
lr = 1e-3

# Baseline MLP settings (matches the original fully-connected notebook)
mlp_hidden_dim = 256

# TokenMix settings
patch_size = 4   # try 7 (16 tokens) or 4 (49 tokens)
token_depth = 4  # number of TokenMix layers

# Create models
mlp = MLP(hidden=mlp_hidden_dim).to(device)
token_net = TokenMixNet(patch_size=patch_size, depth=token_depth).to(device)

print("Baseline MLP params:", pretty_int(count_params(mlp)))
print("TokenMixNet params:", pretty_int(count_params(token_net)))


## 8. Training and evaluation

In [None]:
import torch.optim as optim

def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

@torch.no_grad()
def eval_accuracy(model, loader):
    model.eval()
    correct = 0
    total = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.numel()
    return correct / total


## 9. Train Baseline MLP

In [None]:
mlp_opt = optim.Adam(mlp.parameters(), lr=lr)

for epoch in range(1, epochs + 1):
    loss = train_one_epoch(mlp, train_loader, mlp_opt)
    acc = eval_accuracy(mlp, test_loader)
    print(f"Epoch {epoch:02d} | loss={loss:.4f} | test_acc={acc*100:.2f}%")


## 10. Train TokenMixNet

In [None]:
token_opt = optim.Adam(token_net.parameters(), lr=lr)

for epoch in range(1, epochs + 1):
    loss = train_one_epoch(token_net, train_loader, token_opt)
    acc = eval_accuracy(token_net, test_loader)
    print(f"Epoch {epoch:02d} | loss={loss:.4f} | test_acc={acc*100:.2f}%")


## 11. Final parameter comparison

In [None]:
print("Baseline MLP params:", pretty_int(count_params(mlp)))
print("TokenMixNet params:", pretty_int(count_params(token_net)))

print("\nTip: Try changing `mlp_hidden_dim`, `token_depth`, and `patch_size` (7 vs 4) and rerun.")
