# Vision Transformers

Practical by Niki Amini-Naieni, Iro Laina, and Andrew Zisserman adapted and extended from tutorial by Phillip Lippe

This practical introduces Vision Transformers (ViTs) and explores their role in modern computer vision. Since [Alexey Dosovitskiy et al.](https://openreview.net/pdf?id=YicbFdNTTy) first demonstrated that Transformers could be applied successfully to image recognition tasks, the field has seen rapid developments. But how do these models work in practice, and how do they compare to more traditional convolutional networks? And how can we take advantage of large-scale pretraining to adapt state-of-the-art Transformer models to new tasks?

You will explore these questions in two stages.

In **Part 1**, you will implement and train a Vision Transformer from scratch on the CIFAR-10 dataset and compare its classification performance to a CNN trained under similar conditions. This will give practical experience with both architectures.

In **Part 2**, you will investigate how Vision Transformers can be used effectively without training them end-to-end. You will extract frozen features from a model pretrained with self-supervision on large-scale datasets and train a linear classifier on top, and you will also evaluate CLIP’s zero-shot classification capabilities. This part shows how powerful pretrained representations can be when applied to smaller datasets, and how modern vision systems often rely on large-scale pretraining.

If you are not familiar with Transformers yet, take a look at [this online tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html) where the fundamentals of Multi-Head Attention and Transformers are discussed. We will use [PyTorch Lightning](https://www.pytorchlightning.ai/) \(introduced [here](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial5/Inception_ResNet_DenseNet.html)\).

**IMPORTANT**





Please make sure to change your runtime type to GPU **before** starting. Click the dropdown menu in the top right and select "Change Runtime Type." Then, check the circle next to any of the available GPUs.






# **Part 1**
---
## Learning Outcomes
By the end of Part 1, you should be able to:

* Load and use datasets and dataloaders in PyTorch

* Reason about tensor shapes and dimensions

* Understand and implement the core components of a Vision Transformer in code

* Explain the memory requirements and scaling behavior of self-attention

* Compare pre-layer and post-layer normalization and discuss their benefits and drawbacks

* Understand the purpose of positional embeddings, explore different implementations, and evaluate their strengths and limitations

* Explain the role of the CLS token

* Interpret learning curves and compare the training dynamics of different architectures

* Compare the behavior of CNNs and Vision Transformers when trained from scratch on the same small dataset, and explain the differences you observe







In [None]:
## Standard libraries
import os
import numpy as np
import glob
import random
import math
import json
from functools import partial
from PIL import Image

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## tqdm for loading bars
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim

## Torchvision
import torchvision
from torchvision.datasets import CIFAR10, Flowers102
from torchvision import transforms

# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Import tensorboard
%load_ext tensorboard

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial15"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

We provide a pre-trained Vision Transformer which we download in the next cell. However, Vision Transformers can be relatively quickly trained on CIFAR10 with an overall training time of less than an hour on an NVIDIA TitanRTX. Feel free to experiment with training your own Transformer once you have gone through the whole notebook.

In [None]:
import urllib.request
from urllib.error import HTTPError
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/"
# Files to download
pretrained_files = ["tutorial15/ViT.ckpt", "tutorial15/tensorboards/ViT/events.out.tfevents.ViT",
                    "tutorial5/tensorboards/ResNet/events.out.tfevents.resnet"]
# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name.split("/",1)[1])
    if "/" in file_name.split("/",1)[1]:
        os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n", e)

We load the CIFAR10 dataset below. We use the same setup of the datasets and data augmentations that were used for the CNNs we will benchmark against to keep a fair comparison. The constants in the `transforms.Normalize` correspond to the values that scale and shift the data to a zero mean and standard deviation of one.

In [None]:
test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
                                     ])
# For training, we add some augmentation. Networks have too much capacity and would overfit.
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.49139968, 0.48215841, 0.44653091], [0.24703223, 0.24348513, 0.26158784])
                                     ])
# Loading the training dataset. We need to split it into a training and validation part
# We need to do a little trick because the validation set should not use the augmentation.
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)
val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)
pl.seed_everything(42)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
pl.seed_everything(42)
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])

# Loading the test set
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)

# We define a set of data loaders that we can use for various purposes later.
train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)

# Visualize some examples
NUM_IMAGES = 4
CIFAR_images = torch.stack([val_set[idx][0] for idx in range(NUM_IMAGES)], dim=0)
img_grid = torchvision.utils.make_grid(CIFAR_images, nrow=4, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8,8))
plt.title("Image examples of the CIFAR10 dataset")
plt.imshow(img_grid)
plt.axis('off')
plt.show()
plt.close()

# **Question 1**
> Write code to print out the image size. What is the height and width of the images in CIFAR10?

## Transformers for image classification

