## Vision Transformer (ViT)

In this assignment we're going to work with Vision Transformer. We will start to build our own vit model and train it on an image classification task.
The purpose of this homework is for you to get familar with ViT and get prepared for the final project.

In [15]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# VIT Implementation

The vision transformer can be seperated into three parts, we will implement each part and combine them in the end.

For the implementation, feel free to experiment different kinds of setup, as long as you use attention as the main computation unit and the ViT can be train to perform the image classification task present later.
You can read about the ViT implement from other libary: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py and https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py

## PatchEmbedding
PatchEmbedding is responsible for dividing the input image into non-overlapping patches and projecting them into a specified embedding dimension. It uses a 2D convolution layer with a kernel size and stride equal to the patch size. The output is a sequence of linear embeddings for each patch.

In [17]:
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.num_patches = (image_size // patch_size) ** 2

        # 2D convolution for patch embedding
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)  # Extract patches
        x = x.view(x.shape[0], x.shape[1], -1).transpose(1, 2)  # Reshape
        return x

## MultiHeadSelfAttention

This class implements the multi-head self-attention mechanism, which is a key component of the transformer architecture. It consists of multiple attention heads that independently compute scaled dot-product attention on the input embeddings. This allows the model to capture different aspects of the input at different positions. The attention outputs are concatenated and linearly transformed back to the original embedding size.

In [18]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.query_projection = nn.Linear(embed_dim, embed_dim)
        self.key_projection = nn.Linear(embed_dim, embed_dim)
        self.value_projection = nn.Linear(embed_dim, embed_dim)
        self.output_projection = nn.Linear(embed_dim, embed_dim)

        # **Move dropout to init**
        self.attn_dropout = nn.Dropout(0.1)

    def forward(self, x):
        batch_size, num_tokens, embedding_dim = x.shape

        query = self.query_projection(x)
        key = self.key_projection(x)
        value = self.value_projection(x)

        query = query.view(batch_size, num_tokens, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        key = key.view(batch_size, num_tokens, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        value = value.view(batch_size, num_tokens, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        attention = (query @ key.transpose(-2, -1)) * self.scale
        attention = F.softmax(attention, dim=-1)
        attention = self.attn_dropout(attention)  # Apply dropout

        x = (attention @ value).transpose(1, 2).reshape(batch_size, num_tokens, embedding_dim)
        x = self.output_projection(x)
        return x

## TransformerBlock
This class represents a single transformer layer. It includes a multi-head self-attention sublayer followed by a position-wise feed-forward network (MLP). Each sublayer is surrounded by residual connections.
You may also want to use layer normalization or other type of normalization.

In [19]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.attention = MultiHeadSelfAttention(embed_dim, num_heads)

        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # **Pre-LN Transformer Order Fix**
        residual = x
        x = self.norm1(x)
        x = self.attention(x)
        x = self.dropout(x)
        x = residual + x

        residual = x  # Update residual
        x = self.norm2(x)
        x = self.mlp(x)
        x = self.dropout(x)  # **Apply dropout before addition**
        return residual + x

## VisionTransformer:
This is the main class that assembles the entire Vision Transformer architecture. It starts with the PatchEmbedding layer to create patch embeddings from the input image. A special class token is added to the sequence, and positional embeddings are added to both the patch and class tokens. The sequence of patch embeddings is then passed through multiple TransformerBlock layers. The final output is the logits for all classes

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=0.1):
        # Initialize the Vision Transformer model
        super(VisionTransformer, self).__init__()

        # Patch embedding layer that splits the image into patches and embeds them
        self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)

        # Positional embeddings added to patches to retain positional information
        self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))

        # Class token that will be used for classification task
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        # Dropout layer applied after the addition of positional embeddings
        self.dropout = nn.Dropout(dropout)

        # A list of Transformer blocks stacked on top of each other
        self.transformer = nn.ModuleList([TransformerBlock(embed_dim, num_heads, mlp_dim, dropout) for _ in range(num_layers)])

        # MLP head that will predict the final classification
        self.mlp_head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # Forward pass through the Vision Transformer
        batch_size = x.shape[0]  # Get the batch size

        # Pass input through the patch embedding layer
        x = self.patch_embed(x)

        # Expand the class token to match the batch size and append it to the sequence of patches
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Add positional embeddings to the input sequence
        x = x + self.pos_embed

        # Apply dropout to the embedded sequence (including class token and patches)
        x = self.dropout(x)

        # Pass the embedded sequence through the transformer blocks
        for blk in self.transformer:
            x = blk(x)

        # Select the output corresponding to the class token (first element of the sequence)
        x = x[:, 0]

        # Apply the MLP head to obtain the final class prediction
        x = self.mlp_head(x)

        return x


