<a href="https://colab.research.google.com/github/peeyushsinghal/EVA8/blob/main/S10-Assignment-Solution/EVA8_S10_ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [84]:
# !pip install torchinfo
# !pip install einops

In [85]:
#@title Importing Libraries
import torch
from torch import nn
from torchinfo import summary

In [86]:
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [87]:
def pair(data):
  return data if isinstance(data,tuple) else (data,data)

In [88]:
class ToPatch(nn.Module):
  def __init__(self,  patch_size,  channels=3, embedding_dim = 768):
    super().__init__()

    self.patch= nn.Sequential(
        nn.Conv2d(in_channels = channels, out_channels = embedding_dim, kernel_size = patch_size, stride = patch_size, padding =0), #3X32x32 -> out_channels x (image_height // patch_height) x (image_width // patch_width) [(out_channels)x(image_height // patch_height) x (image_width // patch_width)]
        nn.Flatten(start_dim=2, end_dim=3), # conversion to 2d- out_channels x [(image_height // patch_height) x (image_width // patch_width)] == out_channels x num_patches
    )

  def forward(self,x):
    x = self.patch(x)
    x = x.permute(0,2,1) # [B x out_channels x num_patches] -> [B x  num_patches x out_channels]
    return x

In [89]:
class FeedForward(nn.Module):
  def __init__(self,
               in_dim = 32,
               out_dim = 3*32,
               drop_out = 0.1
               ):
    super().__init__()
    self.ff = nn.Sequential(
        nn.Conv1d(in_channels=in_dim, out_channels=out_dim, kernel_size = 1), # using 1x1 conv instead of linear layer
        nn.GELU(),
        nn.Dropout(drop_out),
        nn.Conv1d(in_channels=out_dim, out_channels=in_dim, kernel_size = 1), # using 1x1 conv instead of linear layer
    )

  def forward(self,x):
    return self.ff(x)

In [90]:
class TransformerEncoderBlock(nn.Module):
  def __init__(self,  
               num_heads = 4, # number of parallel multi attention heads required
               dim = 32, # number of total dimension of input
               transformer_dropout = 0.1 #Dropout used in feedforward layer
               ):
    super().__init__()
    # attention block
    self.layer_norm_preattn= nn.LayerNorm(dim)
    self.self_attn = nn.MultiheadAttention(embed_dim = dim, 
                                           num_heads = num_heads,
                                           batch_first = True)
    

    # mlp block
    self.layer_norm_preff = nn.LayerNorm(dim)
    self.feed_forward = FeedForward(in_dim = dim, out_dim = 3*dim, drop_out = transformer_dropout)

  

  def forward(self,x):
    # attention block
    x_attn_residual = self.layer_norm_preattn(x)
    # print("after layer norm, size : ",x_attn_residual.shape)
    x_attn_residual, attn_output_weights  = self.self_attn(query =x_attn_residual, 
                                            key = x_attn_residual, 
                                            value = x_attn_residual,
                                            need_weights = True,
                                            average_attn_weights=True)
    # print("after attention, size : ",x_attn_residual.shape)
    x = x + x_attn_residual
    # print("after residual addition, size : ", x.shape) # batch, (num_patches + 1), (embedding_dim)

    # Feed Forward block
    x_ff_residual = self.layer_norm_preff(x)
    x_ff_residual = self.feed_forward(x_ff_residual.permute(0,2,1)) # permutation required to get B x C x (num_patches + 1) format. C = Embedding Dim
    # print("after feed forward, size : ", x_ff_residual.shape) 
    x = x + x_ff_residual.permute(0,2,1) # residual requires permutation to get back into format of x, i.e., B x(num_patches + 1) x C format. C = Embedding Dim
    # print("after adding residual in feed forward, size : ", x_ff_residual.shape) 
    return x

In [91]:
class TransformerStack(nn.Module):
  ## MultiHead Attention Block
  def __init__(self,
               num_blocks = 4, # number of transformers blocks stacked on each other
               num_heads = 4, # number of parallel multi attention heads required
               dim = 32, # number of total dimension of input
               transformer_dropout = 0.1 #Dropout used in attention
               ):
    super().__init__()
    self.tranformer_stack = nn.ModuleList([]) # initialized
    for _ in range(num_blocks):
      self.tranformer_stack.append(TransformerEncoderBlock(num_heads=num_heads, dim = dim, transformer_dropout = transformer_dropout))


  def forward(self,x):
    for transformer_block in self.tranformer_stack:
      x = transformer_block(x)
    return x

