# Vision Transformer (ViT) implementation

This project is a PyTorch implementation of the paper: An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale. It uses the labml module, and is based off of the labmlai implementation. This model can be trained on datasets, including the CIFAR-10, for instance.

Vision transformers utilise a transformer architecture exclusively for image processing, omitting convolutional layers. They segment the image into patches and employ a transformer model on these patch embeddings. These embeddings are created by linearly transforming the flattened pixel values of each patch. Subsequently, a standard transformer encoder processes these patch embeddings alongside a classification token. The resultant encoding of the token is utilised for image classification via a Multi-Layer Perceptron (MLP).

To address the lack of inherent spatial information in the patch embeddings due to the transformer architecture, learned positional embeddings are incorporated. These positional embeddings, which consist of vectors corresponding to each patch location, are trained alongside other parameters using gradient descent.

Vision transformers demonstrate strong performance when pre-trained on extensive datasets. The recommended approach involves pre-training with an MLP classification head, followed by fine-tuning using a single linear layer. Notably, the paper surpasses the state-of-the-art (SOTA) performance with a vision transformer pre-trained on a dataset comprising 300 million images. Additionally, higher-resolution images are employed during inference while maintaining the patch size. The positional embeddings for new patch locations are determined through interpolation of learned positional embeddings.

In [2]:
import torch
from torch import nn
from labml_helpers.module import Module
from labml_nn.transformers import TransformerLayer
from labml_nn.utils import clone_module_list

In [8]:
class PatchEmbeddings(Module):
   
    def __init__(self, d_model: int, patch_size: int, in_channels: int):
     
        super().__init__()

        # We create a convolution layer with a kernel size and and stride length equal to patch size.
        # This is equivalent to splitting the image into patches and doing a linear
        # transformation on each patch.
        self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)

    def forward(self, x: torch.Tensor):
      
        # Apply convolution layer
        x = self.conv(x)
        # Get the shape.
        bs, c, h, w = x.shape
        # Rearrange to shape `[patches, batch_size, d_model]`
        x = x.permute(2, 3, 0, 1)
        x = x.view(h * w, bs, c)

        # Return the patch embeddings
        return x

In [3]:
class LearnedPositionalEmbeddings(Module):

    def __init__(self, d_model: int, max_len: int = 5_000):
        super().__init__()
        # Positional embeddings for each location
        self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)

    def forward(self, x: torch.Tensor):
        # Get the positional embeddings for the given patches
        pe = self.positional_encodings[:x.shape[0]]
        # Add to patch embeddings and return
        return x + pe

In [6]:
class ClassificationHead(Module):

    def __init__(self, d_model: int, n_hidden: int, n_classes: int):
        super().__init__()
        # First layer
        self.linear1 = nn.Linear(d_model, n_hidden)
        # Activation
        self.act = nn.ReLU()
        # Second layer
        self.linear2 = nn.Linear(n_hidden, n_classes)

    def forward(self, x: torch.Tensor):
        # First layer and activation
        x = self.act(self.linear1(x))
        # Second layer
        x = self.linear2(x)

        #
        return x


In [9]:
class VisionTransformer(Module):

    def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
                 patch_emb: PatchEmbeddings, pos_emb: LearnedPositionalEmbeddings,
                 classification: ClassificationHead):
        super().__init__()
        # Patch embeddings
        self.patch_emb = patch_emb
        self.pos_emb = pos_emb
        # Classification head
        self.classification = classification
        # Make copies of the transformer layer
        self.transformer_layers = clone_module_list(transformer_layer, n_layers)

        # `[CLS]` token embedding
        self.cls_token_emb = nn.Parameter(torch.randn(1, 1, transformer_layer.size), requires_grad=True)
        # Final normalization layer
        self.ln = nn.LayerNorm([transformer_layer.size])

    def forward(self, x: torch.Tensor):

        # Get patch embeddings. This gives a tensor of shape `[patches, batch_size, d_model]`
        x = self.patch_emb(x)
        # Concatenate the `[CLS]` token embeddings before feeding the transformer
        cls_token_emb = self.cls_token_emb.expand(-1, x.shape[1], -1)
        x = torch.cat([cls_token_emb, x])
        # Add positional embeddings
        x = self.pos_emb(x)

        # Pass through transformer layers with no attention masking
        for layer in self.transformer_layers:
            x = layer(x=x, mask=None)

        # Get the transformer output of the `[CLS]` token (which is the first in the sequence).
        x = x[0]

        # Layer normalization
        x = self.ln(x)

        # Classification head, to get logits
        x = self.classification(x)

        #
        return x