Transformers have been originally proposed to process sets since they are permutation-equivariant architectures, i.e., producing the same output permuted if the input is permuted. To apply Transformers to sequences, we have simply added a positional encoding to the input feature vectors, and the model learned by itself what to do with it. So, why not apply the same approach to images? This is exactly what [Alexey Dosovitskiy et al.](https://openreview.net/pdf?id=YicbFdNTTy) proposed in their paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale". Specifically, the Vision Transformer is a model, first used for image classification, that views images as sequences of smaller patches. As a preprocessing step, we split an image of, for example, $48\times 48$ pixels into 9 $16\times 16$ patches. Each of those patches is considered to be a "word"/"token" and projected to a feature space. With adding positional encodings and a token for classification (the "class token") on top, we can apply a Transformer as usual to this sequence and start training it for our task. A nice GIF visualization of the architecture is shown below (figure credit - [Phil Wang](https://github.com/lucidrains/vit-pytorch/blob/main/images/vit.gif)):

<center width="100%"><img src="https://github.com/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial15/vit.gif?raw=1" width="600px"></center>

We will walk step by step through the Vision Transformer, and implement all parts by ourselves. First, let's implement the image preprocessing: an image of size $N\times N$ has to be split into $(N/M)^2$ patches of size $M\times M$. These represent the input words to the Transformer.

In [None]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Inputs:
        x - torch.Tensor representing the image of shape [B, C, H, W]
        patch_size - Number of pixels per dimension of the patches (integer)
        flatten_channels - If True, the patches will be returned in a flattened format
                           as a feature vector instead of a image grid.
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
    x = x.flatten(1,2)              # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2,4)          # [B, H'*W', C*p_H*p_W]
    return x

Let's take a look at how that works for our CIFAR examples above. We choose a patch size of 4 and visualize the patches below:

In [None]:
img_patches = img_to_patch(CIFAR_images, patch_size=4, flatten_channels=False)

fig, ax = plt.subplots(CIFAR_images.shape[0], 1, figsize=(14,3))
fig.suptitle("Images as input sequences of patches")
for i in range(CIFAR_images.shape[0]):
    img_grid = torchvision.utils.make_grid(img_patches[i], nrow=64, normalize=True, pad_value=0.9)
    img_grid = img_grid.permute(1, 2, 0)
    ax[i].imshow(img_grid)
    ax[i].axis('off')
plt.show()
plt.close()

# **Question 2**
> Given the image size in Question 1, how many patches would we obtain per image using patch sizes 2, 4, 8, 16, and 32? In the above code, we set the patch size to 4. Add code below to print the shape of ```img_patches```. What does each dimension of ```img_patches``` represent?

# **Question 3**
> How does increasing the number of patches ($n$) influence the required memory for attention and why?

Compared to the original images, it is much harder to recognize the objects from those patch lists now. Still, this is the input we provide to the Transformer for classifying the images. The model has to learn itself how it has to combine the patches to recognize the objects. The inductive bias in CNNs that an image is a grid of pixels, is lost in this input format.

After we have looked at the preprocessing, we can now start building the Transformer model. The fundamentals of Multi-Head Attention are revisited [here](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html). For now, we will simply use the PyTorch module `nn.MultiheadAttention` ([docs](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html?highlight=multihead#torch.nn.MultiheadAttention)). Further, we use the Pre-Layer Normalization version of the Transformer blocks proposed by [Ruibin Xiong et al.](http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf) in 2020. The idea is to apply Layer Normalization not in between residual blocks, but instead as a first layer in the residual blocks. This reorganization of the layers supports better gradient flow and removes the necessity of a warm-up stage. A visualization of the difference between the standard Post-LN and the Pre-LN version is shown below.

<center width="100%"><img src="https://github.com/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial15/pre_layer_norm.svg?raw=1" width="400px"></center>

The implementation of the Pre-LN attention block looks as follows:

In [None]:
class AttentionBlock(nn.Module):

    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of input and attention feature vectors
            hidden_dim - Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than embed_dim)
            num_heads - Number of heads to use in the Multi-Head Attention block
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads,
                                          dropout=dropout)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )


    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x

# **Question 4**
> Why does Pre-Layer Normalization help with gradient flow and remove the need for a warm-up stage? Hint: review Theorem 1 in [Ruibin Xiong et al.](http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf).

Now we have all modules ready to build our own Vision Transformer. Besides the Transformer encoder, we need the following modules:

* A **linear projection** layer that maps the input patches to a feature vector of larger size. It is implemented by a simple linear layer that takes each $M\times M$ patch independently as input.
* A **classification token** that is added to the input sequence. We will use the output feature vector of the classification token (CLS token in short) for determining the classification prediction.
* Learnable **positional encodings** that are added to the tokens before being processed by the Transformer. Those are needed to learn position-dependent information, and convert the set to a sequence. Since we usually work with a fixed resolution, we can learn the positional encodings instead of having the pattern of sine and cosine functions.
* An **MLP head** that takes the output feature vector of the CLS token, and maps it to a classification prediction. This is usually implemented by a small feed-forward network or even a single linear layer.

