<a href="https://colab.research.google.com/github/peeyushsinghal/EVA8/blob/main/S10-Assignment-Solution/EVA8_S10_ViT_Run.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title importing libraries
import numpy as np
# from collections import defaultdict
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision

In [2]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("device:", device)

device: cuda


In [3]:
#@title importing model


def pair(data):
  return data if isinstance(data,tuple) else (data,data)

class ToPatch(nn.Module):
  def __init__(self,  patch_size,  channels=3, embedding_dim = 768):
    super().__init__()

    self.patch= nn.Sequential(
        nn.Conv2d(in_channels = channels, out_channels = embedding_dim, kernel_size = patch_size, stride = patch_size, padding =0), #3X32x32 -> out_channels x (image_height // patch_height) x (image_width // patch_width) [(out_channels)x(image_height // patch_height) x (image_width // patch_width)]
        nn.Flatten(start_dim=2, end_dim=3), # conversion to 2d- out_channels x [(image_height // patch_height) x (image_width // patch_width)] == out_channels x num_patches
    )

  def forward(self,x):
    x = self.patch(x)
    x = x.permute(0,2,1) # [B x out_channels x num_patches] -> [B x  num_patches x out_channels]
    return x

class FeedForward(nn.Module):
  def __init__(self,
               in_dim = 32,
               out_dim = 3*32,
               drop_out = 0.1
               ):
    super().__init__()
    self.ff = nn.Sequential(
        nn.Conv1d(in_channels=in_dim, out_channels=out_dim, kernel_size = 1), # using 1x1 conv instead of linear layer
        nn.GELU(),
        nn.Dropout(drop_out),
        nn.Conv1d(in_channels=out_dim, out_channels=in_dim, kernel_size = 1), # using 1x1 conv instead of linear layer
    )

  def forward(self,x):
    return self.ff(x)

class TransformerEncoderBlock(nn.Module):
  def __init__(self,  
               num_heads = 4, # number of parallel multi attention heads required
               dim = 32, # number of total dimension of input
               transformer_dropout = 0.1 #Dropout used in feedforward layer
               ):
    super().__init__()
    # attention block
    self.layer_norm_preattn= nn.LayerNorm(dim)
    self.self_attn = nn.MultiheadAttention(embed_dim = dim, 
                                           num_heads = num_heads,
                                           batch_first = True)
    

    # mlp block
    self.layer_norm_preff = nn.LayerNorm(dim)
    self.feed_forward = FeedForward(in_dim = dim, out_dim = 3*dim, drop_out = transformer_dropout)

  

  def forward(self,x):
    # attention block
    x_attn_residual = self.layer_norm_preattn(x)
    # print("after layer norm, size : ",x_attn_residual.shape)
    x_attn_residual, attn_output_weights  = self.self_attn(query =x_attn_residual, 
                                            key = x_attn_residual, 
                                            value = x_attn_residual,
                                            need_weights = True,
                                            average_attn_weights=True)
    # print("after attention, size : ",x_attn_residual.shape)
    x = x + x_attn_residual
    # print("after residual addition, size : ", x.shape) # batch, (num_patches + 1), (embedding_dim)

    # Feed Forward block
    x_ff_residual = self.layer_norm_preff(x)
    x_ff_residual = self.feed_forward(x_ff_residual.permute(0,2,1)) # permutation required to get B x C x (num_patches + 1) format. C = Embedding Dim
    # print("after feed forward, size : ", x_ff_residual.shape) 
    x = x + x_ff_residual.permute(0,2,1) # residual requires permutation to get back into format of x, i.e., B x(num_patches + 1) x C format. C = Embedding Dim
    # print("after adding residual in feed forward, size : ", x_ff_residual.shape) 
    return x

class TransformerStack(nn.Module):
  ## MultiHead Attention Block
  def __init__(self,
               num_blocks = 4, # number of transformers blocks stacked on each other
               num_heads = 4, # number of parallel multi attention heads required
               dim = 32, # number of total dimension of input
               transformer_dropout = 0.1 #Dropout used in attention
               ):
    super().__init__()
    self.tranformer_stack = nn.ModuleList([]) # initialized
    for _ in range(num_blocks):
      self.tranformer_stack.append(TransformerEncoderBlock(num_heads=num_heads, dim = dim, transformer_dropout = transformer_dropout))


  def forward(self,x):
    for transformer_block in self.tranformer_stack:
      x = transformer_block(x)
    return x

