<a href="https://colab.research.google.com/github/torrhen/pytorch_miscellaneous/blob/main/ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import torch
from torch import nn
import math
import matplotlib.pyplot as plt

In [3]:
class PatchEmbedding(nn.Module):
  '''
  Create patch embeddings using hybrid architecture as described in section 3.1
  '''
  def __init__(self, in_channels=3, patch_size=16, embedding_dim=768):
    super(PatchEmbedding, self).__init__()
    self.in_channels = in_channels
    self.patch_size = patch_size
    self.embedding_dim = embedding_dim
    # create input sequence of patches by flattening the spatial dimensions of the feature map and projecting to the embedding dimension used by the transformer.
    self.embedding = nn.Sequential(
        # [B, 3, 224, 244] -> [B, 768, 14, 14]
        nn.Conv2d(in_channels=self.in_channels, out_channels=self.embedding_dim, kernel_size=self.patch_size, stride=self.patch_size, padding=0),
        # [B, 768, 196]
        nn.Flatten(start_dim=2, end_dim=3)
    )
    self.class_token = nn.Parameter(torch.ones(1, 1, self.embedding_dim))
    self.position_embedding = nn.Parameter(torch.ones(1, 197, self.embedding_dim))

  def forward(self, x):
    # input spatial dimensions should be divided without remainder into 16x16 patches
    height, width = x.shape[-2:] #  x = [C, H, W]
    assert(height % self.patch_size == 0 and width % self.patch_size == 0)
    # calculate patch embedding
    x = self.embedding(x) # [B, (P . C^2), (HW / P^2)]
    x = x.permute(0, 2, 1) # [B, (HW / P^2), (P . C^2)]
    # prepend class token to patch embedding
    x = torch.cat((self.class_token, x), dim=1)
    # add position embedding to patch embedding
    x = x + self.position_embedding
    return x



In [None]:
class MultiHeadSelfAttention(nn.Module):
  def __init__(self, d=768, h=12, dropout=0):
    super(MultiHeadSelfAttention, self).__init__()
    self.embedding_dim = d
    self.h = h
    self.dropout = dropout
    self.LN = nn.LayerNorm(normalized_shape=self.embedding_dim)
    self.MSA = nn.MultiheadAttention(embed_dim=self.embedding_dim,
                                     num_heads=self.h,
                                     dropout=self.dropout,
                                     batch_first=True)
    
  def forward(self, x):
    # layer normalization
    x = self.LN(x)
    # multi-head attention block
    output, attn = self.MSA(query=x, key=x, value=x, need_weights=False)
    return output


In [None]:
class MLPHead(nn.Module):
  def __init__(self, embedding_dim=768, hidden_units=3072, dropout=0.1):
    super(MLPHead, self).__init__()
    self.d = embedding_dim
    self.mlp_size = hidden_units
    self.dropout = dropout
    self.LN = nn.LayerNorm(normalized_shape=self.d)
    self.MLP = nn.Sequential(
        nn.Linear(in_features=self.d, out_features=self.mlp_size),
        nn.GELU(),
        nn.Dropout(p=self.dropout),
        nn.Linear(in_features=self.mlp_size, out_features=self.d),
        nn.Dropout(p=self.dropout)
    )

  def forward(self, x):
    # layer normalization
    x = self.LN(x)
    # MLP head
    x = self.MLP(x)
    return x



In [None]:
class EncoderLayer(nn.Module):
  def __init__(self, embedding_dim=768, h=12, attn_dropout=0, mlp_size=3072, mlp_dropout=0.1):
    super(EncoderLayer, self).__init__()
    self.d = embedding_dim
    self.h = h
    self.attn_dropout = attn_dropout
    self.mlp_size = mlp_size
    self.mlp_dropout = mlp_dropout
    # multi head attention
    self.MHA = MultiHeadSelfAttention(d=self.d,
                                      h=self.h,
                                      dropout=self.attn_dropout)
    # MLP Head
    self.MLP = MLPHead(embedding_dim=self.d,
                       hidden_units=self.mlp_size,
                       dropout=self.mlp_dropout)
    
  def forward(self, x):
    x = self.MHA(x) + x
    x = self.MLP(x) + x
    return x