With those components in mind, let's implement the full Vision Transformer below:

# **Question 5**
> Here we opt to *learn* the positional embeddings instead of keeping them fixed. Why is this okay here? What could we do to handle higher resolution images?

In [None]:
class VisionTransformer(nn.Module):

    def __init__(self, embed_dim, hidden_dim, num_channels, num_heads, num_layers, num_classes, patch_size, num_patches, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of the input feature vectors to the Transformer
            hidden_dim - Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels - Number of channels of the input (3 for RGB)
            num_heads - Number of heads to use in the Multi-Head Attention block
            num_layers - Number of layers to use in the Transformer
            num_classes - Number of classes to predict
            patch_size - Number of pixels that the patches have per dimension
            num_patches - Maximum number of patches an image can have
            dropout - Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size

        # Layers/Networks
        self.input_layer = nn.Linear(num_channels*(patch_size**2), embed_dim)
        self.transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers)])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1,1+num_patches,embed_dim))


    def forward(self, x):
        # Preprocess input
        x = img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)

        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:,:T+1]

        # Apply Transforrmer
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        # Perform classification prediction
        cls = x[0]
        out = self.mlp_head(cls)
        return out

# **Question 6**
> The code to initialize the input layer is: ```self.input_layer = nn.Linear(num_channels*(patch_size**2), embed_dim)```. What is the input layer doing in the forward function? Describe its role.

# **Question 7**
> Look at the forward function in the ```VisionTransformer``` class. At what position is the"class" token? What is its role?

Finally, we can put everything into a PyTorch Lightning Module as usual. We use `torch.optim.AdamW` as the optimizer, which is Adam with a corrected weight decay implementation. Since we use the Pre-LN Transformer version, we do not need to use a learning rate warmup stage anymore. Instead, we use the same learning rate scheduler as the CNNs we will benchmark against.

In [None]:
class ViT(pl.LightningModule):

    def __init__(self, model_kwargs, lr):
        super().__init__()
        self.save_hyperparameters()
        self.model = VisionTransformer(**model_kwargs)
        self.example_input_array = next(iter(train_loader))[0]

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr)
        lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100,150], gamma=0.1)
        return [optimizer], [lr_scheduler]

    def _calculate_loss(self, batch, mode="train"):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = F.cross_entropy(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        self.log(f'{mode}_loss', loss)
        self.log(f'{mode}_acc', acc)
        return loss

    def training_step(self, batch, batch_idx):
        loss = self._calculate_loss(batch, mode="train")
        return loss

    def validation_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="val")

    def test_step(self, batch, batch_idx):
        self._calculate_loss(batch, mode="test")

## Experiments

Commonly, Vision Transformers are applied to large-scale image classification benchmarks such as ImageNet to leverage their full potential. However, here we take a step back and ask: can Vision Transformer also succeed on classical, small benchmarks such as CIFAR10? To find this out, we train a Vision Transformer from scratch on the CIFAR10 dataset. Let's first create a training function for our PyTorch Lightning module which also loads the pre-trained model if you have downloaded it above.

In [None]:
def train_model(train, **kwargs):
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, "ViT"),
                         accelerator="gpu" if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=180,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
                                    LearningRateMonitor("epoch")])
    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, "ViT.ckpt")
    if not train:
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        model = ViT.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
    else:
        pl.seed_everything(42) # To be reproducable
        model = ViT(**kwargs)
        trainer.fit(model, train_loader, val_loader)
        model = ViT.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    # Test best model on validation and test set
    val_result = trainer.test(model, val_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}

    return model, result

The next step would be to start training our model. As seen in our implementation, we have a couple of hyperparameters that we have to set. When creating this notebook, we have performed a small grid search over hyperparameters and listed the best hyperparameters in the cell below. Nevertheless, it is worth discussing the influence that each hyperparameter has, and what intuition we have for choosing its value.

First, let's consider the patch size. The smaller we make the patches, the longer the input sequences to the Transformer become. While in general, this allows the Transformer to model more complex functions, it requires a longer computation time due to the way attention weights scale with sequence length, as discussed in Question 3. Furthermore, small patches can make the task more difficult since the Transformer has to learn which patches are close-by, and which are far away. We experimented with patch sizes of 2, 4, and 8 which gives us the input sequence lengths of 256, 64, and 16 respectively. We found 4 to result in the best performance and hence pick it below.

Next, the embedding and hidden dimensionality have a similar impact on a Transformer as to an MLP. The larger the sizes, the more complex the model becomes, and the longer it takes to train. In Transformers, however, we have one more aspect to consider: the query-key sizes in the Multi-Head Attention layers. Each key has the feature dimensionality of `embed_dim/num_heads`. Considering that we have an input sequence length of 64, a minimum reasonable size for the key vectors is 16 or 32. Lower dimensionalities can restrain the possible attention maps too much. We observed that more than 8 heads are not necessary for the Transformer, and therefore pick an embedding dimensionality of `256`. The hidden dimensionality in the feed-forward networks is usually 2-4x larger than the embedding dimensionality, and thus we pick `512`.