class Head(nn.Module):
  def __init__(self,
               num_classes = 10, # number of classes
               dim = 32, # input dimension
               head_p_drop = 0.1 # drop out
               ):
    super().__init__()
    self.layer_norm_prehead= nn.LayerNorm(dim)
    self.head = nn.Sequential(
        nn.GELU(),
        nn.Dropout(head_p_drop),
        nn.Conv1d(in_channels = dim, out_channels = num_classes, kernel_size = 1)
    )

  def forward(self,x, pool = 'cls'):
    # print("before head block, size :", x.shape)
    if pool == 'cls':
      x_cls = x[:,0,:] # getting the first dimension this gives [batch x dim]
    else:
      x_cls = x[:,1:,:].mean(dim=1) # ignoring the first dimension  this gives [batch x num_patch x dim], mean gives [batch x dim]
    x_cls = x_cls.unsqueeze(dim=1) # [batch x dim] -> batch x 1 x dim
    # print("before head block, before permutation, size :", x_cls.shape)
    x_cls = self.layer_norm_prehead(x_cls)
    x_cls = x_cls.permute(0,2,1)
    # print("before head block, after permutation, size :", x_cls.shape)
    x_cls = self.head(x_cls)
    # print("after head block , size :", x_cls.shape)
    output = x_cls.view(-1,10)
    
    # if pool == 'mean':
    #   x_mean = x[:,1:,:] # ignoring the first dimension  this gives [batch x num_patch x dim]
    #   print("before taking mean, mean, size :", x_mean.shape)
    #   x_mean = x_mean.mean(dim =1) # this gives [batch x dim]
    #   x_mean = x_mean.unsqueeze(dim=1) # [batch x dim] -> batch x 1 x dim
    #   print("before head block, mean, size :", x_mean.shape)
    #   x_mean = self.layer_norm_prehead(x_mean)
    #   print("before head block, after layernorm, mean, size :", x_mean.shape)
    #   x_mean = x_mean.permute(0,2,1)
    #   print("before head block, after permutation, mean, size :", x_mean.shape)
    #   x_mean = self.head(x_mean)
    #   print("after head block, mean, size :", x_mean.shape)
    #   output = x_mean.view(-1,10)
    return output

