In [1]:
import torch
import torch.nn as nn
from types import SimpleNamespace
from einops import einsum, rearrange

In [2]:
config_vit_pico = SimpleNamespace(
    embed_dim = 128,
    num_heads = 4,
    depth = 6,
    pool = 'mean',
    img_size = 224,
    num_channels = 3,
    patch_size = 16,
    attention_dropout = 0.,
    residual_dropout = 0.,
    mlp_ratio = 4,
    mlp_dropout = 0.,
    pos_dropout = 0.,
    num_classes = 1000
)

config_vit_pico

namespace(embed_dim=128,
          num_heads=4,
          depth=6,
          pool='mean',
          img_size=224,
          num_channels=3,
          patch_size=16,
          attention_dropout=0.0,
          residual_dropout=0.0,
          mlp_ratio=4,
          mlp_dropout=0.0,
          pos_dropout=0.0,
          num_classes=1000)

In [3]:
def params(m):
    return sum([p.numel() for p in m.parameters() if p.requires_grad])

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.embed_dim = config.embed_dim
        self.n_heads = config.num_heads
        assert self.embed_dim % self.n_heads == 0, 'embedding dimension by be divisible by number of heads'
        self.head_size = self.embed_dim // self.n_heads
        self.seq_len = config.embed_dim
        
        self.qkv = nn.Linear(self.embed_dim, self.head_size * self.n_heads * 3,bias=False)
        self.scale = self.head_size ** -0.5
        
        self.attention_dropout = nn.Dropout(config.attention_dropout)
        
        self.proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
        self.residual_dropout = nn.Dropout(config.residual_dropout)
        
        
    def forward(self, x):
        b,t,c = x.shape
        # q,k,v shape individually: batch_size x seq_len x embed_dim
        # we know that qk_t = q x k_t, where q=bxtxhead_dim, k_t=bxhead_timxt
        q,k,v = self.qkv(x).chunk(3,dim=-1)
        q = rearrange(q,'b t (h n) -> b n t h',n=self.n_heads) # h = head_size
        k = rearrange(k,'b t (h n) -> b n t h',n=self.n_heads)
        v = rearrange(v,'b t (h n) -> b n t h',n=self.n_heads)
        
        # qk_t = einsum(q,k,'b n t1 h, b n t2 h -> b n t1 t2') * self.scale
        qk_t = (q@k.transpose(-2,-1)) * self.scale
        
        weights = self.attention_dropout(qk_t)
        
        attention = weights @ v # batch x n_heads x seq_len x head_size
        attention = rearrange(attention,'b n t h -> b t (n h)') # batch x n_heads x seq_len x embed_dim
        
        out = self.proj(attention)
        out = self.residual_dropout(out)
        
        return out

In [5]:
x = torch.rand(16,config_vit_pico.embed_dim,config_vit_pico.embed_dim)
a = MultiHeadAttention(config_vit_pico)
a(x).shape

torch.Size([16, 128, 128])

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.embed_dim)
        self.attn = MultiHeadAttention(config)
        self.ln2 = nn.LayerNorm(config.embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(config.embed_dim,config.embed_dim*config.mlp_ratio),
            nn.GELU(),
            nn.Linear(config.embed_dim*config.mlp_ratio,config.embed_dim),
            nn.Dropout(config.mlp_dropout)
        )
        
    def forward(self,x):
        x = x+self.attn(self.ln1(x))
        x = x+self.mlp(self.ln2(x))
        return x

