## Vision Transformer (ViT)

In this assignment we're going to work with Vision Transformer. We will start to build our own vit model and train it on an image classification task.
The purpose of this homework is for you to get familar with ViT and get prepared for the final project.

In [1]:
import math
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# VIT Implementation

The vision transformer can be seperated into three parts, we will implement each part and combine them in the end.

For the implementation, feel free to experiment different kinds of setup, as long as you use attention as the main computation unit and the ViT can be train to perform the image classification task present later.
You can read about the ViT implement from other libary: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py and https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py

## PatchEmbedding
PatchEmbedding is responsible for dividing the input image into non-overlapping patches and projecting them into a specified embedding dimension. It uses a 2D convolution layer with a kernel size and stride equal to the patch size. The output is a sequence of linear embeddings for each patch.

In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (image_size // patch_size) ** 2

    def forward(self, x):
        # B x C x H x W -> B x C x N where N = H*W/patch_size^2
        x = self.projection(x)  # B x embed_dim x H/patch_size x W/patch_size
        # B x embed_dim x (H*W/patch_size^2) -> B x (H*W/patch_size^2) x embed_dim
        x = x.flatten(2)  # B x embed_dim x N
        x = x.transpose(1, 2)  # B x N x embed_dim
        return x


## MultiHeadSelfAttention

This class implements the multi-head self-attention mechanism, which is a key component of the transformer architecture. It consists of multiple attention heads that independently compute scaled dot-product attention on the input embeddings. This allows the model to capture different aspects of the input at different positions. The attention outputs are concatenated and linearly transformed back to the original embedding size.

In [4]:
import torch.nn.functional as F

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention


class MultiHeadSelfAttention(nn.Module):

    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Stack all weight matrices 1...h together for efficiency
        self.qkv_proj = nn.Linear(embed_dim, 3*embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        # Initialize the weights of the attention layers using Xavier uniform initialization
        # This helps with training stability by keeping the variance of activations roughly constant
        # across layers at the start of training
        # Xavier initialization is particularly well-suited for tanh activations and similar
        # The weights are initialized with Xavier uniform while biases are set to 0
        nn.init.xavier_uniform_(self.qkv_proj.weight)  # Initialize query/key/value projection weights
        self.qkv_proj.bias.data.fill_(0)  # Initialize query/key/value projection biases
        nn.init.xavier_uniform_(self.o_proj.weight)  # Initialize output projection weights
        self.o_proj.bias.data.fill_(0)  # Initialize output projection biases

    def forward(self, x):
        # Get input dimensions
        batch_size, seq_length, _ = x.size()

        # Project input into query, key and value vectors all at once
        # Output shape: (batch_size, seq_length, 3*embed_dim)
        qkv = self.qkv_proj(x)

        # Reshape and permute to separate heads and split QKV
        # First reshape to: (batch_size, seq_length, num_heads, 3*head_dim)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
        # Permute to: (batch_size, num_heads, seq_length, 3*head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
        # Split into query, key, value tensors along last dimension
        q, k, v = qkv.chunk(3, dim=-1)

        # Compute scaled dot-product attention
        # values shape: (batch_size, num_heads, seq_length, head_dim)
        # _attention contains the attention weights but we don't use them here
        values, _attention = scaled_dot_product(q, k, v)

        # Reshape attention output
        # First permute back to: (batch_size, seq_length, num_heads, head_dim)
        values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        # Then combine heads: (batch_size, seq_length, embed_dim)
        values = values.reshape(batch_size, seq_length, self.embed_dim)

        # Final linear projection
        o = self.o_proj(values)

        return o





## TransformerBlock
This class represents a single transformer layer. It includes a multi-head self-attention sublayer followed by a position-wise feed-forward network (MLP). Each sublayer is surrounded by residual connections.
You may also want to use layer normalization or other type of normalization.

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout):
        super().__init__()

        # Multi-head self attention layer
        self.attention = MultiHeadSelfAttention(embed_dim, num_heads)

        # Layer normalization layers
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        # MLP block
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # First sublayer: Multi-head self attention with residual connection
        attention_out = self.attention(self.norm1(x))
        x = x + self.dropout(attention_out)

        # Second sublayer: MLP with residual connection
        mlp_out = self.mlp(self.norm2(x))
        x = x + self.dropout(mlp_out)

        return x

## VisionTransformer:
This is the main class that assembles the entire Vision Transformer architecture. It starts with the PatchEmbedding layer to create patch embeddings from the input image. A special class token is added to the sequence, and positional embeddings are added to both the patch and class tokens. The sequence of patch embeddings is then passed through multiple TransformerBlock layers. The final output is the logits for all classes

In [6]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=0.1):
        super().__init__()

        # Patch embedding layer
        self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)

        # Calculate number of patches
        num_patches = (image_size // patch_size) ** 2

        # Class token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        # Positional embedding
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))

        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_dim, dropout)
            for _ in range(num_layers)
        ])

        # Layer normalization
        self.norm = nn.LayerNorm(embed_dim)

        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Create patch embeddings
        x = self.patch_embed(x)  # Shape: (batch_size, num_patches, embed_dim)

        # Add class token
        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # Add positional embeddings
        x = x + self.pos_embed
        x = self.dropout(x)

        # Apply transformer blocks
        for block in self.transformer_blocks:
            x = block(x)

        # Layer normalization
        x = self.norm(x)

        # Use the class token for classification
        x = x[:, 0]

        # Classification head
        x = self.head(x)

        return x

