In [1]:
import torch
import torch.nn as nn

In [2]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## break problem into subproblem 
VIT into 
* Input Block
* Output 
* Layers
* Block
* Model

x is an image in research paper , And capitals are MAtrix

### ViT works by
- input image ---> Patches
- patches     ---> Transformer
- transformer ---> MLP
- MLP         ---> Classifier

In [5]:
SHAPE=224
COLOR_CHANNEL=3
PATCH_SIZE=16
NUM_CLASSES=1000

## 1.  PATCH LAYER/BLOCK

In [10]:
NO_OF_PATCHES=(SHAPE*SHAPE//PATCH_SIZE**2)
INPUT_SHAPE=(SHAPE,SHAPE,COLOR_CHANNEL)
OUTPUT_SHAPE=(NO_OF_PATCHES, PATCH_SIZE**2  * COLOR_CHANNEL)

print("INput shape: ",INPUT_SHAPE)
print("Output shape: Single 1D sequence of PAtches ",OUTPUT_SHAPE)

INput shape:  (224, 224, 3)
Output shape: Single 1D sequence of PAtches  (196, 768)


In [43]:
test_input=torch.randn(*INPUT_SHAPE).to(device)

In [44]:
test_input.shape

torch.Size([224, 224, 3])

In [45]:
conv2d = nn.Conv2d(in_channels=3, out_channels=768, kernel_size=PATCH_SIZE, stride=PATCH_SIZE, padding=0)
conv2d

Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))

In [50]:
image_out_patch=conv2d(test_input.permute(2,0,1))
image_out_patch.shape

torch.Size([768, 14, 14])

In [51]:
image_out_patch.requires_grad # makes params learnable

True

In [None]:
# before passing it to transformer we need to reshape it and flatten it
flatten_image_in_transformer=nn.Flatten(start_dim=2, end_dim=3)(image_out_patch)
flatten_image_in_transformer.shape

In [32]:
flatten_image_in_transformer=flatten_image_in_transformer.permute(0,2,1)
flatten_image_in_transformer.shape

torch.Size([1, 196, 768])

## 2. Transformer

In [34]:
class PatchEmbedding(nn.Module):
    def __init__(self,
                input_shape,
                patch_size,
                output_shape):
        super().__init__()

        #create a patching layer
        self.patcher=nn.Conv2d(in_channels=input_shape,
                                out_channels=output_shape,
                                kernel_size=patch_size,
                                stride=patch_size,
                                padding=0)
        self.flatten=nn.Flatten(start_dim=2, end_dim=3)

    def forward(self,x):
        image_res=x.shape[-1]
        assert image_res % PATCH_SIZE ==0 , "Image resolution should be divisible by patch size"

        # forward pass
        x_patched=self.patcher(x)
        x_flattened=self.flatten(x_patched)

        return x_flattened.permute(0,2,1)




In [37]:
create_patch=PatchEmbedding(input_shape=COLOR_CHANNEL,
                            patch_size=PATCH_SIZE,
                            output_shape=768)

In [54]:
test_input=torch.randn(1,3,SHAPE,SHAPE).to(device)
test_input.shape

torch.Size([1, 3, 224, 224])

In [57]:
patch_embedding=create_patch(test_input)
patch_embedding.shape

torch.Size([1, 196, 768])

### in ViT there is a catch the 1st embedding that you create is a positional embedding which you create manually and it doesnt takes any input from flattening of the Patches

In [59]:
# creating that positional encoding
batch_size=1    

class_token=nn.Parameter(torch.ones(batch_size,1,768),requires_grad=True)
class_token.shape

torch.Size([1, 1, 768])

In [60]:
addition_of_pos_embedding=torch.cat((class_token,patch_embedding),dim=1)
addition_of_pos_embedding.shape

torch.Size([1, 197, 768])

### MSA BLOCK (Multi head self Attention Block)

In [72]:
class MSA(nn.Module):
    def __init__(self,
                 embedding_dim,
                 no_heads,
                 dropout=0):
        super().__init__()
        self.layer_norm=nn.LayerNorm(normalized_shape=embedding_dim)
        self.multi_head_attention=nn.MultiheadAttention(embed_dim=embedding_dim,
                                                        num_heads=no_heads,
                                                        dropout=dropout,
                                                        batch_first=True)  
        
    def forward(self,x):
        x = self.layer_norm(x)
        attention_output, _ = self.multi_head_attention(x, x, x,need_weights=False)

        return attention_output


