This notebook contains the codes from the medium blog on `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`.

Checkout my other blogs related to Deep Learning and Computer Vision [here](https://medium.com/@mandalsouvik)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/souvik3333/medium_blogs/blob/main/transformers/ViT/ViT.ipynb)


## ViT architecture

In [None]:
import torch
import torch.nn as nn
in_chans = 3 #RGB
embed_dim = 768 # vector dimension in model space
patch_size = 16 # each image patch size 16*16
proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # this will create the patch in image
img = torch.randn(1, 3, 224,224) # dummy image
x = proj(img).flatten(2).transpose(1, 2) # BCHW -> BNC
print(x.shape)

In [None]:
cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # create class embeddings without batch
cls_token = cls_token.expand(x.shape[0], -1, -1) # add batch
x = torch.cat((cls_token, x), dim=1) # append class token with linear proj embeddings
x.shape # 196 -> 197

In [None]:
num_patches = 14*14
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) # +1 for class token
x = x + pos_embed # add position encoding
x.shape

In [None]:
# Transformation from source vector to query vector
fc_q = nn.Linear(embed_dim, embed_dim)
# Transformation from source vector to key vector
fc_k = nn.Linear(embed_dim, embed_dim)
# Transformation from source vector to value vector
fc_v = nn.Linear(embed_dim, embed_dim)
Q = fc_q(x)
K = fc_k(x)
V = fc_v(x)
print(Q.shape, K.shape, V.shape)

In [None]:
num_heads = 8
batch_size = 1
Q = Q.view(batch_size, -1, num_heads, embed_dim//num_heads).permute(0, 2, 1, 3) # split the Q matrix for 8 head
K = K.view(batch_size, -1, num_heads, embed_dim//num_heads).permute(0, 2, 1, 3) # split the K matrix for 8 head
V = V.view(batch_size, -1, num_heads, embed_dim//num_heads).permute(0, 2, 1, 3) # split the V matrix for 8 head
print(Q.shape, K.shape, V.shape) # batch_size, num_head, num_patch+1, feature_vec dim per head

In [None]:
score = torch.matmul(Q, K.permute(0, 1, 3, 2)) # Q*k
score = torch.softmax(score, dim=-1)
score = torch.matmul(score, V) # normally we apply dropout layer before this
score.shape # batch_size, num_head, num_patches+1, feature_vector_per_head (embed_dim/num_head)

In [None]:
score = score.permute(0, 2, 1, 3).contiguous()
score.shape # batch_size, num_patches+1, num_head, feature_vector_per_head (embed_dim/num_head)

In [None]:
score = score.view(batch_size, -1, embed_dim) # merge the vectors back to original shape
score.shape # batch_size, num_patches+1, embed_dim

In [None]:
act_layer=nn.GELU # activation function
in_features = embed_dim 
hidden_features = embed_dim * 4
out_features = in_features
fc1 = nn.Linear(in_features, hidden_features)
act = act_layer()
drop1 = nn.Dropout(0.5)
fc2 = nn.Linear(hidden_features, out_features)
drop2 = nn.Dropout(0.5)

In [None]:
x = fc1(score)
x = act(x)
x = drop1(x)
x = fc2(x)
x = drop2(x)
x.shape

In [None]:
cls = x[:,0]

In [None]:
num_classes = 10 # assume 10 class classification
head = nn.Linear(embed_dim, num_classes) 
pred = head(cls)
pred

## Train a simple ViT with PyTorch Lightning and timm 🎆

In [None]:
!pip install timm pytorch_lightning -q

In [None]:
import timm
import torch
import pytorch_lightning as pl
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
import torchmetrics

seed_everything(42, workers=True)

In [None]:
class Model(pl.LightningModule):
    """
    Lightning model
    """
    def __init__(self, model_name, num_classes, lr = 0.001, max_iter=20):
        super().__init__()
        self.model = timm.create_model(model_name=model_name, pretrained=True, num_classes=num_classes)
        self.metric = torchmetrics.Accuracy()
        self.loss = torch.nn.CrossEntropyLoss()
        self.lr = lr
        self.max_iter = max_iter
        
    def forward(self, x):
        return self.model(x)
    def shared_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.metric(preds, y)
        
        return loss
    
    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True, prog_bar=True)
        self.log('train_acc', self.metric, on_epoch=True, logger=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)
        self.log('val_loss', loss, on_step=True, on_epoch=True, logger=True, prog_bar=True)
        self.log('val_acc', self.metric, on_epoch=True, logger=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        optim = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optim, T_max=self.max_iter)
        
        return [optim], [scheduler]

In [None]:
transform = transforms.Compose(
    [transforms.Resize(224),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 128
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=8)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=8)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
model = Model(model_name="vit_tiny_patch16_224", num_classes=len(classes), lr = 0.001, max_iter=10)

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='./checkpoints',
    filename='vit_tpytorch_lightning6_224-cifar10-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}'
)

In [None]:
trainer = Trainer(
    deterministic=True, 
    logger=False, 
    callbacks=[checkpoint_callback], 
    gpus=[0], # change it based on gpu or cpu availability
    max_epochs=10, 
    stochastic_weight_avg=True)

In [None]:
trainer.fit(model=model, train_dataloaders=trainloader, val_dataloaders=testloader)