1.Preparations

In [1]:
import gc
from contextlib import nullcontext
from dataclasses import dataclass, field

import matplotlib.pyplot as plt
import numpy as np
import rich
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

In [2]:
def get_device(verbose: bool = False) -> torch.device:
    if torch.cuda.is_available():
        device = torch.device("cuda")
        if verbose:
            print("Using GPU")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        if verbose:
            print("Using Apple Silicon GPU")
    else:
        device = torch.device("cpu")
        if verbose:
            print("Using CPU")
    return device


def clear_cuda():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()


def tensor_to_device(*tensors, device: torch.device = torch.device("cpu"), non_blocking=True):
    moved = tuple(t.to(device, non_blocking=non_blocking) for t in tensors)
    return moved if len(moved) > 1 else moved[0]


def print_color(text: str, color: str = "green"):
    rich.print(f"[{color}]{text}[/{color}]")


def get_ctx(use_mixed: bool, device: torch.device, amp_mode: str = "auto"):
    if not use_mixed or amp_mode == "off":
        print("Not using autocast context")
        return nullcontext()

    device_type = device.type

    if amp_mode == "fp16":
        dtype = torch.float16
    elif amp_mode == "bf16":
        dtype = torch.bfloat16
    else:
        if device_type == "cuda":
            dtype = torch.bfloat16
        elif device_type == "mps":
            dtype = torch.float16
        elif device_type == "cpu":
            dtype = torch.float16
        else:
            return nullcontext()

    print(f"Using autocast with dtype={dtype} on device type={device_type}")
    return torch.autocast(device_type=device_type, dtype=dtype)

2.Vision Transformer Model

2.1Model Config

In [3]:
@dataclass
class ModelConfig:
    image_size : int = 32
    patch_size:int = 4
    num_channels:int=3
    num_classes:int=10
    
    num_layers:int=6
    num_heads:int=8
    d_model:int =128
    d_ff:int = 512
    
    dropout_rate:float=0.1
    attention_dropout_rate:float=0.1

2.2Patch Embedding

In [4]:
class PatchEmbedder(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        self.config = config
        self.num_patches_per_side = config.image_size//config.patch_size
        self.num_patches = self.num_patches_per_side ** 2
        
        self.proj = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels= config.d_model,
            kernel_size=config.patch_size,
            stride=config.patch_size,
        )
    
    def forward(self, x:torch.Tensor):
        #x: (B,C,H,W)
        x = self.proj(x) # (B,D,H/P,W/P)
        x = x.flatten(2) #(B,D,N)
        x = x.transpose(1,2) #(B,N,D)
        
        return x
        

2.3Position Embedding

In [5]:
class PosEmbedder(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        
        self.position_embeddings = nn.Parameter(
            torch.randn(1, (config.image_size // config.patch_size) ** 2 + 1, config.d_model)
        )

        
        self.cls_token=nn. Parameter(torch.randn(1,1,config.d_model))
        
    def forward(self, x:torch.Tensor):
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B,1,-1)
        x = torch.cat((cls_tokens,x),dim=1)
        
        x=x+self.position_embeddings
        
        return x

2.4MHA

In [6]:
class MHA(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        
        self.num_heads=config.num_heads
        self.d_model=config.d_model
        assert config.d_model % config.num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_k = config.d_model // config.num_heads
        self.qkv_linear = nn.Linear(config.d_model, config.d_model*3)
        self.out_linear = nn.Linear(config.d_model,config.d_model)
        
        self.attention_dropout = nn.Dropout(config.attention_dropout_rate)
    
    def forward(self,x: torch.Tensor):
        B,N,C = x.shape # Batch size, Number of tokens, Embedding dimension
        
        q,k,v=(
            self.qkv_linear(x).reshape(B,N,3,self.num_heads,self.d_k).permute(2,0,3,1,4).unbind(0)
        )
        
        scores = torch.matmul(q, k.transpose(-2,-1))/torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32)) # (B, num_heads, N, N)
        
        attn_weight = torch.softmax(scores,dim=-1)# (B, num_heads, N, N)
        attn = self.attention_dropout(attn_weight)
        
        context = torch.matmul(attn,v)# (B, num_heads, N, d_k)
        context = context.transpose(1,2).reshape(B,N,C)# (B, N, d_model)
        
        out =self.out_linear(context)# (B, N, d_model)
        
        return out, attn_weight
        

2.5FFN

In [7]:
class FFN(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        self.fc1 = nn.Linear(config.d_model,config.d_ff)
        self.fc2 = nn.Linear(config.d_ff,config.d_model)
        
        self.dropout = nn.Dropout(config.dropout_rate)
        
    def forward(self, x:torch.Tensor):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

2.6Layer Norm

In [8]:
class LayerNorm(nn.Module):
    def __init__(self, in_dim:int, eps:float=1e-6):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(in_dim))
        self.beta = nn.Parameter(torch.zeros(in_dim))
        
    def forward(self, x:torch.Tensor):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, keepdim=True, unbiased=False)
        
        x_hat = (x-mean)/torch.sqrt(var+self.eps)
        
        return self.gamma*x_hat+self.beta
        

2.7ViT Encoder

In [9]:
class EncoderBlock(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        
        self.mha = MHA(config)
        self.ffn = FFN(config)
        
        self.norm1 = LayerNorm(config.d_model)
        self.norm2 = LayerNorm(config.d_model)
        
    def forward(self, x:torch.Tensor):
        attn, _ = self.mha(self.norm1(x))
        x = x + attn
        x = x + self.ffn(self.norm2(x))
        
        return x

2.8Classifier Head

In [10]:
class MLPHead(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        
        self.fc1 = nn.Linear(config.d_model, config.d_model)
        self.fc2 = nn.Linear(config.d_model, config.num_classes)
        self.dropout = nn.Dropout(config.dropout_rate)
        
    def forward(self, x:torch.Tensor):
        cls = x[:,0,:]
        cls = self.dropout(F.relu(self.fc1(cls)))
        cls = self.fc2(cls)
        
        return cls
        

2.9Full ViT Model

In [11]:
class Backbone(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        self.patch_embedder = PatchEmbedder(config)
        self.pos_embedder = PosEmbedder(config)
        self.encoder_layers = nn.ModuleList([EncoderBlock(config) for _ in range(config.num_layers)])
        
    def forward(self, x:torch.Tensor):
        x = self.patch_embedder(x)
        x = self.pos_embedder(x)
        for layer in self.encoder_layers:
            x = layer(x)
            
        return x

In [12]:
class ViT(nn.Module):
    def __init__(self, config:ModelConfig):
        super().__init__()
        
        self.backbone = Backbone(config)
        self.class_head = MLPHead(config)
        
    def forward(self,x :torch.Tensor):
        x = self.backbone(x)
        x = self.class_head(x)
        
        return x

2.10Dummy Test

In [13]:
DEVICE = get_device(verbose=True)
config = ModelConfig()
model = ViT(config).to(DEVICE)
dummy_input = torch.randn(2, config.num_channels, config.image_size, config.image_size).to(DEVICE)
dummy_output = model(dummy_input)

assert dummy_output.shape == (2, config.num_classes)

Using GPU