Finally, the learning rate for Transformers is usually relatively small, and in papers, a common value to use is 3e-5. However, since we work with a smaller dataset and have a potentially easier task, we found that we are able to increase the learning rate to 3e-4 without any problems. To reduce overfitting, we use a dropout value of 0.2. Remember that we also use small image augmentations as regularization during training.

Feel free to explore the hyperparameters yourself by changing the values below. In general, the Vision Transformer did not show to be too sensitive to the hyperparameter choices on the CIFAR10 dataset.

Instead of training the model, in the interest of time, we will just load the pretrained checkpoint. If you have the time, you can change ```train=False``` to ```train=True``` in the input to the ```train_model``` function and try retraining yourself.

In [None]:
model, results = train_model(train=False, model_kwargs={
                                'embed_dim': 256,
                                'hidden_dim': 512,
                                'num_heads': 8,
                                'num_layers': 6,
                                'patch_size': 4,
                                'num_channels': 3,
                                'num_patches': 64,
                                'num_classes': 10,
                                'dropout': 0.2
                            },
                            lr=3e-4)
print("ViT results", results)

The Vision Transformer achieves a validation and test performance of about 77%. In comparison, almost all CNN architectures tested [here](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial5/Inception_ResNet_DenseNet.html) obtain a classification performance of around 90%. This is a considerable gap and shows that although Vision Transformers perform strongly on ImageNet with potential pretraining, they cannot come close to simple CNNs on CIFAR10 when being trained from scratch. The differences between a CNN and Transformer can be well observed in the training curves. Let's look at them in a tensorboard below:

In [None]:
# Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH!
%tensorboard --logdir ../saved_models/tutorial15/tensorboards/

<center><img src="https://github.com/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial15/tensorboard_screenshot.png?raw=1" width="100%"/></center>

The tensorboard compares the Vision Transformer to a ResNet trained on CIFAR10. When looking at the training losses, we see that the ResNet learns much more quickly in the first iterations. While the learning rate might have an influence on the initial learning speed, we see the same trend in the validation accuracy. The ResNet achieves the best performance of the Vision Transformer after just 5 epochs (2000 iterations). Further, while the ResNet training loss and validation accuracy have a similar trend, the validation performance of the Vision Transformers only marginally changes after 10k iterations while the training loss has almost just started going down. Yet, the Vision Transformer is also able to achieve close to 100% accuracy on the training set.


# **Question 8**
> Based on these curves, do you think we could get improved performance by training for longer (i.e., for more epochs)?

# **Question 9**
> What would explain this behavior in the learning curves? Why is the vision transformer not able to achieve the same validation accuracy as the ResNet when trained on a small dataset like CIFAR10? What could we try to fix this issue?

#✅**Checkpoint 1**
Check your answers to Part 1 with a TA before moving on to Part 2.

# **Part 2**
---
## Learning Outcomes
As we have seen in Part 1, vision transformers underperform when trained from scratch on small datasets. Then how can we benefit from the vision transformer architecture when we do not have access to millions (or even billions) of annotated images for our task? We will answer this question in Part 2.

By the end of Part 2, you should be able to:

* Load and use pretrained Vision Transformer models from the transformers library

* Visualize and interpret attention maps at different layers of a ViT

* Adapt pretrained vision transformers to new classification tasks using linear probing

* Perform and analyse zero-shot classification using CLIP