In [114]:
class Head(nn.Module):
  def __init__(self,
               num_classes = 10, # number of classes
               dim = 32, # input dimension
               head_p_drop = 0.1 # drop out
               ):
    super().__init__()
    self.layer_norm_prehead= nn.LayerNorm(dim)
    self.head = nn.Sequential(
        nn.GELU(),
        nn.Dropout(head_p_drop),
        nn.Conv1d(in_channels = dim, out_channels = num_classes, kernel_size = 1)
    )

  def forward(self,x, pool = 'cls'):
    print("before head block, size :", x.shape)
    if pool == 'cls':
      x_cls = x[:,0,:] # getting the first dimension this gives [batch x dim]
    else:
      x_cls = x[:,1:,:].mean(dim=1) # ignoring the first dimension  this gives [batch x num_patch x dim], mean gives [batch x dim]
    x_cls = x_cls.unsqueeze(dim=1) # [batch x dim] -> batch x 1 x dim
    # print("before head block, before permutation, size :", x_cls.shape)
    x_cls = self.layer_norm_prehead(x_cls)
    x_cls = x_cls.permute(0,2,1)
    print("before head block, after permutation, size :", x_cls.shape)
    x_cls = self.head(x_cls)
    print("after head block , size :", x_cls.shape)
    output = x_cls.view(-1,10)
    
    # if pool == 'mean':
    #   x_mean = x[:,1:,:] # ignoring the first dimension  this gives [batch x num_patch x dim]
    #   print("before taking mean, mean, size :", x_mean.shape)
    #   x_mean = x_mean.mean(dim =1) # this gives [batch x dim]
    #   x_mean = x_mean.unsqueeze(dim=1) # [batch x dim] -> batch x 1 x dim
    #   print("before head block, mean, size :", x_mean.shape)
    #   x_mean = self.layer_norm_prehead(x_mean)
    #   print("before head block, after layernorm, mean, size :", x_mean.shape)
    #   x_mean = x_mean.permute(0,2,1)
    #   print("before head block, after permutation, mean, size :", x_mean.shape)
    #   x_mean = self.head(x_mean)
    #   print("after head block, mean, size :", x_mean.shape)
    #   output = x_mean.view(-1,10)
    return output

