# Lecture 4: Vision Transformers (ViTs)
Welcome to this hands-on tutorial on **Vision Transformers (ViTs)**, a groundbreaking architecture that has transformed the field of computer vision. In this notebook, we’ll explore the key components and mechanisms that make ViTs unique, with a particular focus on **self-attention** and **patch embeddings**.

Unlike traditional Convolutional Neural Networks (CNNs), ViTs use a transformer-based approach to process image data. Images are divided into small, fixed-size patches, which are then embedded into a sequence of vectors. These vectors are processed through a series of **Multi-Head Self-Attention (MHA) layers**, enabling the model to capture both local and global dependencies.

By the end of this notebook, you’ll understand how to:

1) Implement patch embeddings to transform images into input sequences for ViTs.
2) Build and train a self-attention block, the core building block of Vision Transformers.
3) Appreciate the significance of pre-training for achieving optimal performance with ViTs.

Through these implementations, we aim to show the workings of ViTs and provide you with the tools to construct and train your first Vision Transformer. Let’s dive into this exciting architecture and explore its potential!



In [None]:
# IMPORT PACKAGES
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
import copy
import os

import unittest

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # You don't have a gpu, so use cpu
print(f'Using device: {device}')

## 3.1 Basic Building blocks of ViTs
In this part of the assignment you will implement the Attention mechanism and the function to create patch embeddings from images. Once you have successfully implemented the code in this notebook, please copy your implementation to vision_transformer_utils.py. You will need a complete version of that file for the second part of the assignment.

### 3.1.1 Implementing Multi-Head Self-Attention
Please complete the following Attention class. You can do it in about 7 lines of code

In [None]:
class Attention(nn.Module):

    """
    Implements a multi-head self-attention mechanism with optional scaling.

    This module computes self-attention, using a scaled dot-product mechanism, over input features. It supports optional biases in the query, key, and value projections, scaling of the attention scores, and dropout in both the attention scores and the output projection.

    Parameters:
    - dim (int): Dimensionality of the input features and the output features.
    - num_heads (int, optional): Number of attention heads. Defaults to 8.
    - qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False.
    - qk_scale (float, optional): Scale factor for query-key dot products. If None, defaults to dim ** -0.5. When specified, overrides the default scaling.
    - attn_drop (float, optional): Dropout rate for attention weights. Defaults to 0.
    - proj_drop (float, optional): Dropout rate for the output of the final projection layer. Defaults to 0.

    The forward pass accepts an input tensor `x` and returns the transformed tensor and the attention weights. The input tensor is expected to have the shape (batch_size, num_features, dim), where `num_features` is the number of features (or tokens) and `dim` is the feature dimensionality.

    The output consists of the transformed input tensor with the same shape as the input and the attention weights tensor of shape (batch_size, num_heads, num_features, num_features), representing the attention scores applied to the input features.
    """
    

    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        #TODO: complete the forward pass
        # q, k, v = 
        
        return x, attn

The below unit test function will call the Attention class you just implemented and test if the output shape and type are correct. 
**You do not need to modify this function.**

In [None]:
# Code for testing the Attention module, you do not need to modify this
class TestAttention(unittest.TestCase):

    def setUp(self):
        # Setup your test cases here with different configurations
        self.batch_size = 2
        self.seq_length = 10
        self.embed_dim = 32
        self.num_heads = 4

        self.input_tensor = torch.rand(self.batch_size, self.seq_length, self.embed_dim) # random tensor that represents your input

    def test_output_shape(self):
        """Test if the output shape is correct."""
        attention = Attention(dim=self.embed_dim, num_heads=self.num_heads)
        output, attn = attention(self.input_tensor)
        self.assertEqual(output.shape, self.input_tensor.shape)
        self.assertEqual(attn.shape, (self.batch_size, self.num_heads, self.seq_length, self.seq_length))

    def test_output_type(self):
        """Test if the output types are correct."""
        attention = Attention(dim=self.embed_dim, num_heads=self.num_heads)
        output, attn = attention(self.input_tensor)
        self.assertIsInstance(output, torch.Tensor)
        self.assertIsInstance(attn, torch.Tensor)

Now we are going to test the implemented Attention module with the defined unit test. Please make sure to pass these tests before continuing the assignment. 

**If the tests pass, copy the code of the working Attention module to the vision_transformer_utils.py file for use in the second part of the assignment**.

In [None]:
# Test the Attention module
unittest.main(argv=[''], verbosity=2, exit=False) # make sure these pass before continuing the assignment

### 3.1.2 Implementing Patch Embedding
Please complete the following Patch embedding code. You can do it in 1 line of code.

In [None]:
class PatchEmbed(nn.Module):
    """
    Converts an image into a sequence of patches and embeds them.

    This module uses a convolutional layer to transform the input images into a flat sequence of embeddings, 
    effectively converting each patch of the image into an embedding vector.

    Parameters:
    - img_size (int, optional): Size of the input image (height and width). Defaults to 224.
    - patch_size (int, optional): Size of each patch (height and width). Defaults to 16.
    - in_chans (int, optional): Number of input channels (e.g., 3 for RGB images). Defaults to 3.
    - embed_dim (int, optional): Dimension of the patch embeddings. Defaults to 768.

    The module calculates the number of patches by dividing the image size by the patch size, both vertically and horizontally. 
    It then applies a 2D convolutional layer to project each patch to the embedding space defined by `embed_dim`.
    """
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        
        # TODO: Complete the forward pass
        # x =

        return x


The below unit test function will call the PatchEmbed class you just implemented and test if the output shape, number of patches and output type are correct. **You do not need to modify this function.**