## Let's train the ViT!

We will train the vit to do the image classification with cifar100. Free free to change the optimizer and or add other tricks to improve the training

In [7]:
# hypers
image_size = 224
patch_size = 16
in_channels = 3
embed_dim = 1032 #1024 #768
num_heads = 12
mlp_dim = 4096 #3072
num_layers = 12
num_classes = 100
dropout = 0.1

batch_size = 128

# faster loading
num_workers = 4


In [8]:
model = VisionTransformer(image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout).to(device)
input_tensor = torch.randn(1, in_channels, image_size, image_size).to(device)
output = model(input_tensor)
print(output.shape)

torch.Size([1, 100])


In [9]:
# Load the CIFAR-100 dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169M/169M [00:12<00:00, 13.2MB/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [10]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.001,
    epochs=num_epochs,
    steps_per_epoch=len(trainloader),
    pct_start=0.3,  # Spend 30% of time ramping up, 70% ramping down
)

# more
torch.backends.cudnn.benchmark = True

In [None]:
# Train the model
best_val_acc = 0
start_epoch = 0

# Try to load checkpoint if exists
checkpoint_path = "training_checkpoint.pth"
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, weights_only=True)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch']
    best_val_acc = checkpoint['best_val_acc']
    print(f"Resuming from epoch {start_epoch} with best validation accuracy: {best_val_acc:.2f}%")
else:
    print("No checkpoint found. Starting new.")

for epoch in range(start_epoch, num_epochs):
    model.train()
    train_pbar = tqdm.tqdm(trainloader, desc=f'Training Epoch {epoch+1}/{num_epochs}')
    for i, data in enumerate(train_pbar):
        inputs, labels = data
        inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # scheduler.step()

        train_pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    # Validate the model
    model.eval()
    correct = 0
    total = 0
    val_losses = []
    with torch.no_grad():
        val_pbar = tqdm.tqdm(testloader, desc=f'Validation Epoch {epoch+1}/{num_epochs}')
        for data in val_pbar:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_losses.append(loss.item())
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            val_pbar.set_postfix({'val_loss': f'{torch.tensor(val_losses).mean().item():.4f}'})

    val_acc = 100 * correct / total
    print(f"Epoch: {epoch + 1}, Validation Accuracy: {val_acc:.2f}%")

    # Save checkpoint
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_val_acc': best_val_acc,
    }
    torch.save(checkpoint, checkpoint_path)

    # Save the best model separately
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        print(f"New best model saved with validation accuracy: {val_acc:.2f}%")

No checkpoint found. Starting new.


Training Epoch 1/10: 100%|██████████| 391/391 [09:57<00:00,  1.53s/it, loss=3.4577]
Validation Epoch 1/10: 100%|██████████| 79/79 [00:39<00:00,  1.99it/s, val_loss=3.6654]


Epoch: 1, Validation Accuracy: 13.33%
New best model saved with validation accuracy: 13.33%


Training Epoch 2/10: 100%|██████████| 391/391 [09:56<00:00,  1.52s/it, loss=3.2250]
Validation Epoch 2/10: 100%|██████████| 79/79 [00:39<00:00,  1.99it/s, val_loss=3.1946]


Epoch: 2, Validation Accuracy: 21.74%
New best model saved with validation accuracy: 21.74%


Training Epoch 3/10: 100%|██████████| 391/391 [09:55<00:00,  1.52s/it, loss=3.2133]
Validation Epoch 3/10: 100%|██████████| 79/79 [00:39<00:00,  2.00it/s, val_loss=2.8976]


Epoch: 3, Validation Accuracy: 27.89%
New best model saved with validation accuracy: 27.89%


Training Epoch 4/10: 100%|██████████| 391/391 [09:55<00:00,  1.52s/it, loss=3.0310]
Validation Epoch 4/10: 100%|██████████| 79/79 [00:39<00:00,  2.00it/s, val_loss=2.7330]


Epoch: 4, Validation Accuracy: 31.06%
New best model saved with validation accuracy: 31.06%


Training Epoch 5/10:  48%|████▊     | 189/391 [04:50<05:07,  1.52s/it, loss=2.5591]

Please submit your best_model.pth with this notebook. And report the best test results you get.