# VIT-Base Implementation

## Install a Library and Import Packages

In [28]:
!pip install einops



In [29]:
import torch
import torch.nn as nn
from einops import rearrange

## Build the model

### **Embedded Patches**

**Getting x_p**

First, divide images into multiple patches. Let's say the original image tensor has shape of `(batch_size, num_channels, p_height * h_count, p_width * w_count)`, here we're dividing the images with p_height * p_width size of patches. So for each single image we would have h_count * w_count number of patches, where each patch has shape of (p_height, p_width, num_channels). Here we flatten this into p_height * p_width * num_channels. So as a result x_p would have shape of (batch_size, h_count * w_count, h_count*w_count*num_channels).



**Linear Projection into D dimension**

- Doing things inside [...]

We project this with the Linear Projection 'E',  nn.Linear(h_count * w_count * num_channels, d_model). As a result x_p has shape of (batch_size, N = h_count * w_count, d_model). (The paper denotes h_count * w_count as N.). Plus, the x_class has shape of (batch_size, 1, d_model). So the thing inside [...] in Eq(1) have shape of `(batch_size, N+1, d_model)`


- [...] + E_pos
E_pos have shape of `(batch_size, N+1, d_model)` too. So [...] + E_pos has also have shape of `(batch_size, N+1, d_model)`. As a result after the Eq(1), the shape is (batch_size, N+1, d_model).

**Enocder**

Now we pass the z_0 to encoder. Here MSA means MultiHeadSelfAttention and LN means Layer Normalization.


**Classification Head**


Although in the diagram in the paper, it seems like the output of the encoder is being directly passed to the classicifcation head, but if u take a look at the their code, it first go through the LN, then got passed to classification head.


Use different classification head depending on whether ur pre-training or fine-tuning:

**Pre-training**: MLP with one hidden layer

**Fine-Tuning**: a single linear layer




In [30]:
d_model = 768
mlp_size = 3072
heads=12
layers=12
N = 256
batch_size = 2

In [31]:
class EncoderLayer(nn.Module):
  def __init__(self, ln_input_shape, d_model, num_heads):
    super().__init__()
    self.layer_norm1 = nn.LayerNorm(ln_input_shape)
    self.mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, dropout=0.2)
    self.layer_norm12 = nn.LayerNorm(ln_input_shape)

    self.mlp = nn.Sequential(
            nn.Linear(d_model, mlp_size),
            nn.GELU(),
            nn.Linear(mlp_size, d_model),
        )

  def forward(self, x):
    """
    input shape:  # (batch_sizes, N+1, d_model)
    output shape:  # (batch_sizes, N+1, d_model)
    """
    x_prime = self.layer_norm1(x) # (batch_sizes, N+1, d_model)

    first_out, _ = self.mha(x_prime, x_prime, x_prime)  # (batch_sizes, N+1, d_model)
    first_out = first_out + x  # (batch_sizes, N+1, d_model)

    x_prime = self.layer_norm2(first_out)  # (batch_sizes, N+1, d_model)
    second_out = self.mlp(x_prime)  # (batch_sizes, N+1, d_model)
    return second_out + first_out  # (batch_sizes, N+1, d_model)

In [32]:
class ViT(nn.Module):
  def __init__(self, p1, p2, d_model, batch_size, H, W, ln_input_shape, heads):
    super().__init__()
    self.p1 = p1
    self.p2 = p2
    self.E_projection = nn.Linear(588, d_model)
    self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
    self.encoder_layers = nn.ModuleList([
            EncoderLayer(ln_input_shape=ln_input_shape, d_model=d_model, num_heads=heads)
            for _ in range(12)
        ])
    self.pos_embedding = nn.Parameter(torch.randn(1, N + 1, d_model))


  def forward(self, img):
    """
    input shape: (batch_sizes, num_channels, height, width)
    output shape: (batch_sizes, N+1, D)
    """

    batch_size = img.size[0]
    img_patches = rearrange(img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)', patch_x=self.p1, patch_y=self.p2) #(batch_siez, N, p_height*p_width*c)

    # now Apply E matrix
    img_patches = self.E_projection(img_patches) # (batch_sizes, N, D)

    # add CLS token which has shape of (batch_sizes, 1, D)
    cls_token = self.cls_token.expand(batch_size, -1, -1) # (batch_sizes, 1, D)
    img_patches = torch.cat((cls_token, img_patches), dim=1) # (batch_sizes, N+1, D)

    img_patches = img_patches + img_patches
    # repeat encoder layer 12 times
    for encoder_layer in self.encoder_layers:
      img_patches = encoder_layer(img_patches)

    # (batch_sizes, N+1, D)

    return img_patches

In [33]:
batch_size = 2

In [34]:
model = ViT(p1=14, p2=14, d_model=768, batch_size=batch_size, H=224, W=224, ln_input_shape=(N + 1, d_model), heads=12)

In [35]:
# output = model(torch.randn(batch_size, 3, 224, 224))

In [36]:
# output.shape

In [37]:
def count_parameters(model):
  return sum(p.numel() for p in model.parameters())

num_params = count_parameters(model)
print(f"Number of parameters: {num_params}")

Number of parameters: 95142144
