### Implementing a simple shift-invariant classifier

Expriments from the paper "Making Vision Transformers Truly Shift-Equivariant".\
https://arxiv.org/abs/2305.16316\

Some questions I had after reading the paper
1. Why would making the ViT adaptive improve raw classification accuracy?
2. Why do we need to make each module shift invariant? Why can't we just make the tokenization shift-invariant? Doesn't making the tokenization shift invariant mean that the rest of the model will see the same tokens no matter how you shift the input?
3. What is the tradeoff between building the shift equivariance into the architecture, and creating a training set that is sufficiently shifted around / noised. Could you get similar with no change to model architecture results by simply training on data that contains a lot more shifts?


In [352]:
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from torch.optim import SGD
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset

In [326]:
transforms = transforms.Compose([
    transforms.Resize((14, 14)),
    transforms.ToTensor()
])

train_data = FashionMNIST(root='./data', train=True, transform=transforms, download=True)
test_data = FashionMNIST(root='./data', train=False, transform=transforms, download=True)

In [327]:
train_data[0][0].shape

torch.Size([1, 14, 14])

In [307]:
# Create a very simple classifier that is just tokenization->embedding->MLP for classification
class Classifier(nn.Module):
    def __init__(self, patch_size, n_emb, num_classes, hidden_size):
        super().__init__()
        self.patch_size = patch_size
        self.n_emb = n_emb
        self.flatten = nn.Flatten()
        self.e_proj = nn.Linear(patch_size, n_emb)
        self.head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Linear(hidden_size * 4, num_classes),
        )
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, x):
        x = self.flatten(x)
        B, N = x.shape
        patches = x.reshape(B, N // self.patch_size, self.patch_size)
        tokens = self.e_proj(patches)
        tokens = tokens.mean(dim=-1) # avg global pooling
        logits = self.head(tokens)
        return logits
    def loss(self, x, y):
        logits = self.forward(x)
        return self.loss_fn(logits, y)
    def accuracy(self, x, y):
        logits = self(x)
        probs = torch.softmax(logits, dim=-1)
        y_hat = torch.argmax(probs, dim=-1)
        return (y_hat == y).float().mean()

In [308]:
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False)
patch_size = 4
model = Classifier(
    patch_size=patch_size,
    n_emb=32,
    num_classes=10,
    hidden_size=(14*14)//patch_size
)
optimizer = SGD(model.parameters(), lr=0.1)
train_losses = []
val_losses = []
val_accuracy = []
for epoch in range(50):
    model.train()
    running_loss = 0
    for input, label in train_dataloader:
        # print(input.shape, label.shape)
        optimizer.zero_grad()
        loss = model.loss(input, label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_train_loss = running_loss / len(train_dataloader)
    model.eval()
    running_loss = 0
    val_acc = 0
    for input, label in test_dataloader:
        loss = model.loss(input, label)
        val_acc += model.accuracy(input, label)
        running_loss += loss.item()
    val_accuracy.append(val_acc / len(test_dataloader))
    avg_val_loss = running_loss / len(test_dataloader)
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    print(f"Epoch {epoch:02d}, train loss = {avg_train_loss:.5f}, val_loss = {avg_val_loss:.5f}, val_acc = {val_accuracy[-1]:.5f}")

Epoch 00, train loss = 1.03559, val_loss = 0.63789, val_acc = 0.77406
Epoch 01, train loss = 0.58487, val_loss = 0.57682, val_acc = 0.78494
Epoch 02, train loss = 0.52111, val_loss = 0.51925, val_acc = 0.81130
Epoch 03, train loss = 0.48312, val_loss = 0.49681, val_acc = 0.82069
Epoch 04, train loss = 0.45684, val_loss = 0.47746, val_acc = 0.82318
Epoch 05, train loss = 0.43632, val_loss = 0.45448, val_acc = 0.83716
Epoch 06, train loss = 0.42048, val_loss = 0.45153, val_acc = 0.83476
Epoch 07, train loss = 0.40782, val_loss = 0.43851, val_acc = 0.83976
Epoch 08, train loss = 0.39626, val_loss = 0.42246, val_acc = 0.85044
Epoch 09, train loss = 0.38539, val_loss = 0.41569, val_acc = 0.85144
Epoch 10, train loss = 0.37869, val_loss = 0.41659, val_acc = 0.85184
Epoch 11, train loss = 0.36950, val_loss = 0.40852, val_acc = 0.85443
Epoch 12, train loss = 0.36433, val_loss = 0.41353, val_acc = 0.85294
Epoch 13, train loss = 0.35869, val_loss = 0.39907, val_acc = 0.85803
Epoch 14, train loss

In [311]:
# Create a very simple classifier that is just tokenization->embedding->MLP for classification
class AdaptiveClassifier(nn.Module):
    def __init__(self, patch_size, n_emb, num_classes, hidden_size):
        super().__init__()
        self.patch_size = patch_size
        self.n_emb = n_emb
        self.flatten = nn.Flatten()
        self.e_proj = nn.Linear(patch_size, n_emb)
        self.head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Linear(hidden_size * 4, num_classes),
        )
        self.loss_fn = nn.CrossEntropyLoss()
        N = 14*14
        self.roll_idxs = (torch.arange(N)[:, None] - torch.arange(N)) % N
    def forward(self, x):
        x = self.flatten(x)
        B, N = x.shape
        # F = l2 norm, calculate m_star
        with torch.no_grad():
            # For each batch, try each roll from 0..N-1            
            shifts = torch.stack([x[i, self.roll_idxs] for i in range(B)]).mT
            patches = shifts.view(B, N, N // self.patch_size, self.patch_size)
            tokens = self.e_proj(patches)
            norms = tokens.norm(dim=(2, 3))
            m_star = norms.argmax(dim=1)
            # print(norms.shape, m_star.shape)
        for i in range(B):
            x[i] = torch.roll(x[i], shifts=m_star[i].item())
        patches = x.reshape(B, N // self.patch_size, self.patch_size)
        tokens = self.e_proj(patches)
        tokens = tokens.mean(dim=-1) # avg global pooling
        logits = self.head(tokens)
        return logits
    def loss(self, x, y):
        logits = self.forward(x)
        return self.loss_fn(logits, y)
    def accuracy(self, x, y):
        logits = self(x)
        probs = torch.softmax(logits, dim=-1)
        y_hat = torch.argmax(probs, dim=-1)
        return (y_hat == y).float().mean()

In [312]:
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False)
patch_size = 4
adaptive_model = AdaptiveClassifier(
    patch_size=patch_size,
    n_emb=32,
    num_classes=10,
    hidden_size=(14*14)//patch_size
)
optimizer = SGD(adaptive_model.parameters(), lr=0.1)
train_losses = []
val_losses = []
val_accuracy = []
for epoch in range(50):
    adaptive_model.train()
    running_loss = 0
    for input, label in train_dataloader:
        # print(input.shape, label.shape)
        optimizer.zero_grad()
        loss = adaptive_model.loss(input, label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_train_loss = running_loss / len(train_dataloader)
    adaptive_model.eval()
    running_loss = 0
    val_acc = 0
    for input, label in test_dataloader:
        loss = adaptive_model.loss(input, label)
        val_acc += adaptive_model.accuracy(input, label)
        running_loss += loss.item()
    val_accuracy.append(val_acc / len(test_dataloader))
    avg_val_loss = running_loss / len(test_dataloader)
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    print(f"Epoch {epoch:02d}, train loss = {avg_train_loss:.5f}, val_loss = {avg_val_loss:.5f}, val_acc = {val_accuracy[-1]:.5f}")

Epoch 00, train loss = 2.07253, val_loss = 1.41390, val_acc = 0.46396
Epoch 01, train loss = 1.17637, val_loss = 1.07518, val_acc = 0.59165
Epoch 02, train loss = 1.02649, val_loss = 1.00768, val_acc = 0.62720
Epoch 03, train loss = 0.95728, val_loss = 0.95449, val_acc = 0.63269
Epoch 04, train loss = 0.91092, val_loss = 0.89455, val_acc = 0.66014
Epoch 05, train loss = 0.87358, val_loss = 0.87685, val_acc = 0.67332
Epoch 06, train loss = 0.84198, val_loss = 0.86405, val_acc = 0.68311
Epoch 07, train loss = 0.81698, val_loss = 0.80702, val_acc = 0.69319
Epoch 08, train loss = 0.79864, val_loss = 0.81607, val_acc = 0.69239
Epoch 09, train loss = 0.78502, val_loss = 0.79122, val_acc = 0.70707
Epoch 10, train loss = 0.76837, val_loss = 0.78953, val_acc = 0.70966
Epoch 11, train loss = 0.76013, val_loss = 0.78583, val_acc = 0.71306
Epoch 12, train loss = 0.74701, val_loss = 0.78499, val_acc = 0.71446
Epoch 13, train loss = 0.73917, val_loss = 0.75675, val_acc = 0.72943
Epoch 14, train loss

In [357]:
def get_val_acc(model, shift=False):
    val_acc = 0
    for input, label in test_dataloader:
        B = input.shape[0]
        if shift:
            offset = torch.randint(1, 5, (1,))
            input = torch.roll(input.view(B, -1), shifts=offset.item(), dims=1).view(input.shape)
        val_acc += model.accuracy(input, label)
    return val_acc / len(test_dataloader)

print("Classifier:")
print("val accuracy: ", get_val_acc(model))
print("shifted val accuracy:", get_val_acc(model, shift=True))
print("---")
print("Classifier w/ adaptive tokenization:")
print("val accuracy: ", get_val_acc(adaptive_model))
print("shifted val accuracy:", get_val_acc(adaptive_model, shift=True))

Classifier:
val accuracy:  tensor(0.8674)
shifted val accuracy: tensor(0.2254)
---
Classifier w/ adaptive tokenization:
val accuracy:  tensor(0.7627)
shifted val accuracy: tensor(0.7549)


Adaptive tokenization makes the classifier robust as expected.\
Now I want to try training a non-adaptive classifier with augmented input data to see if it can become similarly robust without architecture changes.

In [353]:
# Creating a new training dataset with every shift (0-5) for every example

new_X, new_y = [], []
for input, label in train_dataloader:
    B, C, H, W = input.shape
    input = input.view(B, -1)
    for r in range(5):
        new_X.append(torch.roll(input, shifts=r))
        new_y.append(label)
new_X = torch.vstack(new_X).view(-1, 14, 14)
new_y = torch.stack(new_y).view(-1)

class shiftedTrainingData(Dataset):
    def __init__(self):
        self.data = new_X
        self.targets = new_y
    def __len__(self):
        return len(self.targets)
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

In [354]:
# training with augmented data
augmented_train_data = shiftedTrainingData()
train_dataloader = DataLoader(augmented_train_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False)
patch_size = 4
model2 = Classifier(
    patch_size=patch_size,
    n_emb=32,
    num_classes=10,
    hidden_size=(14*14)//patch_size
)
optimizer = SGD(model2.parameters(), lr=0.1)
train_losses = []
val_losses = []
val_accuracy = []
for epoch in range(10):
    model2.train()
    running_loss = 0
    for input, label in train_dataloader:
        # random shifting
        B = input.shape[0]
        offsets = torch.randint(0, 5, (B,))
        for i in range(len(input)):
            input[i] = torch.roll(input[i], shifts=offsets[i].item())
        optimizer.zero_grad()
        loss = model2.loss(input, label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_train_loss = running_loss / len(train_dataloader)
    model2.eval()
    running_loss = 0
    val_acc = 0
    for input, label in test_dataloader:
        loss = model2.loss(input, label)
        val_acc += model2.accuracy(input, label)
        running_loss += loss.item()
    val_accuracy.append(val_acc / len(test_dataloader))
    avg_val_loss = running_loss / len(test_dataloader)
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    print(f"Epoch {epoch:02d}, train loss = {avg_train_loss:.5f}, val_loss = {avg_val_loss:.5f}, val_acc = {val_accuracy[-1]:.5f}")

Epoch 00, train loss = 0.81551, val_loss = 0.72434, val_acc = 0.74651
Epoch 01, train loss = 0.53689, val_loss = 0.64040, val_acc = 0.76997
Epoch 02, train loss = 0.48890, val_loss = 0.63119, val_acc = 0.77077
Epoch 03, train loss = 0.46213, val_loss = 0.62897, val_acc = 0.76897
Epoch 04, train loss = 0.44527, val_loss = 0.57735, val_acc = 0.79363
Epoch 05, train loss = 0.43059, val_loss = 0.57648, val_acc = 0.79553
Epoch 06, train loss = 0.42064, val_loss = 0.58249, val_acc = 0.78325
Epoch 07, train loss = 0.41286, val_loss = 0.56082, val_acc = 0.79603
Epoch 08, train loss = 0.40455, val_loss = 0.53400, val_acc = 0.81160
Epoch 09, train loss = 0.40016, val_loss = 0.54917, val_acc = 0.80681


In [359]:
print("Classifier trained w/ augmentation:")
print("val accuracy: ", get_val_acc(model2))
print("shifted val accuracy:", get_val_acc(model2, shift=True))

Classifier trained w/ augmentation:
val accuracy:  tensor(0.8068)
shifted val accuracy: tensor(0.8197)


Training augmentation does make the unmodified model more robust.

#### Some Final Notes
- Adaptive tokenization is robust to shifting without any training as expected.
- Data augmentation can make the unmodified architecture robust to shifting.
- The adaptive model was not able to achieve as low of a training loss as the base model.
- The adaptive model also trained quite slow, could have been just been my implementation though (need to look into more tensor tricks).