In [None]:
# Code for testing the Patch Embedding module, you do not need to modify this
class TestPatchEmbed(unittest.TestCase):

    def setUp(self):
        # Example setup: 224x224 image with 3 channels, 16x16 patches, and embedding dimension of 768
        self.img_size = 224
        self.patch_size = 16
        self.in_chans = 3
        self.embed_dim = 768
        self.batch_size = 4

        # Calculate the expected number of patches
        self.expected_num_patches = (self.img_size // self.patch_size) ** 2

        # Create a dummy input tensor
        self.input_tensor = torch.rand(self.batch_size, self.in_chans, self.img_size, self.img_size)

    def test_output_shape(self):
        """Test if the output tensor shape is correct."""
        patch_embed = PatchEmbed(img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
        output = patch_embed(self.input_tensor)
        expected_shape = (self.batch_size, self.expected_num_patches, self.embed_dim)
        self.assertEqual(output.shape, expected_shape)

    def test_num_patches(self):
        """Test if the calculated number of patches is correct."""
        patch_embed = PatchEmbed(img_size=self.img_size, patch_size=self.patch_size)
        self.assertEqual(patch_embed.num_patches, self.expected_num_patches)

    def test_output_type(self):
        """Test if the output is a tensor."""
        patch_embed = PatchEmbed(img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim)
        output = patch_embed(self.input_tensor)
        self.assertIsInstance(output, torch.Tensor)


Now we are going to test the implemented PatchEmbed module with the defined unit test. Please make sure to pass these tests before continuing the assignment. 

**If the test pass, copy the code of the working PatchEmbed module to the vision_transformer_utils.py file for use in the second part of the assignment.**

In [None]:
# Run the tests
unittest.main(argv=[''], verbosity=2, exit=False) # make sure these pass before continuing the assignment

## 3.2 Training ViTs

We will use he same dataset as in the first and second notebook, CIFAR-10. 

In [None]:
# preprocess the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # cifar mean and std
])


# download CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# download the test data
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

# create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

### 3.2.1 Constructing ViT Model
Here we assume you have successfully implemented Attention and PatchEmbed in the vision_transformers_utils.py file.

In [None]:
from vision_transformer_utils_to_update import vit_small

In [None]:
# Create the Vision Transformer Model with the implemented Attention and PatchEmbed Modules
own_model = vit_small(patch_size=8)

# Send the model to the available device
own_model.to(device)

### 3.2.2 Training ViT Model from scratch
Train your ViT model, feel free to modify the code below and get the best accuracy you can

In [None]:
# Set parameters
num_epochs = 3
learning_rate = 0.0001

# Set the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(own_model.parameters(), lr=learning_rate)

# Initialize best_score parameter
best_score = 0

# Train the model
for epoch in range(num_epochs):  # loop over the dataset for the number of specified epochs
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = own_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        print(f'Epoch {epoch + 1}, Iteration [{i}/{len(train_loader)}]', end='\r')
        
    # log the running loss
    print(f'Finished epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')

    # show testing accuracy
    correct = 0
    total = 0

    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = own_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    # log the testing accuracy
    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')
    

print('Finished Training')
# save the model
PATH = './first_vit_cifar_net_last.pth'
torch.save(own_model.state_dict(), PATH)

Now we want to test the performance of our model trained from scratch.

In [None]:
# Test your model on the test set
def test_model_on_testset(model, test_loader, device):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

In [None]:
test_model_on_testset(own_model, test_loader, device)

#### Questions
**Q1**. What do you think of the performance obtained by the model trained from scratch? How can we improve the performance?

**Q2**. What is the influence of the learning rate? Try to increase/decrease the learning rate.

**Q3**. What is the impact of the patch size on the performance? (Note: the input resolution is 32x32 pixels)

**Q4**. What does happen if we instead of vit_small use another version (e.g. vit_tiny, vit_base)?


### 3.2.3 Initializing ViT model with ImageNet DINO weights
Now we are going to construct the same model, but rather than initializing it from scratch, we will load weights obtained by pre-training with the self-supervised DINO method!

In [None]:
# Try loading pretrained IMAGENET model (NOTE: patch_size has to be 8 to load in the pre-trained weights)
pretrained_model = vit_small(patch_size=8)

url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
pretrained_model.load_state_dict(state_dict, strict=True)
pretrained_model.to(device)

Let's see how the ImageNet weights initialized model performs on the test set

In [None]:
test_model_on_testset(pretrained_model, test_loader, device)

#### Questions
Q5. Is this performance at the level you would expect? Why do you think this is the case?

### 3.2.4 Finetuning ViT model with ImageNet DINO weights
Finetune the model with ImageNet pretrained weights. 

In [None]:
# Set parameters
num_epochs = 3
learning_rate = 0.00001

# Put the model on the device and set to training mode
pretrained_model.to(device)
pretrained_model.train()

# Set the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(pretrained_model.parameters(), lr=learning_rate)

# Initialize best_score parameter
best_score = 0

# Train the model
for epoch in range(num_epochs):  # loop over the dataset for the number of specified epochs
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = pretrained_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        print(f'Epoch {epoch + 1}, Iteration [{i}/{len(train_loader)}]', end='\r')
    
    # log the running loss
    print(f'Finished epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}')    
    
    # show testing accuracy
    correct = 0
    total = 0

    with torch.no_grad():
        for data in test_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = pretrained_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    # log the testing accuracy
    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')
    

print('Finished Training')
# save the model
PATH = './finetuned_model.pth'
torch.save(pretrained_model.state_dict(), PATH)

In [None]:
test_model_on_testset(pretrained_model, test_loader, device) 

#### Questions
**Q6**. What do you think of the performance compared to the model trained from scratch?

**Q7**. What do you think of the speed of convergence compared to the model trained from scratch?

**Q8**. What is the influence of the learning rate? Try to increase/decrease the learning rate.