In [73]:
mlsa=MSA(embedding_dim=768,no_heads=12)

In [74]:
mlsa(addition_of_pos_embedding).shape

torch.Size([1, 197, 768])

### MLP BLOCK

In [76]:
class MLP(nn.Module):
    def __init__(self,embedding_dim,
                 mlp_size,dropout=0):
        super().__init__()

        self.layer_norm=nn.LayerNorm(normalized_shape=embedding_dim)

        self.mlp=nn.Sequential(
            nn.Linear(embedding_dim,mlp_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_size,embedding_dim),
            nn.Dropout(dropout)
        )

    def forward(self,x):
        x=self.layer_norm(x)
        x=self.mlp(x)

        return x

### Transformer Block

In [77]:
class Transformer(nn.Module):
    def __init__(self,
                 embedding_dim,
                 no_heads,
                 mlp_size,
                 dropout=0):
        super().__init__()

        self.msa=MSA(embedding_dim=embedding_dim,
                     no_heads=no_heads,
                     dropout=dropout)

        self.mlp=MLP(embedding_dim=embedding_dim,
                     mlp_size=mlp_size,
                     dropout=dropout)

    def forward(self,x):
        x=self.msa(x)+x
        x=self.mlp(x)+x

        return x

## Putting things Together

In [78]:
class ViT(nn.Module):
    def __init__(self,
                    img_size,
                    in_channels,
                    patch_size,
                    no_of_transformers,
                    embedding_dim,
                    mlp_size,
                    no_heads,
                    no_classes,
                    dropout=0):
        super().__init__()
        self.num_patches=(img_size*img_size//patch_size**2)
        self.class_ebmedding=nn.Parameter(torch.randn(1,1,embedding_dim),requires_grad=True)
        self.positional_embedding=nn.Parameter(torch.randn(1,self.num_patches+1,embedding_dim),requires_grad=True)

        self.patch_embedding=PatchEmbedding(input_shape=in_channels,
                                            patch_size=patch_size,
                                            output_shape=embedding_dim)
        
        self.transformer=nn.Sequential(*[Transformer(embedding_dim=embedding_dim,
                                                    no_heads=no_heads,
                                                    mlp_size=mlp_size,
                                                    dropout=dropout) for _ in range(no_of_transformers)])
        self.classifier=nn.Sequential(
            nn.LayerNorm(embedding_dim),
            nn.Linear(embedding_dim,no_classes)
        )
    
    def forward(self,x):
        batch_size=x.shape[0]

        class_token=self.class_ebmedding.expand(batch_size,-1,-1)
        x=self.patch_embedding(x)
        x=torch.cat((class_token,x),dim=1)
        x +=self.positional_embedding
        x=self.transformer(x)
        x=self.classifier(x[:,0])

        return x



In [81]:
from torchinfo import summary

model=ViT(img_size=SHAPE,
            in_channels=COLOR_CHANNEL,
            patch_size=PATCH_SIZE,
            no_of_transformers=6,
            embedding_dim=768,
            mlp_size=3072,
            no_heads=12,
            no_classes=10,
            dropout=0.1).to(device)


In [83]:
summary(model, input_size=(1, 3, SHAPE, SHAPE), col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds"], depth=4)

Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
ViT                                           [1, 3, 224, 224]          [1, 10]                   152,064                   --                        --
├─PatchEmbedding: 1-1                         [1, 3, 224, 224]          [1, 196, 768]             --                        --                        --
│    └─Conv2d: 2-1                            [1, 3, 224, 224]          [1, 768, 14, 14]          590,592                   [16, 16]                  115,756,032
│    └─Flatten: 2-2                           [1, 768, 14, 14]          [1, 768, 196]             --                        --                        --
├─Sequential: 1-2                             [1, 197, 768]             [1, 197, 768]             --                        --                        --
│    └─Transformer: 2-3                       [1, 197, 768]       

In [84]:
test_input=torch.randn(1,3,SHAPE,SHAPE).to(device)

In [85]:
out=model(test_input)

In [88]:
print(out)
print(out.shape)

tensor([[-0.1228,  0.1758, -1.1480, -0.2348,  0.0632, -0.0963,  0.3380, -0.5666,
         -0.3432, -0.1435]], grad_fn=<AddmmBackward0>)
torch.Size([1, 10])


In [87]:
print(torch.argmax(out))

tensor(6)
