In [3]:
# installing torch
!pip install -q torch einops

In [4]:
# imports
import torch
import torch.nn as nn
from einops import rearrange, repeat

In [None]:
# setting the parameters
patch_size = 16
embedding_dim = 128
num_channels = 3
num_heads = 8
image_size = 224
batch_size = 10
dropout = 0.1
n_heads = 8
head_dim = 64
embedding_dim = 128
# feed forward layer parameters
ff_hidden_dim = 256
ff_dropout_val = 0.1
num_layers = 6
num_classes = 10

In [None]:
class PatchEmbed(nn.Module):
  def __init__(self, image_height = image_size, image_width = image_size, patch_height = patch_size, patch_width = patch_size, num_channels = num_channels, embed_dim = embedding_dim):
    super().__init__()

    self.image_height = image_height
    self.image_width = image_width
    self.patch_height = patch_height
    self.patch_width = patch_width
    self.num_channels = num_channels
    self.embed_dim = embed_dim
    print(f"Embedding dimension: {self.embed_dim}")

    self.num_patches = (self.image_height // self.patch_height) * (self.image_width // self.patch_width)
    print(f"Number of patches: {self.num_patches}")

    #patch dimension
    self.patch_dimension = self.patch_height * self.patch_width * self.num_channels
    print(f"Patch dimension: patch_height * patch_width * num_channels =  {self.patch_dimension}")

    self.cls_token = nn.Parameter(torch.randn(self.embed_dim))

    self.patch_embed = nn.Sequential(
            # This pre and post layer norm speeds up convergence
            # Comment them if you want pure vit implementation
            nn.LayerNorm(self.patch_dimension),
            nn.Linear(self.patch_dimension, self.embed_dim),
            nn.LayerNorm(self.embed_dim)
        )

    self.positional_embedding = nn.Parameter(torch.zeros(1,self.num_patches +1, self.embed_dim))
    self.patch_embed_drop = nn.Dropout(0.1)

  def forward(self,x):
    print(f"Input shape: {x.shape}")
    batch_size = x.shape[0]

    # rearranging the input
    out = rearrange(x, 'b c (nh ph) (nw pw) -> b (nh nw) (ph pw c)',
                      ph=self.patch_height,
                      pw=self.patch_width)
    print(f"Rearranged shape: {out.shape}")

    # embedding the patches
    embed_out = self.patch_embed(out)
    print(f"Embedding layer :\nInput shape: {out.shape}\nOutput shape: {embed_out.shape}")

    # adding cls token
    cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b=batch_size)
    print(f"CLS token shape: {cls_tokens.shape}")

    embed_out = torch.cat((cls_tokens, embed_out), dim=1)
    print(f"cls token : {cls_tokens.shape} + embedding : {embed_out.shape} = output : {embed_out.shape}")

    # adding positional embedding
    embed_out += self.positional_embedding
    print(f"positional embedding shape : {self.positional_embedding.shape}, output shape : {embed_out.shape}")

    # dropout
    embed_out = self.patch_embed_drop(embed_out)

    return embed_out

In [None]:
# attention