The first method we will investigate is adapting pretrained vision transformers to classification tasks using linear probing. Specifically, we will adapt [Dinov2](https://arxiv.org/abs/2304.07193), a vision transformer trained with self-supervision on millions of images, to flower classification. Start by loading the model from the transformers library in the next cell.

In [None]:
from transformers import AutoModel, AutoImageProcessor

dinov2_vits14 = AutoModel.from_pretrained("facebook/dinov2-small", trust_remote_code=True).to(device)
dinov2_vits14.set_attn_implementation('eager')
preprocess = AutoImageProcessor.from_pretrained("facebook/dinov2-small", trust_remote_code=True)

You can visualize the model architecture and image pre-processing pipeline by printing them out in the next cell.

In [None]:
print(dinov2_vits14)
print(preprocess)

# **Question 10**
> The model outputs pretrained features, vectors representing the image. Based on the output of the previous cell, how long is one of these vectors?

It is very important to remember that the vision transformer loaded in this part was *pretrained* on large-scale datasets. This pretraining enables it to operate as a good *feature extractor*, producing meaningful representations of images that can be used for downstream tasks such as classification.

We can examine how meaningful the representations of the pretrained vision transformer are by visualising the attention maps at the final layer of the model. We will do this next for an example image.

In [None]:
def visualize_attention(attentions, img, patch_size, device):
    # make the image divisible by the patch size
    w, h = img.shape[1] - img.shape[1] % patch_size, img.shape[2] - \
        img.shape[2] % patch_size
    img = img[:, :w, :h].unsqueeze(0)

    w_featmap = img.shape[-2] // patch_size
    h_featmap = img.shape[-1] // patch_size

    nh = attentions.shape[1]  # number of head

    # keep only the output patch attention
    attentions = attentions[0, :, 0, 1:].reshape(nh, -1)

    attentions = attentions.reshape(nh, w_featmap, h_featmap)
    attentions = nn.functional.interpolate(attentions.unsqueeze(
        0), scale_factor=patch_size, mode="bicubic")[0].cpu().numpy()

    return attentions


def plot_attention(img, attention):
    n_heads = attention.shape[0]

    plt.figure(figsize=(10, 10))
    text = ["Original Image", "Head Mean"]
    for i, fig in enumerate([img, np.mean(attention, 0)]):
        plt.subplot(1, 2, i+1)
        plt.imshow(fig, cmap='inferno')
        plt.title(text[i])
    plt.show()

    plt.figure(figsize=(10, 10))
    for i in range(n_heads):
        plt.subplot(n_heads//3, 3, i+1)
        plt.imshow(attention[i], cmap='inferno')
        plt.title(f"Head n: {i+1}")
    plt.tight_layout()
    plt.show()

Obtain the example image by running the next cell.

In [None]:
! wget "https://github.com/aryan-jadon/Medium-Articles-Notebooks/raw/main/Visualizing%20Attention%20in%20Vision%20Transformer/corgi_image.jpg"

Visualise the attention maps at the final layer of the pretrained model by running the next cell.

In [None]:
path = '/content/corgi_image.jpg'
img = Image.open(path)
img_pre = preprocess(img, return_tensors="pt")
with torch.no_grad():
    outputs = dinov2_vits14(img_pre['pixel_values'].to(device), output_attentions=True)
layer = -1
attentions = outputs.attentions[layer] # get the attention map at [layer]
attentions = visualize_attention(attentions, img_pre['pixel_values'], 14, device)
plot_attention(img, attentions)

# **Question 11**
> Modify the ```layer``` variable in the previous cell to visualise the attention maps earlier and later in the network. What general pattern do you observe? Why might the model's pretraining cause this pattern to emerge?



Next, we are going to load and visualise the [Oxford Flowers Dataset](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/). This is a small dataset for classifying 102 different flower species.

In [None]:
# Load Oxford Flowers data.
batch_size = 64
transform_vis = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),  # convert to RGB
    transforms.ToTensor(), # convert to PyTorch tensor
    transforms.Resize(244), # resize while maintaining aspect ratio
    transforms.CenterCrop(224) # take a square center crop
])
preprocess_return_tensor = transforms.Lambda(lambda image: preprocess(image, return_tensors="pt"))
flowers_train_data_unnormalized = Flowers102(
    '../data', split="train", download=True, transform=transform_vis
    )
flowers_train_data = Flowers102(
    '../data', split="train", download=True, transform=preprocess_return_tensor
    )
flowers_val_data = Flowers102(
    '../data', split="val", download=True, transform=preprocess_return_tensor
    )
flowers_test_data = Flowers102(
    '../data', split="test", download=True, transform=preprocess_return_tensor
    )

# Visualize data
visualization_loader = torch.utils.data.DataLoader(
    flowers_train_data_unnormalized,
    batch_size=batch_size, shuffle=True)

# Get one batch of images and labels
images, labels = next(iter(visualization_loader))

# Create a grid of images
fig = plt.figure(figsize=(12, 8))
for i in range(batch_size):
    plt.subplot(8, 8, i + 1)  # 8x8 grid
    plt.tight_layout()
    plt.imshow(images[i].permute(1, 2, 0))  # [C, H, W] → [H, W, C]
    plt.title(f"Label: {labels[i].item()}")
    plt.axis('off')

plt.show()

# **Question 12**
> How many images are in the Oxford Flowers training set? How many images are in the CIFAR10 training set. Given this, what do you think would happen if we tried to train the vision transformer from scratch on Oxford Flowers?

Instead of training the Dinov2 vision transformer from scratch, we are going to train a linear classifier to predict the flower class from features obtained with the pretrained Dinov2 model. To make training faster, we will pre-compute and save the features from Dinov2 before training our classifier on top. Run the next cell to define the ```compute_embeddings``` function.

