<a href="https://colab.research.google.com/github/samitha278/nanoViT/blob/main/build_vit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from dataclasses import dataclass

In [24]:
# create patches from image : ImageNet-1k

# linear projection of patches

# Emmbedding  patch + postional

# transformer encoder * n

#MLP


In [25]:
batch_size = 32

## Getting data

In [26]:
transform_train = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform_train)
val_dataset = datasets.MNIST(root='data', train=False, download=True, transform=transform_train)


train_data = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_data = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [27]:
i = 0
for image , label in train_data:

  if i==1:break
  print(image.shape,label)
  i+=1

torch.Size([32, 1, 28, 28]) tensor([3, 4, 3, 9, 4, 5, 9, 8, 7, 1, 2, 0, 7, 4, 9, 3, 3, 3, 8, 1, 5, 6, 1, 2,
        9, 5, 0, 1, 7, 8, 5, 1])


## Configuration



In [28]:
@dataclass
class ViTConfig:
  batch_size = 32
  num_classes = 10
  img_size = 28
  patch_size = 7
  n_patch = (img_size / patch_size) * (img_size / patch_size)

  n_head = 4
  n_layer = 4

  n_embd = 32


## Multihead Attention for Encoder

In [29]:
class Attention(nn.Module):


  def __init__(self,n_embd,n_head) :
    super().__init__()

    self.nh = n_head

    self.w = nn.Linear(n_embd,3*n_embd)    # 3 * n_head * head_size
    self.proj = nn.Linear(n_embd,n_embd)


  def forward(self,x):

    B,T,C = x.shape

    wei = self.w(x)        # B,T, 3* C

    k,q,v = torch.chunk(wei,3, dim = -1)      # each B,T,C

    head_size = C//self.nh

    key   = k.view(B, T, self.nh, head_size).transpose(1, 2)    # B, n_head, T, head_size
    query = q.view(B, T, self.nh, head_size).transpose(1, 2)    # ""
    value = v.view(B, T, self.nh, head_size).transpose(1, 2)


    weight = ( query @ key.transpose(-1,-2) )  * (head_size ** -0.5)    #B,nh,T,T
    weight = F.softmax(weight,dim = -1)

    out = weight @ value      #B,nh,T,n_head

    out.transpose(1,2)

    out = self.proj(out.view(B,T,C))

    return out




In [30]:
x = torch.randn(32,8,128)

attn = Attention(128,4)

out = attn(x)

In [31]:
out.shape

torch.Size([32, 8, 128])

### MLP

In [32]:
class MLP(nn.Module):


  def __init__(self,n_embd):
    super().__init__()


    self.layer = nn.Linear(n_embd,4*n_embd)
    self.gelu = nn.GELU()
    self.proj = nn.Linear(4*n_embd,n_embd)



  def forward(self,x):


    x = self.gelu(self.layer(x))
    x = self.proj(x)

    return x

## Block

In [33]:
class Block(nn.Module):


  def __init__(self,n_layer,n_embd,n_head):
    super().__init__()


    self.ln_1 = nn.LayerNorm(n_embd)
    self.attn = Attention(n_embd,n_head)
    self.ln_2 = nn.LayerNorm(n_embd)
    self.mlp = MLP(n_embd)


  def forward(self,x):

    x = x + self.attn(self.ln_1(x))
    x = x + self.mlp(self.ln_2(x))

    return x



## Patch Embedding

In [53]:
class PatchEmbedding(nn.Module):



  def __init__(self,config):
    super().__init__()

    self.config = config

    self.n_patches = (config.img_size // config.patch_size) ** 2
    self.patch_dim = config.im_channels* config.patch_size ** 2


    #patch embedding
    self.patch_embd = nn.Sequential(
        nn.LayerNorm(self.patch_dim),
        nn.Linear(self.patch_dim,config.n_embd),
        nn.LayerNorm(config.n_embd)
    )

    #cls tokens
    self.cls_token = nn.Parameter(torch.randn((config.n_embd,)))

    #possitional embedding
    self.pos_embd = nn.Embedding(self.n_patches+1,config.n_embd)    # +1 for cls token





def forward(self,x):

  B,C,H,W = x.shape

  # B,C,H,W -> B, n_patches , patch_dim    # patch_dim = C* patch_size*patch_size

  patch_size = self.config.patch_size

  patches = F.unfold(x, patch_size, stride = patch_size).transpose(-1,-2)

  #patch embedding
  patch_embd = self.patch_embd(patches)        # B, n_patches , n_embd

  #class token
  class_tok = self.cls_token.expand(B,1,-1)     # B , 1 , n_embd


  patch_embd = torch.cat((class_tok,patch_embd),dim =1 )    # B, n_patches +1  , n_embd

  #positional embedding
  pos_embd = self.pos_embd(torch.arange(0,self.n_pathes+1))     # B, n_patches +1  , n_embd

  out = patch_embd + pos_embd

  return out




In [None]:
@dataclass
class ViTConfig:
  batch_size = 32
  num_classes = 10
  img_size = 28
  patch_size = 7
  n_patch = (img_size // patch_size) * (img_size // patch_size)

  n_head = 4
  n_layer = 4

  n_embd = 32

In [None]:
x = torch.randn((32,3,224,224))

patch_embd = PatchEmbedding()

In [42]:
x = torch.randn((32,3,224,224))

out = F.unfold(x,(16,16),stride=(16,16)).transpose(-1,-2)  #Unfold → flatten → [ R-values | G-values | B-values ] → length 768
out.shape

torch.Size([32, 196, 768])

In [52]:
cls_token = nn.Parameter(torch.randn((768,)))
cls_token.expand(32,1,-1).shape

torch.Size([32, 1, 768])