In [1]:
import torch
import torch.nn as nn
from types import SimpleNamespace
from einops import einsum, rearrange

In [2]:
config_vit_pico = SimpleNamespace(
    embed_dim = 128,
    num_heads = 4,
    depth = 6,
    pool = 'mean',
    img_size = 224,
    num_channels = 3,
    patch_size = 16,
    attention_dropout = 0.,
    residual_dropout = 0.,
    mlp_ratio = 4,
    mlp_dropout = 0.,
    pos_dropout = 0.,
    num_classes = 1000
)

config_vit_pico

namespace(embed_dim=128,
          num_heads=4,
          depth=6,
          pool='mean',
          img_size=224,
          num_channels=3,
          patch_size=16,
          attention_dropout=0.0,
          residual_dropout=0.0,
          mlp_ratio=4,
          mlp_dropout=0.0,
          pos_dropout=0.0,
          num_classes=1000)

In [3]:
def params(m):
    return sum([p.numel() for p in m.parameters() if p.requires_grad])

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.n_heads = config.num_heads
        assert self.embed_dim % self.n_heads == 0, 'embedding dimension by be divisible by number of heads'
        self.head_size = self.embed_dim // self.n_heads
        self.seq_len = config.embed_dim
        
        self.qkv = nn.Linear(self.embed_dim, self.head_size * self.n_heads * 3,bias=False)
        self.scale = self.head_size ** -0.5
        
        self.attention_dropout = nn.Dropout(config.attention_dropout)
        
        self.proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.residual_dropout = nn.Dropout(config.residual_dropout)
        
        
    def forward(self, x):
        b,t,c = x.shape
        # q,k,v shape individually: batch_size x seq_len x embed_dim
        # we know that qk_t = q x k_t, where q=bxtxhead_dim, k_t=bxhead_timxt
        q,k,v = self.qkv(x).chunk(3,dim=-1)
        q = rearrange(q,'b t (h n) -> b n t h',n=self.n_heads) # h = head_size
        k = rearrange(k,'b t (h n) -> b n t h',n=self.n_heads)
        v = rearrange(v,'b t (h n) -> b n t h',n=self.n_heads)
        
        # qk_t = einsum(q,k,'b n t1 h, b n t2 h -> b n t1 t2') * self.scale
        qk_t = (q@k.transpose(-2,-1)) * self.scale
        
        weights = self.attention_dropout(qk_t)
        
        attention = weights @ v # batch x n_heads x seq_len x head_size
        attention = rearrange(attention,'b n t h -> b t (n h)') # batch x n_heads x seq_len x embed_dim
        
        out = self.proj(attention)
        out = self.residual_dropout(out)
        
        return out

In [5]:
x = torch.rand(16,config_vit_pico.embed_dim,config_vit_pico.embed_dim)
a = MultiHeadAttention(config_vit_pico)
a(x).shape

torch.Size([16, 128, 128])

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.embed_dim)
        self.attn = MultiHeadAttention(config)
        self.ln2 = nn.LayerNorm(config.embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(config.embed_dim,config.embed_dim*config.mlp_ratio),
            nn.GELU(),
            nn.Linear(config.embed_dim*config.mlp_ratio,config.embed_dim),
            nn.Dropout(config.mlp_dropout)
        )
        
    def forward(self,x):
        x = x+self.attn(self.ln1(x))
        x = x+self.mlp(self.ln2(x))
        return x

In [31]:
class ViT(nn.Module):
    def __init__(self,config):
        
        super().__init__()
        
        config.num_patches = (config.img_size // config.patch_size) ** 2
        config.patch_dim = config.num_channels * config.patch_size ** 2
        
        self.config = config
        
        self.patch_embedding = nn.Sequential(
            nn.LayerNorm(self.config.patch_dim),
            nn.Linear(self.config.patch_dim, self.config.embed_dim, bias=False),
            nn.LayerNorm(self.config.embed_dim)
        )
        self.pos_embed = nn.Parameter(torch.randn(1,self.config.num_patches,self.config.embed_dim),requires_grad=True)
        self.pos_dropout = nn.Dropout(self.config.pos_dropout)
        
        self.transformer_blocks = nn.ModuleList([TransformerBlock(config) for _ in range(self.config.depth)])
        
        self.head = nn.Linear(self.config.embed_dim,self.config.num_classes)
        
    def forward(self,x):
        
        x = rearrange(x,'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                      p1=self.config.patch_size,
                      p2=self.config.patch_size
                     )
        x = self.patch_embedding(x)
        x += self.pos_embed
        x = self.pos_dropout(x)
        
        for block in self.transformer_blocks:
            x = block(x)
            
        x = x.mean(dim=1)
        x = self.head(x)
        
        return x
        

In [32]:
model = ViT(config_vit_pico)
x = torch.rand(1,3,224,224)
model(x).shape, params(model), model.config

torch.Size([1, 196, 128])


(torch.Size([1, 1000]),
 1440744,
 namespace(embed_dim=128,
           num_heads=4,
           depth=6,
           pool='mean',
           img_size=224,
           num_channels=3,
           patch_size=16,
           attention_dropout=0.0,
           residual_dropout=0.0,
           mlp_ratio=4,
           mlp_dropout=0.0,
           pos_dropout=0.0,
           num_classes=1000,
           num_patches=196,
           patch_dim=768))

In [33]:
model

ViT(
  (patch_embedding): Sequential(
    (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=768, out_features=128, bias=False)
    (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (pos_dropout): Dropout(p=0.0, inplace=False)
  (transformer_blocks): ModuleList(
    (0-5): 6 x TransformerBlock(
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (qkv): Linear(in_features=128, out_features=384, bias=False)
        (attention_dropout): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=128, out_features=128, bias=False)
        (residual_dropout): Dropout(p=0.0, inplace=False)
      )
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=128, bias=True)
        (3): Dropout(p=0.0

In [34]:
config_vit_base = SimpleNamespace(
    embed_dim = 768,
    num_heads = 12,
    depth = 12,
    img_size = 224,
    num_channels = 3,
    patch_size = 16,
    attention_dropout = 0.,
    residual_dropout = 0.,
    mlp_ratio = 4,
    mlp_dropout = 0.,
    pos_dropout = 0.,
    num_classes = 1000
)

base = ViT(config_vit_base)
x = torch.rand(1,3,224,224)
base(x).shape, params(base)

torch.Size([1, 196, 768])


(torch.Size([1, 1000]), 86530024)

In [35]:
base

ViT(
  (patch_embedding): Sequential(
    (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=768, out_features=768, bias=False)
    (2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (pos_dropout): Dropout(p=0.0, inplace=False)
  (transformer_blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (qkv): Linear(in_features=768, out_features=2304, bias=False)
        (attention_dropout): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=False)
        (residual_dropout): Dropout(p=0.0, inplace=False)
      )
      (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=3072, out_features=768, bias=True)
        (3): Dropout(

In [15]:
torch.randn(1, 128, 128)

tensor([[[-0.9628,  0.7251,  0.7515,  ..., -0.0577,  2.6292,  0.7296],
         [ 0.8384, -0.2621,  0.1058,  ..., -0.9684,  0.2455,  0.1826],
         [ 0.5375,  0.0921, -0.1405,  ..., -0.2842, -0.2453,  1.3140],
         ...,
         [-0.4933,  0.2625,  1.3823,  ...,  1.2591, -0.4266, -0.3067],
         [ 1.0388,  0.2339, -0.1040,  ..., -1.0064, -1.0018, -0.6355],
         [-1.6927,  1.6512,  1.0339,  ...,  0.3604, -0.8847, -0.6002]]])