In [7]:
class ViT(nn.Module):
    def __init__(self,config):
        
        super().__init__()
        
        config.num_patches = (config.img_size // config.patch_size) ** 2
        config.patch_dim = config.num_channels * config.patch_size ** 2
        
        self.config = config
        
        self.patch_embedding = nn.Sequential(
            nn.LayerNorm(self.config.patch_dim),
            nn.Linear(self.config.patch_dim, self.config.embed_dim, bias=False),
            nn.LayerNorm(self.config.embed_dim)
        )
        self.pos_embed = nn.Parameter(torch.randn(1,self.config.num_patches,self.config.embed_dim),requires_grad=True)
        self.pos_dropout = nn.Dropout(self.config.pos_dropout)
        
        self.transformer_blocks = nn.ModuleList([TransformerBlock(config) for _ in range(self.config.depth)])
        
        self.head = nn.Linear(self.config.embed_dim,self.config.num_classes)
        
    def forward(self,x):
        
        x = rearrange(x,'b c (h p1) (w p2) -> b (h w) (p1 p2 c)',
                      p1=self.config.patch_size,
                      p2=self.config.patch_size
                     )
        x = self.patch_embedding(x)
        x += self.pos_embed
        x = self.pos_dropout(x)
        
        for block in self.transformer_blocks:
            x = block(x)
            
        x = x.mean(dim=1)
        x = self.head(x)
        
        return x
        

In [8]:
model = ViT(config_vit_pico)
x = torch.rand(1,3,224,224)
model(x).shape, params(model), model.config

(torch.Size([1, 1000]),
 1440744,
 namespace(embed_dim=128,
           num_heads=4,
           depth=6,
           pool='mean',
           img_size=224,
           num_channels=3,
           patch_size=16,
           attention_dropout=0.0,
           residual_dropout=0.0,
           mlp_ratio=4,
           mlp_dropout=0.0,
           pos_dropout=0.0,
           num_classes=1000,
           num_patches=196,
           patch_dim=768))

In [9]:
model

ViT(
  (patch_embedding): Sequential(
    (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=768, out_features=128, bias=False)
    (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (pos_dropout): Dropout(p=0.0, inplace=False)
  (transformer_blocks): ModuleList(
    (0-5): 6 x TransformerBlock(
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (qkv): Linear(in_features=128, out_features=384, bias=False)
        (attention_dropout): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=128, out_features=128, bias=False)
        (residual_dropout): Dropout(p=0.0, inplace=False)
      )
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=128, bias=True)
        (3): Dropout(p=0.0

In [10]:
config_vit_base = SimpleNamespace(
    embed_dim = 768,
    num_heads = 12,
    depth = 12,
    img_size = 224,
    num_channels = 3,
    patch_size = 16,
    attention_dropout = 0.,
    residual_dropout = 0.,
    mlp_ratio = 4,
    mlp_dropout = 0.,
    pos_dropout = 0.,
    num_classes = 1000
)

base = ViT(config_vit_base)
x = torch.rand(1,3,224,224)
base(x).shape, params(base)

(torch.Size([1, 1000]), 86530024)

In [11]:
base

ViT(
  (patch_embedding): Sequential(
    (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=768, out_features=768, bias=False)
    (2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (pos_dropout): Dropout(p=0.0, inplace=False)
  (transformer_blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (qkv): Linear(in_features=768, out_features=2304, bias=False)
        (attention_dropout): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=False)
        (residual_dropout): Dropout(p=0.0, inplace=False)
      )
      (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=3072, out_features=768, bias=True)
        (3): Dropout(

In [12]:
torch.randn(1, 128, 128)

tensor([[[-3.2585e-01, -2.4704e+00, -1.4879e-01,  ..., -6.1265e-01,
           1.1498e+00, -6.2170e-02],
         [-1.3894e+00,  4.6732e-01, -1.8328e-03,  ...,  1.2957e+00,
           1.9182e-01, -1.8384e-01],
         [ 8.2649e-01, -2.5241e-01,  4.8311e-01,  ..., -1.2235e+00,
           4.6463e-01,  1.5737e-01],
         ...,
         [-1.5943e-01,  1.4821e+00,  8.3911e-01,  ..., -6.4295e-01,
          -1.0655e+00,  8.7774e-01],
         [ 2.2364e+00, -8.3828e-01, -2.2446e+00,  ...,  6.1022e-02,
          -1.2066e+00,  1.0642e+00],
         [ 5.2472e-01,  5.3328e-02, -6.7582e-01,  ...,  6.5444e-01,
           2.0783e-03, -1.1407e+00]]])

# fix

In [37]:
x = torch.rand(1,3,128,128)

In [38]:
patch_size = 16
num_patches = ((128//16)**2)
patch_embed = 3 * 16**2
patch_size,num_patches, patch_embed

(16, 64, 768)

In [39]:
patch_conv = nn.Conv2d(3,patch_embed,16,16)
x2 = patch_conv(x)
x2.shape

torch.Size([1, 768, 8, 8])

In [40]:
x2 = rearrange(x2,'b p h w -> b (h w) p')
x2.shape

torch.Size([1, 64, 768])

In [41]:
patchnorm = nn.Sequential(
    nn.LayerNorm(patch_embed),
    nn.Linear(patch_embed,64),
    nn.LayerNorm(64)
)
x3 = patchnorm(x2)
x3.shape

torch.Size([1, 64, 64])

In [42]:
pos = torch.randn(1,num_patches,64)
x3 += pos
x3.shape

torch.Size([1, 64, 64])