## Vision Transformer (ViT)

In this project I am going to work with Vision Transformer. I will start to build our own vit model and train it on an image classification task.


In [None]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# VIT Implementation

The vision transformer can be seperated into three parts, we will implement each part and combine them in the end.


You can read about the ViT implement from other libary: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py and https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py

## PatchEmbedding
PatchEmbedding is responsible for dividing the input image into non-overlapping patches and projecting them into a specified embedding dimension. It uses a 2D convolution layer with a kernel size and stride equal to the patch size. The output is a sequence of linear embeddings for each patch.

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
      # TODO
      super(PatchEmbedding, self).__init__()
      self.image_size = image_size
      self.patch_size = patch_size
      self.in_channels = in_channels
      self.embed_dim = embed_dim
      self.num_patches = (image_size // patch_size) ** 2

      self.projection_layer = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)


    def forward(self, x):
      # TODO
      x = self.projection_layer(x)
      x = x.flatten(2)
      x = x.transpose(1, 2)
      return x

## MultiHeadSelfAttention

This class implements the multi-head self-attention mechanism, which is a key component of the transformer architecture. It consists of multiple attention heads that independently compute scaled dot-product attention on the input embeddings. This allows the model to capture different aspects of the input at different positions. The attention outputs are concatenated and linearly transformed back to the original embedding size.

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
      # TODO
      super(MultiHeadSelfAttention, self).__init__()
      self.embed_dim = embed_dim
      self.num_heads = num_heads
      self.head_dim = embed_dim // num_heads

      self.query_layer = nn.Linear(embed_dim, embed_dim)
      self.key_layer = nn.Linear(embed_dim, embed_dim)
      self.value_layer = nn.Linear(embed_dim, embed_dim)
      self.output_layer = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
      # TODO
      batch_size = x.size(0)
      query = self.query_layer(x)
      key = self.key_layer(x)
      value = self.value_layer(x)

      queries = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
      keys = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
      values = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

      attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(self.head_dim)
      attention_probs = F.softmax(attention_scores, dim=-1)
      attention_output = torch.matmul(attention_probs, values) #weighted sum
      attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
      output = self.output_layer(attention_output)
      return output







## TransformerBlock
This class represents a single transformer layer. It includes a multi-head self-attention sublayer followed by a position-wise feed-forward network (MLP). Each sublayer is surrounded by residual connections.
You may also want to use layer normalization or other type of normalization.

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout):
        # TODO
        super(TransformerBlock, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.dropout = dropout
        self.attention = MultiHeadSelfAttention(embed_dim, num_heads)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # TODO
        residual = x
        x = self.attention(x)
        x = self.layer_norm1(x + residual)
        residual = x
        x = self.mlp(x)
        x = self.layer_norm2(x + residual)
        return x


## VisionTransformer:
This is the main class that assembles the entire Vision Transformer architecture. It starts with the PatchEmbedding layer to create patch embeddings from the input image. A special class token is added to the sequence, and positional embeddings are added to both the patch and class tokens. The sequence of patch embeddings is then passed through multiple TransformerBlock layers. The final output is the logits for all classes

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=0.1):
        super(VisionTransformer, self).__init__()

        self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
        self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.position_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.num_patches + 1, embed_dim))
        self.transformer_blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads, mlp_dim, dropout) for _ in range(num_layers)])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # TODO
        batch_size = x.size(0)
        x = self.patch_embedding(x)
        class_token = self.class_token.expand(batch_size, -1, -1)
        x = torch.cat([class_token, x], dim=1)
        x = x + self.position_embedding
        x = self.dropout(x)
        for block in self.transformer_blocks:
            x = block(x)
        x = x[:, 0]
        x = self.mlp_head(x)
        return x

## Let's train the ViT!

We will train the vit to do the image classification with cifar100.

In [None]:
# Example usage:
image_size = 32
patch_size = 4
in_channels = 3 #RGB
embed_dim = 256
num_heads = 8
mlp_dim = 512
num_layers = 8
num_classes = 100
dropout = 0.2


batch_size = 128

In [None]:
model = VisionTransformer(image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout).to(device)
input_tensor = torch.randn(1, in_channels, image_size, image_size).to(device)
output = model(input_tensor)
print(output.shape)

torch.Size([1, 100])


In [None]:
# Load the CIFAR-100 dataset
from torchvision.transforms import AutoAugment, AutoAugmentPolicy
transform_train = transforms.Compose([
    AutoAugment(policy=AutoAugmentPolicy.CIFAR10),  # AutoAugment for CIFAR-like datasets
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))  # CIFAR-100 mean and std
])


transform_test = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169M/169M [00:18<00:00, 9.23MB/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
# Define the loss function and optimizer
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Label smoothing
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)


In [None]:
from itertools import accumulate

#Traning Loop
num_epochs = 100
best_val_acc = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_train = 0  # Count of correct predictions
    total_train = 0    # Total number of samples in training data

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        correct_train += (predicted == labels).sum().item()  # Sum correct predictions
        total_train += labels.size(0)                        # Track total samples
    scheduler.step()
    # Calculate epoch loss and accuracy
    epoch_loss = running_loss / len(trainloader)
    epoch_accuracy = 100 * correct_train / total_train
    print(f"Epoch: {epoch + 1}, Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_accuracy:.2f}%")

    # Validation phase
    model.eval()
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()

    # Calculate validation accuracy
    val_acc = 100 * correct_val / total_val
    print(f"Epoch: {epoch + 1}, Validation Accuracy: {val_acc:.2f}%")

    # Save the best model based on validation accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")
        print(f"New best model saved with validation accuracy: {best_val_acc:.2f}%")


Epoch: 1, Loss: 4.3502, Training Accuracy: 5.31%
Epoch: 1, Validation Accuracy: 10.96%
New best model saved with validation accuracy: 10.96%
Epoch: 2, Loss: 3.9845, Training Accuracy: 11.38%
Epoch: 2, Validation Accuracy: 19.22%
New best model saved with validation accuracy: 19.22%
Epoch: 3, Loss: 3.7334, Training Accuracy: 16.39%
Epoch: 3, Validation Accuracy: 23.11%
New best model saved with validation accuracy: 23.11%
Epoch: 4, Loss: 3.5512, Training Accuracy: 20.42%
Epoch: 4, Validation Accuracy: 27.46%
New best model saved with validation accuracy: 27.46%
Epoch: 5, Loss: 3.3982, Training Accuracy: 23.83%
Epoch: 5, Validation Accuracy: 30.60%
New best model saved with validation accuracy: 30.60%
Epoch: 6, Loss: 3.2842, Training Accuracy: 26.70%
Epoch: 6, Validation Accuracy: 35.17%
New best model saved with validation accuracy: 35.17%
Epoch: 7, Loss: 3.1932, Training Accuracy: 28.75%
Epoch: 7, Validation Accuracy: 36.27%
New best model saved with validation accuracy: 36.27%
Epoch: 

Please submit your best_model.pth with this notebook. And report the best test results you get.