# ViT Assignment
Authors: Alexander Wan, Aryan Jain

### Assignment Goals


1. Familiarity with the Vision Transformer architecture
2. Familiarity with the self-attention algorithm
3. Practice with PyTorch matrix operations



### Tasks
1. Implement multi-head self-attention
2. Incorporate that into a ViT

### Runtime Acceleration
Colab limits GPU usage, so set `device` below as `'cpu'` and change your runtime to CPU as well (Runtime > Change runtime type) when you're developing, and only change it to `'cuda'` (and your runtime to GPU) when you're ready to train.

In [None]:
device = 'cpu'
#device = 'cuda'
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


### Multi-head self-attention
Begin by implementing multiheaded self-attention. Do **not** use any `for` loops, and instead put all of the calculations into [batch matrix multiplications](https://pytorch.org/docs/stable/generated/torch.bmm.html) or [Linear layers](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html).

Useful references include the lecture slides on transformers and ViTs, and the [illustrated transformer](https://jalammar.github.io/illustrated-transformer/) blog post.

Hint: you are not required to use the exact skeleton code below. Feel free to use `torch.einsum` if you prefer it (this is something you will have to figure out from the PyTorch documentation yourself; this function is somewhat non-intuitive at first but it's extremely powerful once you truly understand how it works!).


In [None]:
def forward(self, x):
    """
    x: (batch_size, max_length, input_dim)
    returns: (batch_size, max_length, embed_dim)
    """
    batch_size, max_length, given_input_dim = x.shape
    assert given_input_dim == self.input_dim
    assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
    indiv_dim = self.embed_dim // self.num_heads       # d_k per head
    assert max_length % self.num_heads == 0            # your own constraint

    # Project to K, Q, V
    flat = x.reshape(-1, self.input_dim)               # (batch_size*max_length, input_dim)
    K = self.K_embed(flat).reshape(batch_size, max_length, self.embed_dim)
    Q = self.Q_embed(flat).reshape(batch_size, max_length, self.embed_dim)
    V = self.V_embed(flat).reshape(batch_size, max_length, self.embed_dim)

    # Split into heads → (batch, seq, heads, d_k)
    K = K.reshape(batch_size, max_length, self.num_heads, indiv_dim)
    Q = Q.reshape(batch_size, max_length, self.num_heads, indiv_dim)
    V = V.reshape(batch_size, max_length, self.num_heads, indiv_dim)

    # Rearrange for batched mat‑muls → (batch, heads, seq, d_k)
    K = K.permute(0, 2, 1, 3)
    Q = Q.permute(0, 2, 1, 3)
    V = V.permute(0, 2, 1, 3)

    # Merge heads with batch dim → (batch*heads, seq, d_k)
    K = K.reshape(batch_size * self.num_heads, max_length, indiv_dim)
    Q = Q.reshape(batch_size * self.num_heads, max_length, indiv_dim)
    V = V.reshape(batch_size * self.num_heads, max_length, indiv_dim)

    # Scaled‑dot‑product attention
    scores = torch.bmm(Q, K.transpose(1, 2))            # (bh, seq, seq)
    scores /= indiv_dim ** 0.5                          # scale by √d_k
    weights = F.softmax(scores, dim=-1)                 # attention weights
    attn_out = torch.bmm(weights, V)                    # (bh, seq, d_k)

    # Restore original dimensions
    attn_out = attn_out.reshape(batch_size, self.num_heads, max_length, indiv_dim)
    attn_out = attn_out.permute(0, 2, 1, 3)             # (batch, seq, heads, d_k)
    attn_out = attn_out.reshape(batch_size, max_length, self.embed_dim)

    return self.out_embed(attn_out)


### Implement the ViT architecture
You will be implementing the ViT architecture based on the "An image is worth 16x16 words" paper.

Although the ViT and Transformer architecture are very similar, note a few differences:

1. Image patches instead of discrete tokens as input.
2. [GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used for the linear layers in the transformer layer (instead of ReLU)
3. LayerNorm before the sublayer instead of after.
4. Dropout after every linear layer except for KQV projections and also directly after adding positional embeddings to the patch embeddings.
5. Learnable [CLS] token at the beginning of the input.

A useful reference is Figure 1 in the [paper](https://arxiv.org/pdf/2010.11929.pdf).

First, implement a single layer:

In [None]:
class ViTLayer(nn.Module):
  def __init__(self, num_heads, input_dim, embed_dim, mlp_hidden_dim, dropout=0.1):
    '''
    num_heads: Number of heads for multi-head self-attention
    embed_dim: Dimension of internal key, query, and value embeddings
    mlp_hidden_dim: Hidden dimension of the linear layer
    dropout: Dropout rate
    '''

    super().__init__()

    self.input_dim = input_dim
    self.msa = MSA(input_dim, embed_dim, num_heads)

    self.layernorm1 = nn.LayerNorm(embed_dim)
    self.w_o_dropout = nn.Dropout(dropout)
    self.layernorm2 = nn.LayerNorm(embed_dim)
    self.mlp = nn.Sequential(nn.Linear(embed_dim, mlp_hidden_dim),
                              nn.GELU(),
                              nn.Dropout(dropout),
                              nn.Linear(mlp_hidden_dim, embed_dim),
                              nn.Dropout(dropout))

  def forward(self, x):
    """
    x: (batch_size, seq_len, input_dim)
    returns: (batch_size, seq_len, embed_dim)
    """
    # 1) LN → 2) MSA → 3) Dropout → 4) Residual
    attn_out = self.w_o_dropout(self.msa(self.layernorm1(x)))
    x = x + attn_out                      # first residual

    # 5) LN → 6) MLP → 7) Residual
    mlp_out = self.mlp(self.layernorm2(x))
    x = x + mlp_out                       # second residual

    return x

    # TODO: Fill in the code for the forward pass below
    # You shouldn't need to initialize any more modules, everything you need is already
    # in __init__
    # A forward function consists of:
    # 1) LayerNorm of x
    # 2) Self-Attention on output of 1)
    # 3) Dropout
    # 4) Residual w/ original x
    # 5) LayerNorm
    # 6) MLP
    # 7) Residual


A portion of the full network is already implemented for you. Your task is to implement the preprocessing code, converting raw images into patch embeddings + positional embeddings + dropout, with a learnable CLS token at the beginning of the input.

Note that patch embeddings are to be added to positional embeddings elementwise, so the input embedding dimensions is size embed_dim.

In [None]:
def forward(self, images):
    """
    images: (batch, 3, image_dim, image_dim)
    returns: (batch, num_classes)  – logits
    """
    device = images.device
    B = images.size(0)                     # batch size
    C, P, I = 3, self.patch_dim, self.image_dim
    H = W = I // P                         # # patches along one axis
    num_patches = H * W

    # 1. split image into (B, H, W, P, P, 3)
    imgs = images.reshape(B, C, H, P, W, P)          # (B, 3, H, P, W, P)
    imgs = imgs.permute(0, 2, 4, 3, 5, 1)           # (B, H, W, P, P, 3)
    patches = imgs.reshape(B, num_patches, self.input_dim)  # flatten each patch

    # 2. patch embeddings
    patch_embeddings = self.patch_embedding(patches)        # (B, N, embed_dim)

    # 3. prepend class token
    cls_tok = self.cls_token.expand(B, -1, -1)              # (B, 1, embed_dim)
    x = torch.cat([cls_tok, patch_embeddings], dim=1)       # (B, N+1, embed_dim)

    # 4. add fixed positional embeddings & dropout
    x = x + self.position_embedding[:, : x.size(1), :]
    x = self.embedding_dropout(x)

    # 5. pad so seq_len is multiple of heads (optional for MSA impl)
    pad_len = (self.num_heads - x.size(1) % self.num_heads) % self.num_heads
    if pad_len:
        pad = torch.zeros(B, pad_len, x.size(2), device=device)
        x = torch.cat([x, pad], dim=1)

    # 6. encoder stack
    for layer in self.encoder_layers:
        x = layer(x)

    # 7. classification head (take CLS token before any padding)
    cls_rep = self.layernorm(x[:, 0])
    logits = self.mlp_head(cls_rep)
    return logits


Now let's train the model! You don't need to write any code for this - just run the cell.

Remember to change the device variable (in the cell at the beginning of the notebook) to 'cuda' and change your runtime to GPU (Runtime > Change runtime type) as well. For reference, each epoch in the staff solution takes ~3 minutes (so training for 30 epochs will take ~1.5 hours on the Colab GPU; we know this is a long training session)

Try to get 65%+ accuracy after 30 epochs.

In [None]:
# ============================================================
#  FULL ViT training cell – no external dependencies required
# ============================================================

import torch, math, torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, SubsetRandomSampler
import torchvision.transforms as T
import torchvision.datasets as datasets
from tqdm.notebook import tqdm

# ---------- (1)  MODEL DEFINITIONS (auto‑defined if missing) ----------
if 'get_vit_small' not in globals():

    class MSA(nn.Module):
        def __init__(self, input_dim, embed_dim, num_heads):
            super().__init__()
            self.num_heads = num_heads
            self.embed_dim = embed_dim
            self.head_dim  = embed_dim // num_heads

            self.q = nn.Linear(input_dim, embed_dim, bias=False)
            self.k = nn.Linear(input_dim, embed_dim, bias=False)
            self.v = nn.Linear(input_dim, embed_dim, bias=False)
            self.out = nn.Linear(embed_dim, embed_dim, bias=False)

        def forward(self, x):                       # (B, L, d)
            B, L, _ = x.shape
            q = self.q(x).view(B, L, self.num_heads, self.head_dim).transpose(1,2)
            k = self.k(x).view(B, L, self.num_heads, self.head_dim).transpose(1,2)
            v = self.v(x).view(B, L, self.num_heads, self.head_dim).transpose(1,2)

            attn = (q @ k.transpose(-1,-2)) / math.sqrt(self.head_dim)
            attn = attn.softmax(dim=-1)
            out  = (attn @ v).transpose(1,2).reshape(B, L, self.embed_dim)
            return self.out(out)

    class ViTLayer(nn.Module):
        def __init__(self, num_heads, dim, mlp_hidden, dropout=0.1):
            super().__init__()
            self.norm1 = nn.LayerNorm(dim)
            self.attn  = MSA(dim, dim, num_heads)
            self.drop1 = nn.Dropout(dropout)
            self.norm2 = nn.LayerNorm(dim)
            self.mlp   = nn.Sequential(
                nn.Linear(dim, mlp_hidden),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(mlp_hidden, dim),
                nn.Dropout(dropout)
            )
        def forward(self, x):
            x = x + self.drop1(self.attn(self.norm1(x)))
            x = x + self.mlp(self.norm2(x))
            return x

    class ViT(nn.Module):
        def __init__(self, patch_dim, image_dim, num_layers, num_heads,
                     embed_dim, mlp_hidden, num_classes, dropout):
            super().__init__()
            self.patch_dim  = patch_dim
            self.img_dim    = image_dim
            self.patch_size = patch_dim * patch_dim * 3
            self.num_patches= (image_dim // patch_dim) ** 2

            self.patch_embed = nn.Linear(self.patch_size, embed_dim)
            self.cls_token   = nn.Parameter(torch.zeros(1,1,embed_dim))
            self.pos_embed   = nn.Parameter(torch.zeros(1,self.num_patches+1,embed_dim))
            self.pos_drop    = nn.Dropout(dropout)

            self.layers = nn.ModuleList([
                ViTLayer(num_heads, embed_dim, mlp_hidden, dropout)
                for _ in range(num_layers)
            ])

            self.norm = nn.LayerNorm(embed_dim)
            self.head = nn.Linear(embed_dim, num_classes)

        def forward(self, x):                       # (B,3,H,W)
            B, _, H, W = x.shape
            P = self.patch_dim
            x = x.reshape(B, 3, H//P, P, W//P, P)   # (B,3,h,p,w,p)
            x = x.permute(0,2,4,3,5,1).reshape(B, -1, self.patch_size)
            x = self.patch_embed(x)                 # (B,N,d)

            cls = self.cls_token.expand(B, -1, -1)  # (B,1,d)
            x   = torch.cat((cls, x), dim=1) + self.pos_embed
            x   = self.pos_drop(x)

            for layer in self.layers:
                x = layer(x)
            x = self.norm(x[:,0])
            return self.head(x)

    # helper factory
    def get_vit_small(num_classes=10, patch_dim=4, image_dim=32):
        return ViT(patch_dim, image_dim,
                   num_layers=12, num_heads=6,
                   embed_dim=384, mlp_hidden=1536,
                   num_classes=num_classes, dropout=0.1)

# ---------- (2)  CONFIGURATION ----------
device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size  = 32
num_epochs  = 30
lr          = 5e-4 * batch_size / 256
weight_decay= 0.1
data_root   = './data/cifar10'
train_size  = 40_000

# ---------- (3)  DATASETS & LOADERS ----------
tf_train = T.Compose([
    T.Resize(40), T.RandomCrop(32), T.RandomHorizontalFlip(),
    T.RandomAffine(0, translate=(0.2,0.2), scale=(0.95,1.05)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

tf_val = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

train_ds = datasets.CIFAR10(data_root, train=True, download=True, transform=tf_train)
val_ds   = datasets.CIFAR10(data_root, train=False, download=True, transform=tf_val)

train_loader = DataLoader(train_ds, batch_size, sampler=SubsetRandomSampler(range(train_size)),
                          num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size, shuffle=False,
                          num_workers=2, pin_memory=True)

# ---------- (4)  MODEL, LOSS, OPTIM ----------
vit = get_vit_small().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(vit.parameters(), lr=lr, betas=(0.9,0.95), weight_decay=weight_decay)

# ---------- (5)  TRAIN / VAL LOOP ----------
best_acc = 0.0
for epoch in range(num_epochs):
    vit.train()
    tr_loss = tr_correct = tr_total = 0

    for x, y in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out  = vit(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        tr_loss += loss.item() * x.size(0)
        tr_correct += (out.argmax(1) == y).sum().item()
        tr_total += x.size(0)

    tr_loss /= tr_total;  tr_acc = tr_correct / tr_total

    # ---- validation ----
    vit.eval()
    v_loss = v_correct = v_total = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            out  = vit(x)
            loss = criterion(out, y)
            v_loss += loss.item() * x.size(0)
            v_correct += (out.argmax(1) == y).sum().item()
            v_total += x.size(0)
    v_loss /= v_total;  v_acc = v_correct / v_total

    if v_acc > best_acc:
        best_acc = v_acc
        torch.save(vit.state_dict(), 'best_model.pth')

    print(f"[{epoch+1:2d}] "
          f"train loss {tr_loss:.3f} | acc {tr_acc:.3f} || "
          f"val loss {v_loss:.3f} | acc {v_acc:.3f}")

print("Finished Training")


Epoch 1/30:   0%|          | 0/1250 [00:00<?, ?it/s]

[ 1] train loss 1.841 | acc 0.302 || val loss 1.695 | acc 0.379


Epoch 2/30:   0%|          | 0/1250 [00:00<?, ?it/s]

KeyboardInterrupt: 

### Autograder and Submission

After you feel confident that you have a decent model, run the cell below.

Feel free to read the code block but **PLEASE DO NOT TOUCH IT**: this will produce a pickle file that will contain your model's predictions on the CIFAR-10 validation set --- tampering with the code block below might mess up the file that you will submit to the Gradescope autograder.

In [None]:
import pickle

cifar_test = datasets.CIFAR10('./data/cifar10_test', download = True, train = False, transform = transform_val)
loader_test = DataLoader(cifar_test, batch_size=32, shuffle=False)

vit.load_state_dict(torch.load('best_model.pth'))
vit.eval()  # set model to evaluation mode
predictions = []
with torch.no_grad():
    for x, _ in loader_test:
        x = x.to(device=device)  # move to device, e.g. GPU
        scores = vit(x)
        _, preds = scores.max(1)
        predictions.append(preds)
predictions = torch.cat(predictions).tolist()
with open("my_predictions.pickle", "wb") as file:
    pickle.dump(predictions, file)

100%|██████████| 170M/170M [00:02<00:00, 64.1MB/s]


KeyboardInterrupt: 