### Implementing a simple shift-invariant classifier

In [284]:
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from torch.optim import SGD
import matplotlib.pyplot as plt
import numpy as np

In [285]:
transforms = transforms.Compose([
    transforms.Resize((14, 14)),
    transforms.ToTensor()
])

train_data = FashionMNIST(root='./data', train=True, transform=transforms, download=True)
test_data = FashionMNIST(root='./data', train=False, transform=transforms, download=True)

In [286]:
train_data[0][0].shape

torch.Size([1, 14, 14])

In [287]:
# Create a very simple classifier that is just tokenization->embedding->MLP for classification
class Classifier(nn.Module):
    def __init__(self, patch_size, n_emb, num_classes, hidden_size):
        super().__init__()
        self.patch_size = patch_size
        self.n_emb = n_emb
        self.flatten = nn.Flatten()
        self.e_proj = nn.Linear(patch_size, n_emb)
        self.head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Linear(hidden_size * 4, num_classes),
        )
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, x):
        x = self.flatten(x)
        B, N = x.shape
        patches = x.reshape(B, N // self.patch_size, self.patch_size)
        tokens = self.e_proj(patches)
        tokens = tokens.mean(dim=-1) # avg global pooling
        logits = self.head(tokens)
        return logits
    def loss(self, x, y):
        logits = self.forward(x)
        return self.loss_fn(logits, y)
    def accuracy(self, x, y):
        logits = self(x)
        probs = torch.softmax(logits, dim=-1)
        y_hat = torch.argmax(probs, dim=-1)
        return (y_hat == y).float().mean()

In [290]:
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False)
patch_size = 4
model = Classifier(
    patch_size=patch_size,
    n_emb=32,
    num_classes=10,
    hidden_size=(14*14)//patch_size
)
optimizer = SGD(model.parameters(), lr=0.1)
train_losses = []
val_losses = []
val_accuracy = []
for epoch in range(50):
    model.train()
    running_loss = 0
    for input, label in train_dataloader:
        # print(input.shape, label.shape)
        optimizer.zero_grad()
        loss = model.loss(input, label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_train_loss = running_loss / len(train_dataloader)
    model.eval()
    running_loss = 0
    val_acc = 0
    for input, label in test_dataloader:
        loss = model.loss(input, label)
        val_acc += model.accuracy(input, label)
        running_loss += loss.item()
    val_accuracy.append(val_acc / len(test_dataloader))
    avg_val_loss = running_loss / len(test_dataloader)
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    print(f"Epoch {epoch:02d}, train loss = {avg_train_loss:.5f}, val_loss = {avg_val_loss:.5f}, val_acc = {val_accuracy[-1]:.5f}")

Epoch 00, train loss = 1.25036, val_loss = 0.66491, val_acc = 0.75399
Epoch 01, train loss = 0.58810, val_loss = 0.56244, val_acc = 0.79503
Epoch 02, train loss = 0.51601, val_loss = 0.52691, val_acc = 0.80351
Epoch 03, train loss = 0.47678, val_loss = 0.48850, val_acc = 0.82009
Epoch 04, train loss = 0.44947, val_loss = 0.47636, val_acc = 0.82788
Epoch 05, train loss = 0.42977, val_loss = 0.45234, val_acc = 0.83536
Epoch 06, train loss = 0.41401, val_loss = 0.44127, val_acc = 0.83946
Epoch 07, train loss = 0.40164, val_loss = 0.43346, val_acc = 0.84425
Epoch 08, train loss = 0.39131, val_loss = 0.42336, val_acc = 0.84675
Epoch 09, train loss = 0.38188, val_loss = 0.42639, val_acc = 0.84794


In [291]:
def get_val_acc(model, shift=False):
    val_acc = 0
    for input, label in test_dataloader:
        B = input.shape[0]
        if shift:
            offset = torch.randint(1, 5, (1,))
            input = torch.roll(input.view(B, -1), shifts=offset.item(), dims=1).view(input.shape)
        val_acc += model.accuracy(input, label)
    return val_acc / len(test_dataloader)

print("val accuracy: ", get_val_acc(model))
print("randomly shifted val accuracy:", get_val_acc(model, shift=True))

tensor(0.8479)
tensor(0.2812)


In [292]:
# Create a very simple classifier that is just tokenization->embedding->MLP for classification
# class AdaptiveClassifier(nn.Module):
#     def __init__(self, patch_size, n_emb, num_classes, hidden_size):
#         super().__init__()
#         self.patch_size = patch_size
#         self.n_emb = n_emb
#         self.flatten = nn.Flatten()
#         self.e_proj = nn.Linear(patch_size, n_emb)
#         self.head = nn.Sequential(
#             nn.Linear(hidden_size, hidden_size * 4),
#             nn.GELU(),
#             nn.Linear(hidden_size * 4, num_classes),
#         )
#         self.loss_fn = nn.CrossEntropyLoss()
#     def forward(self, x):
#         x = self.flatten(x)
#         B, N = x.shape
#         # F = l2 norm, calculate m_star
#         with torch.no_grad():
#             # For each batch, try each roll from 0..N-1            
#             norms = []
#             for m in range(N):
#                 xm = torch.roll(x, shifts=m, dims=1)
#                 patches = xm.reshape(B, N // self.patch_size, self.patch_size)
#                 tokens = self.e_proj(patches)
#                 norms.append(tokens.norm(dim=(1, 2)))
#             norms = torch.stack(norms)
#             m_star = norms.argmax(dim=0)
#             # print(m_star.shape)
#         for i in range(B):
#             x[i] = torch.roll(x[i], shifts=m_star[i].item())
#         patches = x.reshape(B, N // self.patch_size, self.patch_size)
#         tokens = self.e_proj(patches)
#         tokens = tokens.mean(dim=-1) # avg global pooling
#         logits = self.head(tokens)
#         return logits
#     def loss(self, x, y):
#         logits = self.forward(x)
#         return self.loss_fn(logits, y)
#     def accuracy(self, x, y):
#         logits = self(x)
#         probs = torch.softmax(logits, dim=-1)
#         y_hat = torch.argmax(probs, dim=-1)
#         return (y_hat == y).float().mean()

In [293]:
# Create a very simple classifier that is just tokenization->embedding->MLP for classification
class AdaptiveClassifier(nn.Module):
    def __init__(self, patch_size, n_emb, num_classes, hidden_size):
        super().__init__()
        self.patch_size = patch_size
        self.n_emb = n_emb
        self.flatten = nn.Flatten()
        self.e_proj = nn.Linear(patch_size, n_emb)
        self.head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Linear(hidden_size * 4, num_classes),
        )
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, x):
        x = self.flatten(x)
        B, N = x.shape
        # F = l2 norm, calculate m_star
        with torch.no_grad():
            # For each batch, try each roll from 0..N-1            
            roll_idxs = (torch.arange(N)[:, None] - torch.arange(N)) % N
            shifts = torch.stack([x[i, roll_idxs] for i in range(B)]).mT
            patches = shifts.view(B, N, N // self.patch_size, self.patch_size)
            tokens = self.e_proj(patches)
            norms = tokens.norm(dim=(2, 3))
            m_star = norms.argmax(dim=1)
            # print(norms.shape, m_star.shape)
        for i in range(B):
            x[i] = torch.roll(x[i], shifts=m_star[i].item())
        patches = x.reshape(B, N // self.patch_size, self.patch_size)
        tokens = self.e_proj(patches)
        tokens = tokens.mean(dim=-1) # avg global pooling
        logits = self.head(tokens)
        return logits
    def loss(self, x, y):
        logits = self.forward(x)
        return self.loss_fn(logits, y)
    def accuracy(self, x, y):
        logits = self(x)
        probs = torch.softmax(logits, dim=-1)
        y_hat = torch.argmax(probs, dim=-1)
        return (y_hat == y).float().mean()

In [297]:
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=False)
patch_size = 4
adaptive_model = AdaptiveClassifier(
    patch_size=patch_size,
    n_emb=32,
    num_classes=10,
    hidden_size=(14*14)//patch_size
)
optimizer = SGD(adaptive_model.parameters(), lr=0.1)
train_losses = []
val_losses = []
val_accuracy = []
for epoch in range(10):
    adaptive_model.train()
    running_loss = 0
    for input, label in train_dataloader:
        # print(input.shape, label.shape)
        optimizer.zero_grad()
        loss = adaptive_model.loss(input, label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_train_loss = running_loss / len(train_dataloader)
    adaptive_model.eval()
    running_loss = 0
    val_acc = 0
    for input, label in test_dataloader:
        loss = adaptive_model.loss(input, label)
        val_acc += adaptive_model.accuracy(input, label)
        running_loss += loss.item()
    val_accuracy.append(val_acc / len(test_dataloader))
    avg_val_loss = running_loss / len(test_dataloader)
    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    print(f"Epoch {epoch:02d}, train loss = {avg_train_loss:.5f}, val_loss = {avg_val_loss:.5f}, val_acc = {val_accuracy[-1]:.5f}")

Epoch 00, train loss = 1.76255, val_loss = 1.20937, val_acc = 0.54203
Epoch 01, train loss = 1.11391, val_loss = 1.07206, val_acc = 0.59665
Epoch 02, train loss = 0.99718, val_loss = 0.97009, val_acc = 0.62909
Epoch 03, train loss = 0.93370, val_loss = 0.91120, val_acc = 0.66104
Epoch 04, train loss = 0.89270, val_loss = 0.87649, val_acc = 0.66883
Epoch 05, train loss = 0.85919, val_loss = 0.85035, val_acc = 0.67712
Epoch 06, train loss = 0.83117, val_loss = 0.83843, val_acc = 0.68620
Epoch 07, train loss = 0.80866, val_loss = 0.84389, val_acc = 0.68061
Epoch 08, train loss = 0.79503, val_loss = 0.79964, val_acc = 0.70088
Epoch 09, train loss = 0.78123, val_loss = 0.81200, val_acc = 0.70288


In [298]:
print(get_val_acc(adaptive_model))
print(get_val_acc(adaptive_model, shift=True))

tensor(0.7029)
tensor(0.6992)


In [128]:
a = torch.randn(5, 5)
a1 = torch.roll(a, 2)
b = torch.randn(5, 5)
c = a @ b
c1 = a1 @ b
a.norm(), b.norm(), c.norm(), c1.norm()

(tensor(4.1601), tensor(4.5330), tensor(7.0751), tensor(7.1127))

In [207]:
a = torch.randn(2, 4, 1)
print(a)
print(a.expand(2, 4, 4))
rolls = torch.arange(4)
idxs = (rolls[:, None] - rolls) % 4
print(idxs)
print(torch.randn(4)[idxs])


tensor([[[ 0.3131],
         [-0.1254],
         [-0.4524],
         [ 0.2357]],

        [[ 1.1113],
         [-0.3757],
         [ 0.0120],
         [ 0.2193]]])
tensor([[[ 0.3131,  0.3131,  0.3131,  0.3131],
         [-0.1254, -0.1254, -0.1254, -0.1254],
         [-0.4524, -0.4524, -0.4524, -0.4524],
         [ 0.2357,  0.2357,  0.2357,  0.2357]],

        [[ 1.1113,  1.1113,  1.1113,  1.1113],
         [-0.3757, -0.3757, -0.3757, -0.3757],
         [ 0.0120,  0.0120,  0.0120,  0.0120],
         [ 0.2193,  0.2193,  0.2193,  0.2193]]])
tensor([[0, 3, 2, 1],
        [1, 0, 3, 2],
        [2, 1, 0, 3],
        [3, 2, 1, 0]])
tensor([[-0.6804, -1.2364,  0.0966, -0.4540],
        [-0.4540, -0.6804, -1.2364,  0.0966],
        [ 0.0966, -0.4540, -0.6804, -1.2364],
        [-1.2364,  0.0966, -0.4540, -0.6804]])


In [222]:
a = torch.randn(4, 4)
print(a)
a[idxs].transpose(0, 1)

tensor([[ 1.4199, -0.8493, -0.4666,  0.3999],
        [-1.0931, -1.1288, -0.6471,  0.9726],
        [-1.3034, -1.3349,  0.2783,  0.7967],
        [-0.0505,  0.2216,  0.5465, -0.6069]])


tensor([[[ 1.4199, -0.8493, -0.4666,  0.3999],
         [-1.0931, -1.1288, -0.6471,  0.9726],
         [-1.3034, -1.3349,  0.2783,  0.7967],
         [-0.0505,  0.2216,  0.5465, -0.6069]],

        [[-0.0505,  0.2216,  0.5465, -0.6069],
         [ 1.4199, -0.8493, -0.4666,  0.3999],
         [-1.0931, -1.1288, -0.6471,  0.9726],
         [-1.3034, -1.3349,  0.2783,  0.7967]],

        [[-1.3034, -1.3349,  0.2783,  0.7967],
         [-0.0505,  0.2216,  0.5465, -0.6069],
         [ 1.4199, -0.8493, -0.4666,  0.3999],
         [-1.0931, -1.1288, -0.6471,  0.9726]],

        [[-1.0931, -1.1288, -0.6471,  0.9726],
         [-1.3034, -1.3349,  0.2783,  0.7967],
         [-0.0505,  0.2216,  0.5465, -0.6069],
         [ 1.4199, -0.8493, -0.4666,  0.3999]]])

In [145]:
torch.randn(2, 4, 1).repeat((1, 1, 2))

tensor([[[ 0.8327,  0.8327],
         [ 0.2480,  0.2480],
         [ 1.3505,  1.3505],
         [ 0.8663,  0.8663]],

        [[-0.6404, -0.6404],
         [-0.0786, -0.0786],
         [-2.1338, -2.1338],
         [ 0.1364,  0.1364]]])