class Attention(nn.Module):
  def __init__(self, n_heads=n_heads, head_dim = head_dim, embedding_dim = embedding_dim):
    super().__init__()

    self.n_heads = n_heads
    self.head_dim = head_dim
    self.embedding_dim = embedding_dim

    # attention dimension
    self.attention_dim = self.n_heads * self.head_dim
    print(f"Attention dimension: no of heads * head dimension =  {self.attention_dim}")

    # key query value
    self.qkv_projection = nn.Linear(self.embedding_dim, 3*self.attention_dim, bias = False)

    # final output projection
    self.projection_out = nn.Sequential(
        nn.Linear(self.attention_dim, self.embedding_dim),
        nn.Dropout(0.1)
    )

  def forward(self,x):

    # saving batch size and number of channels
    batch_size, num_channels = x.shape[:2]

    # changing the input dimension (embedding dimension) to attention dimension*3
    qkv_out = self.qkv_projection(x)
    print(f"QKV dimension change :\nInput shape: {x.shape}\nOutput shape: {qkv_out.shape}")

    # splitting the output into q k v matrix
    q, k , v = qkv_out.split(self.attention_dim, dim = -1)
    print(f"q k v splitting :\ninput shape : {qkv_out.shape}\nshape of each q k v : {q.shape} ")

    # rearranging q k v for processing
    q_new = rearrange(q, 'b n (n_h h_dim) -> b n_h n h_dim',
                      n_h=self.n_heads, h_dim=self.head_dim)
    print(f"Rearranged q shape :\nbefore : {q.shape}\nafter : {q_new.shape} [batch_size, no_of_head, no_of_patch, head_dim]")   # [10, 8, 197, 64] = [batch_size, no_of_head, no_of_patch, head_dim]

    k_new = rearrange(k, 'b n (n_h h_dim) -> b n_h n h_dim',
                      n_h=self.n_heads, h_dim=self.head_dim)
    print(f"Rearranged k shape :\nbefore : {k.shape}\nafter : {k_new.shape} [batch_size, no_of_head, no_of_patch, head_dim]")

    v_new = rearrange(v, 'b n (n_h h_dim) -> b n_h n h_dim',
                      n_h=self.n_heads, h_dim=self.head_dim)
    print(f"Rearranged v shape :\nbefore : {v.shape}\nafter : {v_new.shape} [batch_size, no_of_head, no_of_patch, head_dim]")

    # attention weight calculation
    att = torch.matmul(q_new,k_new.transpose(-2,-1))*(self.head_dim**-0.5)
    print(f"(q * K)*(head_dim**-0.5):Output shape: {att.shape}")

    # passing through softmax layer
    att = torch.nn.functional.softmax(att, dim=-1)
    print(f"After passing through softmax :Output shape: {att.shape}")

    # softmax layer output * v
    out = torch.matmul(att, v_new)
    print(f"Softmax_out * V :Output shape: {out.shape}")

    # B x N x (Heads * Head Dimension) -> B x N x (Attention Dimension)
    out_final = rearrange(out, 'b n_h n h_dim -> b n (n_h h_dim)')
    print(f"Rearranged out shape :\nbefore : {out.shape}\nafter : {out_final.shape}")

    # final reshaping
    out_attention = self.projection_out(out_final)
    print(f"Final output shape :\nInput shape: {out_final.shape}\nOutput shape: {out_attention.shape}")

    return out_attention

In [None]:
# transformer layer
class TransformerLayer(nn.Module):
  def __init__(self,ff_hidden_dim = ff_hidden_dim, embed_dim = embedding_dim,ff_dropout_val = ff_dropout_val):
    super().__init__()

    self.ff_dim = ff_hidden_dim
    self.embed_dim = embed_dim
    self.ff_dropout_val = ff_dropout_val

    # normalization layer
    self.normalization_layer = nn.LayerNorm(self.embed_dim)

    # attention block
    self.attention = Attention()

    # feed forward layer
    self.ff_block = nn.Sequential(
        nn.Linear(self.embed_dim, self.ff_dim),
        nn.GELU(),
        nn.Dropout(self.ff_dropout_val),
        nn.Linear(self.ff_dim, self.embed_dim),
        nn.Dropout(self.ff_dropout_val)
    )

  def forward(self,x):

    input_val = x

    # attention block
    out_attention =self.attention(self.normalization_layer(input_val))+input_val
    print(f"Passing through attention block + residual connection:\nInput shape: {input_val.shape}\nOutput shape: {out_attention.shape}")

    # feed forward block
    out_transformLayer = self.ff_block(self.normalization_layer(out_attention))+out_attention
    print(f"Passing through feed forward block + residual connection:\nInput shape: {out_attention.shape}\nOutput shape: {out_transformLayer.shape}")

    return out_transformLayer

In [None]:
# full ViT
class ViT(nn.Module):
  def __init__(self, n_layers = num_layers, n_class = num_classes, embed_dim = embedding_dim):
    super().__init__()

    self.n_layers = n_layers
    self.n_classes = n_class
    self.embed_dim = embed_dim

    # patch embedding layer
    self.patchLayer = PatchEmbed()

    # transformer layers
    self.transformerLayers = nn.ModuleList([TransformerLayer() for _ in range(self.n_layers)])

    # normalization layer
    self.normal = nn.LayerNorm(self.embed_dim)

    # classification output layer
    self.classificationLayer = nn.Linear(self.embed_dim, self.n_classes)


  def forward(self,x):

    print(f"Input shape to VIT: {x.shape}")
    # patching the image
    out = self.patchLayer(x)
    print(f"Output from PATCH LAYER shape: {out.shape}")

    # passing through transformer layers, looping through them
    for layer in self.transformerLayers:
      out = layer(out)

    print(f"Output after passing through TransformerLayers : {out.shape}")
    # normalization
    out = self.normal(out)

    # classification layer
    classification_out = self.classificationLayer(out[:,0])
    print(f"Output after passing through classification layer : {classification_out.shape}")

    return classification_out

In [None]:
test_batch_object = torch.randn([10,3,224,224])
vit_obj = ViT()
vit_out = vit_obj(test_batch_object)