In [None]:
def compute_embeddings(dataset, suffix) -> None:
    """
    Computes DINOv2 embeddings for all images in the dataset and saves them
    as an N x 384 dimensional array in an npy file, where N is the number of
    images, and 384 is the embedding dimension.

    [dataset]: PyTorch dataset containing images to encode with DINOv2
    [suffix]: suffix to add to the name of the npy file

    Results are saved in the npy file: dinov2_embeddings_[suffix].npy
    """
    embeddings_list = []
    with torch.no_grad():
        for ind in range(len(dataset)):
            image, number = dataset[ind]
            outputs = dinov2_vits14(image['pixel_values'].to(device))
            last_hidden_states = outputs[0]
            cls_token = last_hidden_states[:, 0, :]
            embeddings_np = cls_token.squeeze().cpu().numpy()
            print(embeddings_np.shape)
            embeddings_list.append(embeddings_np)
            print("Embedded image: " + str(ind + 1) + "/" + str(len(dataset)))

    all_embeddings = np.array(embeddings_list)
    output_file_name = "dinov2_embeddings_" + suffix + ".npy"
    np.save(output_file_name, all_embeddings)
    print("Saved all Dinov2 embeddings in " + output_file_name)

# **Question 13**
> Examine the ```compute_embeddings``` function. What feature vector are we saving for each image and why?

Hint: what vector did we use for classification in Part 1? It has a particular name and a particular role...

Run the next cell to compute and save the embeddings for the training, validation, and test images in the Oxford Flowers Dataset. Running this cell may take a couple of minutes. Move on to the next question while you wait for this cell to finish running.

In [None]:
# Compute and save the embeddings for the training images.
compute_embeddings(flowers_train_data, "train")

# Compute and save the embeddings for the validation images.
compute_embeddings(flowers_val_data, "val")

# Compute and save the embeddings for the test_images.
compute_embeddings(flowers_test_data, "test")

# **Question 14**
> Below we define a ```Dataset``` class for the Oxford Flowers Dataset. In the ```__getitem__``` method, what is ```x_tensor```?

In [None]:
# Define the dataset.
class Dinov2Flowers(torch.utils.data.Dataset):
    """Dinov2 features for Oxford Flowers dataset images.
    """
    def __init__(self, split):
      self.dataset = Flowers102('../data', split, download=True)
      self.embeddings = np.load("dinov2_embeddings_" + split + ".npy")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img_embeddings = self.embeddings[idx, :] / 8.0 # (1, 384)
        x_tensor = torch.tensor(img_embeddings / 768.0, dtype=torch.float32)
        flower_class = self.dataset[idx][1]
        return x_tensor, flower_class

Below we define a simple linear probe classifier for training on top of the Dinov2 features.

In [None]:
# Define the classifier.
class SimpleClassifier(nn.Module):
    """
    A simple linear probe classifier.
    """

    def __init__(self, in_features, out_features):
        super(SimpleClassifier, self).__init__()
        self.layer1 = nn.Linear(in_features, out_features)
        self.activation1 = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        x = self.layer1(x)
        x = self.activation1(x)
        return x


Initialise the classifier and put it on the GPU in the next cell.

In [None]:
# Initialize the classifier.
# The dimension of the Dinov2 input vector is 384 => [in_features]=384.
# Oxford Flowers has 102 classes => [out_features]=102.
model = SimpleClassifier(in_features=384, out_features=102)
model = model.cuda() # put the model on the GPU for fast training

# **Question 15**
> Write code to count the number of parameters in the simple classifier and the Dinov2 transformer model respectively. Which model has more parameters? Would it be faster to train the linear probe or the Dinov2 model from scratch and why?

Run the next two cells to define the training and validation functions as well as a function to set the seeds for reproducibility. Setting the seeds ensures rerunning the notebook produces the same results each time.

In [None]:
# Define training functions.
def train(epoch):
    loss_avg = 0
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        # This will zero out the gradients for this batch.
        optimizer.zero_grad()
        output = model(data)
        # Calculate the negative log likelihood loss. It is useful to train a classification problem with C classes.
        loss = F.nll_loss(output, target)
        loss_avg += loss.item()
        loss.backward()
        # Do a one-step update on our parameters.
        optimizer.step()
        # Print out the loss.
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    return loss_avg / len(train_loader.dataset)

def val():
    model.eval()
    val_loss = 0
    correct = 0
    for data, target in val_loader:
        data, target = data.cuda(), target.cuda()
        with torch.inference_mode():
            output = model(data)
        val_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()

    val_loss /= len(val_loader.dataset)
    print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        val_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))
    return val_loss, 100. * correct / len(val_loader.dataset)

In [None]:
def set_all_seeds(seed: int = 42):
    """
    Sets seeds for Python random, NumPy, and PyTorch (CPU and GPU)
    for full reproducibility.
    """
    # Python
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

    # NumPy
    np.random.seed(seed)

    # PyTorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU

    # Ensure deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    print(f"[Seed Set] All seeds set to {seed}")

Now we can set the hyperparameters for training our simple classifier. The hyperparameters are ```batch_size```, ```learning_rate```, and ```epochs```. These can be chosen to maximise the accuracy on the validation set. If we had more time, you could rerun the training procedure for different sets of hyperparameters and choose the ones that achieve the best validation accuracy. For now, they are set to the values below after some minimal tuning on the validation set.