In [32]:
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout=0.1):
        # Initialize the VisionTransformer module with the given hyperparameters.
        super(VisionTransformer, self).__init__()  # Call the parent class (nn.Module) constructor.
        
        # Patch embedding: converts the input image into a sequence of embedded patches.
        self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
        
        # Positional embeddings: learnable parameters that encode the spatial positions of patches.
        # We add one extra token for the class token.
        self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, embed_dim))
        
        # Class token: a learnable embedding that represents the entire image.
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        
        # Dropout layer: used to prevent overfitting.
        self.dropout = nn.Dropout(dropout)
        
        # Transformer encoder blocks: a list of transformer layers that process the embeddings.
        self.transformer = nn.ModuleList([TransformerBlock(embed_dim, num_heads, mlp_dim, dropout) for _ in range(num_layers)])
        
        # MLP head: a linear layer that maps the final embedding to class scores.
        self.mlp_head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # Forward pass through the Vision Transformer.
        
        batch_size = x.shape[0]  # Determine the batch size from the input.
        
        # Convert the input image into a sequence of patch embeddings.
        x = self.patch_embed(x)
        
        # Duplicate the class token for each example in the batch.
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        
        # Prepend the class token to the patch embeddings.
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Add the positional embeddings to incorporate spatial information.
        x = x + self.pos_embed
        
        # Apply dropout for regularization.
        x = self.dropout(x)
        
        # Process the embeddings through each transformer block sequentially.
        for blk in self.transformer:
            x = blk(x)
        
        # Extract the class token output (first token) from the transformer output.
        x = x[:, 0]
        
        # Pass the class token through the MLP head to obtain final class scores.
        x = self.mlp_head(x)
        
        return x


## Let's train the ViT!

We will train the vit to do the image classification with cifar100. Free free to change the optimizer and or add other tricks to improve the training

In [33]:
# Example usage:
image_size = 64
patch_size = 8
in_channels = 3
embed_dim = 256
num_heads = 8
mlp_dim = 512
num_layers = 6
num_classes = 100
dropout = 0.01
batch_size = 256

In [34]:
model = VisionTransformer(image_size, patch_size, in_channels, embed_dim, num_heads, mlp_dim, num_layers, num_classes, dropout).to(device)
input_tensor = torch.randn(1, in_channels, image_size, image_size).to(device)
output = model(input_tensor)
print(output.shape)

torch.Size([1, 100])


In [35]:
# Load the CIFAR-100 dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [36]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

In [37]:
num_epochs = 100
best_val_acc = 0
for epoch in range(num_epochs):
    model.train()
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # TODO Feel free to modify the training loop youself.
    lr_scheduler.step()

    # Validate the model
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_acc = 100 * correct / total
    print(f"Epoch: {epoch + 1}, Validation Accuracy: {val_acc:.2f}%")

    # Save the best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_model.pth")

Epoch: 1, Validation Accuracy: 18.61%
Epoch: 2, Validation Accuracy: 26.48%
Epoch: 3, Validation Accuracy: 30.45%
Epoch: 4, Validation Accuracy: 35.51%
Epoch: 5, Validation Accuracy: 36.74%
Epoch: 6, Validation Accuracy: 40.21%
Epoch: 7, Validation Accuracy: 42.85%
Epoch: 8, Validation Accuracy: 44.72%
Epoch: 9, Validation Accuracy: 46.84%
Epoch: 10, Validation Accuracy: 47.87%
Epoch: 11, Validation Accuracy: 48.57%
Epoch: 12, Validation Accuracy: 49.36%
Epoch: 13, Validation Accuracy: 50.76%
Epoch: 14, Validation Accuracy: 51.34%
Epoch: 15, Validation Accuracy: 51.10%
Epoch: 16, Validation Accuracy: 52.46%
Epoch: 17, Validation Accuracy: 52.13%
Epoch: 18, Validation Accuracy: 52.21%
Epoch: 19, Validation Accuracy: 53.58%
Epoch: 20, Validation Accuracy: 52.88%
Epoch: 21, Validation Accuracy: 53.12%
Epoch: 22, Validation Accuracy: 53.78%
Epoch: 23, Validation Accuracy: 53.73%
Epoch: 24, Validation Accuracy: 53.89%
Epoch: 25, Validation Accuracy: 53.05%
Epoch: 26, Validation Accuracy: 53

Please submit your best_model.pth with this notebook. And report the best test results you get.

In [38]:
print("best_val_acc:",best_val_acc)

best_val_acc: 58.26
