<a href="https://colab.research.google.com/github/samitha278/nanoViT/blob/main/build_vit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from dataclasses import dataclass

In [2]:
# create patches from image : ImageNet-1k

# linear projection of patches

# Emmbedding  patch + postional

# transformer encoder * n

#MLP


In [25]:
batch_size = 32

## Getting data

In [37]:
transform_train = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform_train)
val_dataset = datasets.MNIST(root='data', train=False, download=True, transform=transform_train)


train_data = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_data = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 59.6MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.74MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.9MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.26MB/s]


In [41]:
i = 0
for image , label in train_data:

  if i==1:break
  print(image.shape,label)
  i+=1

torch.Size([32, 1, 28, 28]) tensor([2, 2, 0, 0, 5, 0, 2, 4, 4, 4, 1, 8, 1, 4, 7, 4, 7, 8, 6, 7, 2, 1, 1, 2,
        2, 4, 3, 6, 3, 0, 2, 9])


## Configuration



In [42]:
@dataclass
class ViTConfig:
  batch_size = 32
  num_classes = 10
  img_size = 28
  patch_size = 7
  n_patch = (img_size / patch_size) * (img_size / patch_size)

  n_head = 4
  n_layer = 4

  n_embd = 32


## Multihead Attention for Encoder

In [17]:
class Attention(nn.Module):


  def __init__(self,n_embd,n_head) :
    super().__init__()

    self.nh = n_head

    self.w = nn.Linear(n_embd,3*n_embd)    # 3 * n_head * head_size
    self.proj = nn.Linear(n_embd,n_embd)


  def forward(self,x):

    B,T,C = x.shape

    wei = self.w(x)        # B,T, 3* C

    k,q,v = torch.chunk(wei,3, dim = -1)      # each B,T,C

    head_size = C//self.nh

    key   = k.view(B, T, self.nh, head_size).transpose(1, 2)    # B, n_head, T, head_size
    query = q.view(B, T, self.nh, head_size).transpose(1, 2)    # ""
    value = v.view(B, T, self.nh, head_size).transpose(1, 2)


    weight = ( query @ key.transpose(-1,-2) )  * (head_size ** -0.5)    #B,nh,T,T
    weight = F.softmax(weight,dim = -1)

    out = weight @ value      #B,nh,T,n_head

    out.transpose(1,2)

    out = self.proj(out.view(B,T,C))

    return out




In [18]:
x = torch.randn(32,8,128)

attn = Attention(128,4)

out = attn(x)

In [19]:
out.shape

torch.Size([32, 8, 128])

### MLP

In [20]:
class MLP(nn.Module):


  def __init__(self,n_embd):
    super().__init__()


    self.layer = nn.Linear(n_embd,4*n_embd)
    self.gelu = nn.GELU()
    self.proj = nn.Linear(4*n_embd,n_embd)



  def forward(self,x):


    x = self.gelu(self.layer(x))
    x = self.proj(x)

    return x

## Block

In [None]:
class Block(nn.Module):


  def __init__(self,n_layer,n_embd,n_head):
    super().__init__()


    self.ln_1 = nn.LayerNorm(n_embd)
    self.attn = Attention(n_embd,n_head)
    self.ln_2 = nn.LayerNorm(n_embd)
    self.mlp = MLP(n_embd)


  def forward(self,x):

    x = x + self.attn(self.ln_1(x))
    x = x + self.mlp(self.ln_2(x))

    return x



## Patch Embedding