In [None]:
# Set the seed for reproducibility.
set_all_seeds(0)
# Choose training hyperparameters.
batch_size = 32
learning_rate = 0.001
epochs = 60

# Set up the dataloaders.
train_loader = torch.utils.data.DataLoader(
    Dinov2Flowers(split="train"),
    batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(
    Dinov2Flowers(split="val"),
    batch_size=1, shuffle=False)

# Set up the optimizer.
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Run the next cell to train the classifier. Training will take around five minutes, and the model should achieve around 94% accuracy on the validation set after training completes. Please work on the next question while you wait for the cell to finish running.

In [None]:
epoch_list = []
train_losses = []
val_losses = []
val_accs = []
for epoch in range(1, epochs + 1):
    epoch_list.append(epoch)
    train_loss = train(epoch)
    train_losses.append(train_loss)
    val_loss, val_acc = val()
    val_losses.append(val_loss)
    val_accs.append(val_acc)

# **Question 16**
> Why is it important to use the validation set to determine hyperparameters? Why not use the training or test sets?

Run the next cell to evaluate our model on the test set.

In [None]:
def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        with torch.inference_mode():
            output = model(data)
        test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, 100. * correct / len(test_loader.dataset)

test_loader = torch.utils.data.DataLoader(
    Dinov2Flowers(split="test"),
    batch_size=1, shuffle=False)
test()

# **Question 17**
> How well does our model perform on the test set? Assuming that 70% accuracy is considered 'pretty good' on this dataset, are you satisfied with our model's performance? What techniques and strategies did we use that could explain the performance of our model?

#✅**Checkpoint 2**
Check your answers so far for Part 2 with a TA before moving on to the rest of the practical.

In the next section, we will investigate another method of adapting powerful pretrained vision transformers to our specific flower classification task. Unlike in the previous section, we will not need to do any training. Instead, we will use a vision transformer pretrained with [Contrastive Language-Image Pretraining (CLIP)](https://arxiv.org/abs/2103.00020), for 'zero-shot classification.'

CLIP is a method that allows for large-scale pretraining of models using weak supervision. Given a large-scale dataset of images and captions describing the images, CLIP trains an image model ('image encoder') and a text model ('text encoder') to map each image to a feature vector and each caption to a feature vector such that images and their matching captions are mapped to vectors that are 'close' to and images and unrelated captions are mapped to vectors that are 'far' away from each other. 'Closeness' and 'farness' are evaluated using the cosine similarity. This training objective is a form of weak supervision, since the texts are not constrained to a particular task and can be easily obtained by collecting images and unstructured captions from the web. Because the data constraints are few, this objective easily scales to millions and billions of image-text pairs.

Zero-shot classification means that the CLIP model can be adapted to new classification tasks without being trained on any data samples (i.e., trained on *zero* data samples) for those classification tasks. This is possible because the CLIP training objective is general and allows for the specification of the text descriptions at inference time. Specifically, given a new problem with $C$ classes, we can encode the $C$ classes as $C$ text descriptions with the CLIP text encoder. For any image, we can encode it with the image encoder. Then, using the cosine similarity, we can compare the image CLIP feature vector to the text feature vectors for each class and pick the class that the image is closest to. This is what we will do next for the flower dataset. In our case, the image encoder is a vision transformer, and the text encoder is a text transformer. The image and text transformers were pretrained jointly using the CLIP objective.

Run the next cell to install the CLIP library from OpenAI.

In [None]:
!pip install git+https://github.com/openai/CLIP.git

Run the next cell to import the clip module and see a list of available models.

In [None]:
import clip

clip.available_models()

We will choose the ViT-L/14 model in the next cell, which includes both the pretrained image encoder, a vision transformer with a 14 x 14 patch size, and the text encoder, a text transformer pretrained using CLIP with the image encoder.

In [None]:
model, preprocess = clip.load("ViT-L/14")
model.cuda().eval()

print(model)
print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")

Next, we will load the Oxford Flowers Dataset again but this time using the CLIP image pre-processing pipeline.

In [None]:
# Get image input.
flowers_test_data_vis = Flowers102('../data', split="test", download=True)
flowers_test_data = Flowers102('../data', split="test", download=True, transform=preprocess)

To get the text input, we will use the flower classes in the Oxford Flowers dataset. Rather than just encoding the classes by themselves, we will place each class into a fixed *prompt template* that adds some context to the text descriptions. Because captions tend to describe the images they correspond to, adding the prompt template makes the flower classes more similar to the text descriptions the model was trained on. This improves the accuracy on our downstream classification task.

In [None]:
# Get text input
flower_text_descriptions = [f"a photo of a {label}, a type of flower" for label in flowers_test_data.classes]
texts = [label for label in flowers_test_data.classes]
text_tokens = clip.tokenize(flower_text_descriptions).cuda()

The cosine similarity for two vectors $\mathbf{v}$ and $\mathbf{w}$ is the dot product of the two normalised vectors. Letting $|| \cdot ||$ represent the $\mathcal{L}_{2}$ norm, we have the following equation for the cosine similarity:

$similarity(\mathbf{v}, \mathbf{w}) = \frac{\mathbf{v}\cdot \mathbf{w}}{|| \mathbf{v} |||| \mathbf{w}||}$ with

$\mathbf{v}\cdot \mathbf{w} = \sum_{i=1}^{N}v_{i}w_{i}$,

where $v_{i}$ is the $i^{th}$ element of $\mathbf{v}$, and $w_{i}$ is the $i^{th}$ element of $\mathbf{w}$.

# **Question 18**
> In the below code, the image and text features are computed for the Oxford Flowers test set. Each image is mapped to a vector of length 768, and each text description is mapped to a vector of length 768 in a pretrained joint image-text embedding space where the vectors can be compared with the cosine similarity. Write code in the ```TO DO``` block below to calculate the cosine similarity between the image and text features, and store the result in the ```similarity``` variable. The rest of the code chooses the flower class with the highest cosine similarity as the predicted label for the image. What classification accuracy do you get when you run the code?

In [None]:
def test():
    model.eval()
    correct = 0
    sample_ind = 1
    for data, target in test_loader:
        data, target = data.cuda(), target.cuda()
        with torch.inference_mode():
            # Get the image and text features
            image_features = model.encode_image(data).float()
            text_features = model.encode_text(text_tokens).float()
            print("image_features.shape: " + str(image_features.shape))
            print("text_features.shape: " + str(text_features.shape))
            # TO DO: Calculate the cosine similarity and save it to [similarity]


            pred = similarity.max(1, keepdim=True)[1] # get the index of the max similarity
            sample_correct = pred.eq(target.data.view_as(pred)).long().cpu().sum()
            print(str(sample_ind) + "/" + str(len(test_loader)))
            sample_ind += 1
            correct += sample_correct

    print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return 100. * correct / len(test_loader.dataset)

test_loader = torch.utils.data.DataLoader(
    flowers_test_data,
    batch_size=32, shuffle=False)
test()

# **Question 19**
> Are you satisfied with the performance of CLIP on our flower classification task? How does it compare to training the linear probe on top of Dinov2 features? Why do you think one method works better than the other? What could we do to improve the zero-shot performance of CLIP on this task?

#✅**Checkpoint 3**
Congratulations! You have finished the lab! Check your answers to the remaining questions with a TA.

## Conclusion

In this practical, we have investigated the Vision Transformer (ViT) and seen how to best leverage its capabilities for the task of image classification. We implemented and trained a ViT from scratch on the small CIFAR10 dataset as well as adapted vision transformers pretrained on large-scale datasets to flower classification. The code from this lab may be used and adapted for your own projects, so I recommend saving this notebook for future reference.

### References

Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale." International Conference on Learning Representations (2021). [link](https://arxiv.org/pdf/2010.11929.pdf)

Chen, Xiangning, et al. "When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations." arXiv preprint arXiv:2106.01548 (2021). [link](https://arxiv.org/abs/2106.01548)

Amini-Naieni, Niki, et al. "Instant Uncertainty Calibration of NeRFs Using a Meta-Calibrator." European Conference on Computer Vision (2024). [link](https://arxiv.org/abs/2312.02350)

Oquab, Maxime, et al. "DINOv2: Learning Robust Visual Features without Supervision." International Conference on Machine Learning (2021). [link](https://arxiv.org/abs/2304.07193)

Radford, Alec, et al. "Learning Transferable Visual Models From Natural Language Supervision." [link](https://arxiv.org/abs/2103.00020)

Nilsback, Maria-Elena, et al. "Automated Flower Classification over a Large Number of Classes." Indian Conference on Computer Vision, Graphics and Image Processing (2008). [link](https://www.robots.ox.ac.uk/~vgg/publications/2008/Nilsback08/)

Tolstikhin, Ilya, et al. "MLP-mixer: An all-MLP Architecture for Vision." arXiv preprint arXiv:2105.01601 (2021). [link](https://arxiv.org/abs/2105.01601)

Xiong, Ruibin, et al. "On layer normalization in the transformer architecture." International Conference on Machine Learning. PMLR, 2020. [link](http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf)

### Code borrowed and adapted from the following sources

[1] https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb

[2] https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/DINO/Visualize_self_attention_of_DINO.ipynb#scrollTo=4o2nUmAOZ2Bo

[3] https://colab.research.google.com/drive/1tRRuT21W3VUvORCFRazrVaFLSWYbYoqL?usp=sharing#scrollTo=Rb5jVjzvpacU

[4] https://github.com/TrasperJ/102-flowers-classfication-with-PyTorch/blob/master/102flower_classification.ipynb

[5] https://github.com/niki-amini-naieni/instantcalibration

[6] https://colab.research.google.com/github/phlippe/uvadlc_notebooks/blob/master/docs/tutorial_notebooks/tutorial15/Vision_Transformer.ipynb