class ViT(nn.Module):
  def __init__(self, 
               image_size, 
               patch_size, 
               dim = None, # if None, use the information as per image size else use the dimensions provided
               pool = 'cls', # whether the pooling is based on class token ('cls') or mean pooling ('mean')
               num_classes = 10, 
               emb_dropout = 0.1 #Dropout for patch and position embeddings
               ):
    super().__init__()
    
    assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
    self.pool = pool

    image_height, image_width = pair(image_size)
    patch_height, patch_width = pair(patch_size)

    assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

    num_patches = (image_height // patch_height) * (image_width // patch_width)
    channels = 3 # hard coding
    patch_dim = channels * patch_height * patch_width

    if dim:
      embedding_dim = dim # specific dimension
    else:
      embedding_dim = patch_dim # 3xP_hxP_w

    self.to_patch = ToPatch(patch_size = patch_size, channels = channels, embedding_dim = embedding_dim)
    self.class_token = nn.Parameter(data= torch.randn(1, 1, embedding_dim), requires_grad=True)
    self.pos_embedding = nn.Parameter(data = torch.randn(1, num_patches + 1, embedding_dim),requires_grad=True)
    self.embedding_dropout = nn.Dropout(p=emb_dropout)

    self.transformer = TransformerStack(num_blocks = 4,  num_heads = 4,  dim = 32,  transformer_dropout = 0.1)

    self.head_output = Head(num_classes = num_classes, dim = embedding_dim, head_p_drop = 0.1)

  def forward(self,x):
    x = self.to_patch(x)
    # print("after to_patch, size :", x.shape)
    
    batch_size, num_patches, embedding_dim = x.shape[0], x.shape[-2], x.shape[-1]
    # print(f'num_patches : {num_patches}, embedding_dim : {embedding_dim}')
    
    class_token_across_batch = self.class_token.expand(batch_size,-1,-1) # -1 means not to expand in that direction
    x = torch.cat((class_token_across_batch,x),dim=1) # dim 0 is batch_size, dim 1 is num_patches and dim 2 is embedding_dim
    # print("after concatenation with class token, size :", x.shape)
    
    pos_emeddings_across_batch = self.pos_embedding.expand(batch_size,-1,-1) # -1 means not to expand in that direction
    x = x + pos_emeddings_across_batch 
    # print("after adding with postional embeddings, size :", x.shape)

    x = self.embedding_dropout(x)

    x = self.transformer(x)
    # print("after transformer, size :", x.shape)

    if self.pool == 'cls':
      x = self.head_output(x,pool='cls')
    if self.pool == 'mean':
      x = self.head_output(x,pool='mean')

    
    # print("after head_output, size :", x.shape)
    return x



In [4]:
NUM_CLASSES, IMAGE_SIZE = 10, 32
model = ViT( image_size = IMAGE_SIZE, patch_size=2, dim=32, pool = 'cls', num_classes= NUM_CLASSES)

In [5]:
model.to(device)

ViT(
  (to_patch): ToPatch(
    (patch): Sequential(
      (0): Conv2d(3, 32, kernel_size=(2, 2), stride=(2, 2))
      (1): Flatten(start_dim=2, end_dim=3)
    )
  )
  (embedding_dropout): Dropout(p=0.1, inplace=False)
  (transformer): TransformerStack(
    (tranformer_stack): ModuleList(
      (0): TransformerEncoderBlock(
        (layer_norm_preattn): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
        )
        (layer_norm_preff): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (feed_forward): FeedForward(
          (ff): Sequential(
            (0): Conv1d(32, 96, kernel_size=(1,), stride=(1,))
            (1): GELU(approximate='none')
            (2): Dropout(p=0.1, inplace=False)
            (3): Conv1d(96, 32, kernel_size=(1,), stride=(1,))
          )
        )
      )
      (1): TransformerEncoderBlock(
        (layer_

In [6]:
print("Number of parameters: {:,}".format(sum(p.numel() for p in model.parameters())))

Number of parameters: 51,562


In [7]:
IMAGE_SIZE = 32

NUM_CLASSES = 10
NUM_WORKERS = 8
BATCH_SIZE = 128
EPOCHS = 25

LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-1

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.75, 1.0), ratio=(1.0, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandAugment(num_ops=1, magnitude=8),
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
    transforms.RandomErasing(p=0.25)
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std)
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=4)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data




Files already downloaded and verified


In [8]:
import time

clip_norm = True
lr_schedule = lambda t: np.interp([t], [0, EPOCHS*2//5, EPOCHS*4//5, EPOCHS], 
                                  [0, 0.01, 0.01/20.0, 0])[0]

# model = nn.DataParallel(model, device_ids=[0]).cuda()
opt = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

for epoch in range(EPOCHS):
    start = time.time()
    train_loss, train_acc, n = 0, 0, 0
    for i, (X, y) in enumerate(trainloader):
        model.train()
        X, y = X.cuda(), y.cuda()

        lr = lr_schedule(epoch + (i + 1)/len(trainloader))
        opt.param_groups[0].update(lr=lr)

        opt.zero_grad()
        with torch.cuda.amp.autocast():
            output = model(X)
            loss = criterion(output, y)

        scaler.scale(loss).backward()
        if clip_norm:
            scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(opt)
        scaler.update()
        
        train_loss += loss.item() * y.size(0)
        train_acc += (output.max(1)[1] == y).sum().item()
        n += y.size(0)
        
    model.eval()
    test_acc, m = 0, 0
    with torch.no_grad():
        for i, (X, y) in enumerate(testloader):
            X, y = X.cuda(), y.cuda()
            with torch.cuda.amp.autocast():
                output = model(X)
            test_acc += (output.max(1)[1] == y).sum().item()
            m += y.size(0)

    print(f'ViT: Epoch: {epoch} | Train Acc: {train_acc/n:.4f}, Test Acc: {test_acc/m:.4f}, Time: {time.time() - start:.1f}, lr: {lr:.6f}')


ViT: Epoch: 0 | Train Acc: 0.1780, Test Acc: 0.2883, Time: 71.9, lr: 0.001000
ViT: Epoch: 1 | Train Acc: 0.2807, Test Acc: 0.3714, Time: 65.8, lr: 0.002000
ViT: Epoch: 2 | Train Acc: 0.3436, Test Acc: 0.4328, Time: 64.1, lr: 0.003000
ViT: Epoch: 3 | Train Acc: 0.3933, Test Acc: 0.4798, Time: 63.1, lr: 0.004000
ViT: Epoch: 4 | Train Acc: 0.4330, Test Acc: 0.5002, Time: 65.6, lr: 0.005000
ViT: Epoch: 5 | Train Acc: 0.4564, Test Acc: 0.5178, Time: 64.3, lr: 0.006000
ViT: Epoch: 6 | Train Acc: 0.4816, Test Acc: 0.5498, Time: 64.9, lr: 0.007000
ViT: Epoch: 7 | Train Acc: 0.4980, Test Acc: 0.5481, Time: 66.2, lr: 0.008000
ViT: Epoch: 8 | Train Acc: 0.5129, Test Acc: 0.5611, Time: 63.8, lr: 0.009000
ViT: Epoch: 9 | Train Acc: 0.5261, Test Acc: 0.5691, Time: 65.3, lr: 0.010000
ViT: Epoch: 10 | Train Acc: 0.5385, Test Acc: 0.6042, Time: 64.7, lr: 0.009050
ViT: Epoch: 11 | Train Acc: 0.5614, Test Acc: 0.6253, Time: 63.5, lr: 0.008100
ViT: Epoch: 12 | Train Acc: 0.5723, Test Acc: 0.6294, Time: 65