In [103]:
class ViT(nn.Module):
  def __init__(self, 
               image_size, 
               patch_size, 
               dim = None, # if None, use the information as per image size else use the dimensions provided
               pool = 'cls', # whether the pooling is based on class token ('cls') or mean pooling ('mean')
               num_classes = 10, 
               emb_dropout = 0.1 #Dropout for patch and position embeddings
               ):
    super().__init__()
    
    assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
    self.pool = pool

    image_height, image_width = pair(image_size)
    patch_height, patch_width = pair(patch_size)

    assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

    num_patches = (image_height // patch_height) * (image_width // patch_width)
    channels = 3 # hard coding
    patch_dim = channels * patch_height * patch_width

    if dim:
      embedding_dim = dim # specific dimension
    else:
      embedding_dim = patch_dim # 3xP_hxP_w

    self.to_patch = ToPatch(patch_size = patch_size, channels = channels, embedding_dim = embedding_dim)
    self.class_token = nn.Parameter(data= torch.randn(1, 1, embedding_dim), requires_grad=True)
    self.pos_embedding = nn.Parameter(data = torch.randn(1, num_patches + 1, embedding_dim),requires_grad=True)
    self.embedding_dropout = nn.Dropout(p=emb_dropout)

    self.transformer = TransformerStack(num_blocks = 4,  num_heads = 4,  dim = 32,  transformer_dropout = 0.1)

    self.head_output = Head(num_classes = num_classes, dim = embedding_dim, head_p_drop = 0.1)

  def forward(self,x):
    x = self.to_patch(x)
    # print("after to_patch, size :", x.shape)
    
    batch_size, num_patches, embedding_dim = x.shape[0], x.shape[-2], x.shape[-1]
    # print(f'num_patches : {num_patches}, embedding_dim : {embedding_dim}')
    
    class_token_across_batch = self.class_token.expand(batch_size,-1,-1) # -1 means not to expand in that direction
    x = torch.cat((class_token_across_batch,x),dim=1) # dim 0 is batch_size, dim 1 is num_patches and dim 2 is embedding_dim
    # print("after concatenation with class token, size :", x.shape)
    
    pos_emeddings_across_batch = self.pos_embedding.expand(batch_size,-1,-1) # -1 means not to expand in that direction
    x = x + pos_emeddings_across_batch 
    # print("after adding with postional embeddings, size :", x.shape)

    x = self.embedding_dropout(x)

    x = self.transformer(x)
    # print("after transformer, size :", x.shape)

    if self.pool == 'cls':
      x = self.head_output(x,pool='cls')
    if self.pool == 'mean':
      x = self.head_output(x,pool='mean')

    
    print("after head_output, size :", x.shape)
    return x

In [94]:
# class xx(nn.Module):
#   def __init__(self,):
#     super().__init__()

#   def forward(self,x):
#     return x

In [117]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", device)
batch_size = 8
model = ViT(image_size = 32, patch_size = 2, dim = 32, pool = 'mean').to(device)
summary(model, input_size=(batch_size, 3, 32, 32), col_names = ['input_size','output_size','num_params'], depth = 10, row_settings = ["var_names"])

device: cpu
before head block, size : torch.Size([8, 257, 32])
before head block, before permutation cls, size : torch.Size([8, 1, 32])
before head block, after permutation cls, size : torch.Size([8, 32, 1])
after head block cls, size : torch.Size([8, 10, 1])
after head_output, size : torch.Size([8, 10])


Layer (type (var_name))                                 Input Shape               Output Shape              Param #
ViT (ViT)                                               [8, 3, 32, 32]            [8, 10]                   8,256
├─ToPatch (to_patch)                                    [8, 3, 32, 32]            [8, 256, 32]              --
│    └─Sequential (patch)                               [8, 3, 32, 32]            [8, 32, 256]              --
│    │    └─Conv2d (0)                                  [8, 3, 32, 32]            [8, 32, 16, 16]           416
│    │    └─Flatten (1)                                 [8, 32, 16, 16]           [8, 32, 256]              --
├─Dropout (embedding_dropout)                           [8, 257, 32]              [8, 257, 32]              --
├─TransformerStack (transformer)                        [8, 257, 32]              [8, 257, 32]              --
│    └─ModuleList (tranformer_stack)                    --                        --                   

In [15]:
class ViT_internet(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        # self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        # self.pool = pool
        # self.to_latent = nn.Identity()

        # self.mlp_head = nn.Sequential(
        #     nn.LayerNorm(dim),
        #     nn.Linear(dim, num_classes)
        # )

    def forward(self, img):
        x = self.to_patch_embedding(img)

        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        print(x.shape)
        return

        # x = self.transformer(x)

        # x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        # x = self.to_latent(x)
        # return self.mlp_head(x)

In [16]:
batch_size = 8
model_internet = ViT_internet(image_size = 32, patch_size = 2 , num_classes = 10,dim = 32,depth = 4, heads =4, mlp_dim = 10).to(device)
summary(model_internet, input_size=(batch_size, 3, 32, 32), col_names = ['input_size','output_size','num_params'],row_settings = ["var_names"])

torch.Size([8, 257, 32])


Layer (type (var_name))                  Input Shape               Output Shape              Param #
ViT_internet (ViT_internet)              [8, 3, 32, 32]            --                        8,256
├─Sequential (to_patch_embedding)        [8, 3, 32, 32]            [8, 256, 32]              --
│    └─Rearrange (0)                     [8, 3, 32, 32]            [8, 256, 12]              --
│    └─Linear (1)                        [8, 256, 12]              [8, 256, 32]              416
├─Dropout (dropout)                      [8, 257, 32]              [8, 257, 32]              --
Total params: 8,672
Trainable params: 8,672
Non-trainable params: 0
Total mult-adds (M): 0.00
Input size (MB): 0.10
Forward/backward pass size (MB): 0.52
Params size (MB): 0.00
Estimated Total Size (MB): 0.62

In [17]:
class ToPatches(nn.Sequential):
    def __init__(self, in_channels, channels, patch_size, hidden_channels=32):
        super().__init__(
            nn.Conv2d(in_channels, hidden_channels, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(hidden_channels, channels, patch_size, stride=patch_size)
        )

In [18]:
class AddPositionEmbedding(nn.Module):
    def __init__(self, channels, shape):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.Tensor(channels, *shape))
    
    def forward(self, x):
        return x + self.pos_embedding

In [19]:
class ToEmbedding(nn.Sequential):
    def __init__(self, in_channels, channels, patch_size, shape, p_drop=0.):
        super().__init__(
            ToPatches(in_channels, channels, patch_size),
            AddPositionEmbedding(channels, shape),
            nn.Dropout(p_drop)
        )

In [20]:
class ViT_rohan(nn.Sequential):
    def __init__(self, classes, image_size, channels, head_channels, num_blocks, patch_size,
                 in_channels=3, emb_p_drop=0., trans_p_drop=0., head_p_drop=0.):
        reduced_size = image_size // patch_size
        shape = (reduced_size, reduced_size)
        super().__init__(
            ToEmbedding(in_channels, channels, patch_size, shape, emb_p_drop),
            # TransformerStack(num_blocks, channels, head_channels, shape, trans_p_drop),
            # Head(channels, classes, head_p_drop)
        )
        # self.reset_parameters()
    
    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None: nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.weight, 1.)
                nn.init.zeros_(m.bias)
            elif isinstance(m, AddPositionEmbedding):
                nn.init.normal_(m.pos_embedding, mean=0.0, std=0.02)
            elif isinstance(m, SelfAttention2d):
                nn.init.normal_(m.pos_enc, mean=0.0, std=0.02)
            elif isinstance(m, Residual):
                nn.init.zeros_(m.gamma)
    
    def separate_parameters(self):
        parameters_decay = set()
        parameters_no_decay = set()
        modules_weight_decay = (nn.Linear, nn.Conv2d)
        modules_no_weight_decay = (nn.LayerNorm,)

        for m_name, m in self.named_modules():
            for param_name, param in m.named_parameters():
                full_param_name = f"{m_name}.{param_name}" if m_name else param_name

                if isinstance(m, modules_no_weight_decay):
                    parameters_no_decay.add(full_param_name)
                elif param_name.endswith("bias"):
                    parameters_no_decay.add(full_param_name)
                elif isinstance(m, Residual) and param_name.endswith("gamma"):
                    parameters_no_decay.add(full_param_name)
                elif isinstance(m, AddPositionEmbedding) and param_name.endswith("pos_embedding"):
                    parameters_no_decay.add(full_param_name)
                elif isinstance(m, SelfAttention2d) and param_name.endswith("pos_enc"):
                    parameters_no_decay.add(full_param_name)
                elif isinstance(m, modules_weight_decay):
                    parameters_decay.add(full_param_name)

        # sanity check
        # assert len(parameters_decay & parameters_no_decay) == 0
        # assert len(parameters_decay) + len(parameters_no_decay) == len(list(model.parameters()))

        return parameters_decay, parameters_no_decay

In [21]:
batch_size = 8
NUM_CLASSES, IMAGE_SIZE = 10, 32

model_rohan = ViT_rohan(NUM_CLASSES, IMAGE_SIZE, channels=32, head_channels=8, num_blocks=4, patch_size=2).to(device)
summary(model_rohan, input_size=(batch_size, 3, 32, 32), col_names = ['input_size','output_size','num_params'],row_settings = ["var_names"])

Layer (type (var_name))                  Input Shape               Output Shape              Param #
ViT_rohan (ViT_rohan)                    [8, 3, 32, 32]            [8, 32, 16, 16]           --
├─ToEmbedding (0)                        [8, 3, 32, 32]            [8, 32, 16, 16]           --
│    └─ToPatches (0)                     [8, 3, 32, 32]            [8, 32, 16, 16]           --
│    │    └─Conv2d (0)                   [8, 3, 32, 32]            [8, 32, 32, 32]           896
│    │    └─GELU (1)                     [8, 32, 32, 32]           [8, 32, 32, 32]           --
│    │    └─Conv2d (2)                   [8, 32, 32, 32]           [8, 32, 16, 16]           4,128
│    └─AddPositionEmbedding (1)          [8, 32, 16, 16]           [8, 32, 16, 16]           8,192
│    └─Dropout (2)                       [8, 32, 16, 16]           [8, 32, 16, 16]           --
Total params: 13,216
Trainable params: 13,216
Non-trainable params: 0
Total mult-adds (M): 15.79
Input size (MB): 0.10
Forwa