<a href="https://colab.research.google.com/github/yukitiec/Research/blob/main/VisionTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Import Library

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#Make Module

##Input Layer

In [5]:
class VitInputLayer(nn.Module):
  def __init__(self,
               in_channels:int = 3,
               emb_dim:int = 384,
               num_patch_row:int = 2,
               image_size:int = 32):
    """
    in_channels : num of channels of input images
    emb_dim : length of vector after embedded
    num_patch_row : num of patch in height axis
    image size : image size
    """
    super(VitInputLayer,self).__init__()
    self.in_channels = in_channels
    self.emb_dim = emb_dim
    self.num_patch_row = num_patch_row
    self.image_size = image_size

    #num of patch 
    self.num_patch = self.num_patch_row**2

    #size of patch
    self.patch_size = int(self.image_size//self.num_patch_row)

    #make input images into patch and embedded one with Conv2D
    self.patch_emb_layer = nn.Conv2d(
        in_channels = self.in_channels,
        out_channels = self.emb_dim,
        kernel_size = self.patch_size,
        stride = self.patch_size
    )

    #class token
    self.cls_token = nn.Parameter(
        torch.randn(1,1,emb_dim)
    )

    #positional embedding
    #prepare (batch_size+1) vectors for embedded vectors because the header is class token
    self.pos_emb = nn.Parameter(
        torch.randn(1,self.num_patch+1,emb_dim)
    )

  def forward(self,x: torch.Tensor) -> torch.Tensor:
    """
    Args:
      x : input image (B,C,H,W)
        B:Batch size, C:Channel, H:Height, W:Width


    Return:
      z_0 : input for Vit (B,N,D)
        B:Batch size, N : Num of Token D: Length of embedded vectors
    """

    #Patch embedding
    #(B,C,H,W) -> (B,C,H/P,W/P)
    z_0 = self.patch_emb_layer(x)

    #patch flatten
    #(B,C,H/P,W/P) -> (B,D，Np) (D = (P^2*C), after 2 is flatten)
    z_0 = z_0.flatten(2)

    #reshape the matrix
    #(B,D,Np) -> (B,Np,D)
    z_0 = z_0.transpose(1,2)

    #Concatenate class token at the head of embeddings
    #(B,Np,D) -> (B,N,D) N = Np + 1
    #cls token : (1,1,D) -> (B,1,D)
    z_0 = torch.cat(
        [self.cls_token.repeat(repeats = (x.size(0),1,1)), z_0], dim=1
    )

    #Positional Embedding
    #(B,N,D) -> (B,N,D)
    z_0 = z_0 + self.pos_emb

    return z_0

##Multi-Head Self-Attention

In [14]:
class MultiHeadSelfAttention(nn.Module):
  def __init__(self,
               emb_dim:int = 384,
               head:int = 3,
               dropout:float = 0):
    """
    Args:
      emb_dim : the length of embedded vector
      head : num of head
      dropout : the rate of dropout
    """

    super(MultiHeadSelfAttention, self).__init__()
    self.head = head
    self.emb_dim = emb_dim
    self.head_dim = emb_dim // head
    self.sqrt_dh = self.head_dim**0.5 #for attention weight

    #linear layer for q,k,v
    self.w_q = nn.Linear(emb_dim,emb_dim,bias=False)
    self.w_k = nn.Linear(emb_dim,emb_dim,bias=False)
    self.w_v = nn.Linear(emb_dim,emb_dim,bias=False)

    #dropout layer
    self.attn_drop = nn.Dropout(dropout)

    #linear layer for ouput of MHSA
    self.w_o = nn.Sequential(
        nn.Linear(emb_dim,emb_dim),
        nn.Dropout(dropout)
    )

  def forward(self, z:torch.Tensor) -> torch.Tensor:
    """
    Args:
      z: input for MHSA (B,N,D)
        B:Batch size, N: Num of patches, D:length of embedded vectors
    
    Return:
      out: output of MHSA (B,N,D)
    """

    batch_size, num_patch, _ = z.size()

    #embedding
    q = self.w_q(z)
    k = self.w_k(z)
    v = self.w_v(z)

    #split q,k,v for MHSA
    #(B,N,D) -> (B,N,h,D//h)
    q = q.view(batch_size,num_patch,self.head,self.head_dim)
    k = k.view(batch_size,num_patch,self.head,self.head_dim)
    v = v.view(batch_size,num_patch,self.head,self.head_dim)

    #arrange data for self-attention
    #(B,N,h,D//h) -> (B,h,N,D//h)
    q = q.transpose(1,2)
    k = k.transpose(1,2)
    v = v.transpose(1,2)

    #arragen k for attention weight 
    #(B,h,N,D//h) -> (B,h,h//D,N) 
    k_T = k.transpose(2,3)
    #inner dot
    #(B,h,N,D//h)*(B,h,h//D,N) -> (B,h,N,N)
    dots = (q@k_T)/self.sqrt_dh
    #softmax in row axis
    attn = F.softmax(dots,dim=-1)
    #Dropout
    attn = self.attn_drop(attn)

    #get new embeddings
    #(B,h,N,N)*(B,h,N,D//h) -> (B,h,N,D//h)
    out = attn@v
    #(B,h,N,D//h) -> (B,N,h,D//h)
    out = out.transpose(1,2)
    #(B,N,h,D//h) -> (B,N,D)
    out = out.reshape(batch_size,num_patch,self.emb_dim)

    #output layer
    out = self.w_o(out)

    return out

##Encoder

In [21]:
  class VitEncoderBlock(nn.Module):
    def __init__(self,
                 emb_dim:int = 384,
                 head:int = 8,
                 hidden_dim:int = 384*4,
                 dropout:float = 0):
      """
      Args:
        emb_dim : length of embedded vectors
        head : num of heads in MHSA
        hidden_dim : length of the middle layer of MLP in Encoder Block, here 384*4 as in paper
        dropout : dropout rate
      """

      super(VitEncoderBlock,self).__init__()
      #first LayerNormalization
      self.ln1 = nn.LayerNorm(emb_dim)
      #MHSA
      self.msa = MultiHeadSelfAttention(
          emb_dim = emb_dim,
          head = head,
          dropout = dropout
      )

      #second LayerNormalization
      self.ln2 = nn.LayerNorm(emb_dim)

      #MLP
      self.mlp = nn.Sequential(
          nn.Linear(emb_dim, hidden_dim),
          nn.GELU(),
          nn.Dropout(dropout),
          nn.Linear(hidden_dim,emb_dim),
          nn.Dropout(dropout)
      )

    def forward(self,z:torch.Tensor) ->  torch.Tensor:
      """
      Args:
        z : input for Encoder Block (B,N,D)
      
      Return:
        out:out for Encoder Block (B,N,D)
      """
      #first half
      out = self.msa(self.ln1(z)) + z
      #second half
      out = self.mlp(self.ln2(out)) + out
      return out

##Visual Transformer

#Training Session

In [23]:
batch_size, channel, height, width = 2,3,32,32
x = torch.randn(batch_size, channel, height, width)
input_layer = VitInputLayer(num_patch_row = 2)
z_0 = input_layer(x)

#check if the shape is (2,5,384)
print("after input layer")
print(z_0.shape)

#MHSA
mhsa = MultiHeadSelfAttention()
out = mhsa(z_0)

print("after MHSA layer")
print(out.shape)

vit_enc = VitEncoderBlock()
z_1 = vit_enc(z_0)

print("after Vit Encoder Block")
print(z_1.shape)


after input layer
torch.Size([2, 5, 384])
after MHSA layer
torch.Size([2, 5, 384])
after Vit Encoder Block
torch.Size([2, 5, 384])
