<a href="https://colab.research.google.com/github/rudraxx/pytorch/blob/main/Vision_Transformer_in_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vision Transformer (ViT)

I will be implementing the Vision Transformer network from scratch. We will train the network on CIFAR-10 data.

Link to the original Paper:


1.   [Attention is All you need](https://arxiv.org/pdf/1706.03762)
2.   [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
](https://arxiv.org/abs/2010.11929)



Articles used as reference:


1.   [Article on Medium](https://towardsdatascience.com/implementing-vision-transformer-vit-from-scratch-3e192c6155f0)
2.   List item



Overview of the VIT Architecture
![image.png](https://miro.medium.com/v2/resize:fit:1400/format:webp/1*Q-mBZkDz7TUnVGw1KPwqOA.png)

In [14]:
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10

In [31]:
#Set up some of the hyper parameters

#Load the config
config ={"image_size": 128, "patch_size": 16,
         "num_channels":3, "hidden_size": 17,
         "hidden_dropout_prob": 0.2}


**Stages of the model pipeline**
1. Convert images into patches
2. Run the patches through linear layer to get patch embeddings. Layer weights are learnt.
3. Add the CLS token as the first token for all instances in the batch
3.  Get positional embeddings( sin/cos transform)
4. Input to the transformer is the sum of patch and positional embeddings



In [29]:
#Convert images into patches
class PatchEmbeddings(nn.Module):
  def __init__(self, config):
    super().__init__()

    self.image_size   = config["image_size"] # Size of the incoming images ih xiw
    self.patch_size   = config["patch_size"] # What is the size of the patch hxw
    self.num_channels = config["num_channels"] # ch
    self.hidden_size  = config["hidden_size"] # d_hidden

    #Calculate the number of patches from the image and patch size
    self.num_patches = (self.image_size // self.patch_size)**2 # p x p

    #Create the projection to convert the images into patches
    # This layer should take each patch and convert it into a 1D vector of size (1, d_hidden)
    self.projection = nn.Conv2d(self.num_channels, self.hidden_size,
                                kernel_size=self.patch_size,
                                stride=self.patch_size)


  def forward(self, x):

    # x is of shape: (B X num_channels X image_size X image_size)
    # Required output is of shape: (B X num_patches X d_hidden)

    #1) (B X num_channels X image_size X image_size) -> (B X d_hidden X self.num_patches X self.num_patches)
    x = self.projection(x)
    #2) (B X d_hidden X self.num_patches X self.num_patches) -> (B X d_hidden X self.num_patches*self.num_patches)
    x = x.flatten(2)
    #3) (B X d_hidden X self.num_patches*self.num_patches) -> (B X self.num_patches*self.num_patches X d_hidden)
    x=x.transpose(1,2)

    return x



Input tensor shape to the patch embedding model: 
 torch.Size([32, 3, 128, 128]). (B X num_channels X image_size X image_size)
After the patch embedding:
 torch.Size([32, 64, 17]). (B X num_patches X d_hidden)


In [None]:
#Test out the patch embeddings class
x = torch.randn(32,config["num_channels"], config["image_size"], config["image_size"])
print(f"Input tensor shape to the patch embedding model: \n {x.shape}. (B X num_channels X image_size X image_size)")


patch_embedding = PatchEmbeddings(config)

x = patch_embedding(x)
print(f"After the patch embedding:\n {x.shape}. (B X num_patches X d_hidden)")



In [37]:
# Add the CLS token to the beginning of each sequence
class Embeddings(nn.Module):
  def __init__(self, config):
    super().__init__()

    self.patch_embeddings = PatchEmbeddings(config)

    #Create a learnable [CLS] token. This is added before the first patch,
    # so should be the same dimension of the other patches
    self.cls_token = nn.Parameter(torch.randn(1, config["hidden_size"]))

    # Position embedding is added(summed up) with the B X CLS+num_patches X d_hidden,
    # so ensure the shape of position_embedding takes this into account
    self.position_embeddings = \
      nn.Parameter(torch.randn(1,self.patch_embeddings.num_patches + 1, config["hidden_size"]))

    # Drop out layer
    self.dropout = nn.Dropout(config["hidden_dropout_prob"])

  def forward(self, x):

    batch_size = x.shape[0]
    #Get the patch embeddings
    x = self.patch_embeddings(x)

    #Add the CLS token to every batch item
    cls_tokens = self.cls_token.unsqueeze(0).repeat(batch_size, 1, 1)  # Shape: (B, 1, d_hidden)
    # Concat the cls token. Note that the size of the hidden_size stays the same. The number of patches/tokens increases by 1.
    # because CLS is added as the first token.
    x = torch.cat((cls_tokens, x), dim=1)

    #Add positional embedding
    x = x + self.position_embeddings

    # Dropout
    x = self.dropout(x)

    return x


In [None]:
#Test code to try the math for addition of the cls to the batch
cls_x1 = torch.randn(1,10)
print("cls_x1", cls_x1.shape)

expanded_cls = cls_x1.unsqueeze(0).repeat(32,1,1)
print("expanded_cls", expanded_cls.shape)

x = torch.randn(32,200, 10)
print("x", x.shape)

cat_result = torch.cat((expanded_cls, x), dim=1)
print("cat_result", cat_result.shape)



In [39]:
#Test out the embeddings
embeddings = Embeddings(config)
print(embeddings)


Embeddings(
  (patch_embeddings): PatchEmbeddings(
    (projection): Conv2d(3, 17, kernel_size=(16, 16), stride=(16, 16))
  )
  (dropout): Dropout(p=0.2, inplace=False)
)


The CLS + Position + Patch embeddings have been created. Next step is to create the Transformer Encoder model

# Multi-head Attention Module

The Attention module takes sequence of embeddings as input and computes query, key and value vectors for each embedding

In [None]:
class AttentionHead(nn.Module):
  """
  A single attention head
  """

  def __init__(self, hidden_size, attention_head_size, dropout, bias=True):
    super().__init__()

    self.hidden_size = hidden_size
    self.attention_head_size = attention_head_size

    #Create the query, key and value projection layers

    self.query  = nn.Linear(hidden_size, attention_head_size, bias=bias)
    self.key    = nn.Linear(hidden_size, attention_head_size, bias=bias)
    self.value  = nn.Linear(hidden_size, attention_head_size, bias=bias)

