<a href="https://colab.research.google.com/github/samitha278/nanoViT/blob/main/test2_vit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import gdown
import time
import math
import matplotlib.pyplot as plt

from dataclasses import dataclass

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class ViT(nn.Module):

    def __init__(self,config):
      super().__init__()


      self.config = config

      self.embd = PatchEmbedding(config)

      self.block = nn.ModuleList([Block(config)  for i in range(config.n_layer)])


      self.ln = nn.LayerNorm(config.n_embd)

      self.layer = nn.Linear(config.n_embd,config.num_classes)





    def forward(self,x,targets = None):

      B,C,H,W = x.shape


      #embedding
      out =  self.embd(x)

      #blocks
      for block in self.block:
        out = block(out)

      #layer norm
      out = self.ln(out)

      #linear layer
      out = self.layer(out[:,0])


      if targets is None:
        return out
      else:
        return out,F.cross_entropy(out,targets.view(-1))



#_______________________________________________________________________________


class PatchEmbedding(nn.Module):



  def __init__(self,config):
    super().__init__()

    self.config = config

    self.n_patches = config.n_patch
    self.patch_dim = config.im_channels* config.patch_size ** 2


    #patch embedding
    self.patch_embd = nn.Sequential(
        nn.LayerNorm(self.patch_dim),
        nn.Linear(self.patch_dim,config.n_embd),
        nn.LayerNorm(config.n_embd)
    )

    #cls tokens
    self.cls_token = nn.Parameter(torch.randn((config.n_embd,),device=device))

    #possitional embedding
    self.pos_embd = nn.Embedding(self.n_patches+1,config.n_embd)    # +1 for cls token





  def forward(self,x):

    B,C,H,W = x.shape

    # B,C,H,W -> B, n_patches , patch_dim    # patch_dim = C* patch_size*patch_size

    patch_size = self.config.patch_size

    patches = F.unfold(x, patch_size, stride = patch_size).transpose(-1,-2)

    #patch embedding
    patch_embd = self.patch_embd(patches)        # B, n_patches , n_embd

    #class token
    class_tok = self.cls_token.expand(B,1,-1)     # B , 1 , n_embd


    patch_embd = torch.cat((class_tok,patch_embd),dim =1 )    # B, n_patches +1  , n_embd

    #positional embedding
    pos_embd = self.pos_embd(torch.arange(0,self.n_patches+1,device=device))     # B, n_patches +1  , n_embd

    out = patch_embd + pos_embd

    return out





#_______________________________________________________________________________



class Block(nn.Module):


  def __init__(self,config):
    super().__init__()



    self.ln_1 = nn.LayerNorm(config.n_embd)
    self.attn = Attention(config.n_embd,config.n_head)
    self.ln_2 = nn.LayerNorm(config.n_embd)
    self.mlp = MLP(config.n_embd)


  def forward(self,x):

    x = x + self.attn(self.ln_1(x))
    x = x + self.mlp(self.ln_2(x))

    return x


#_______________________________________________________________________________


class MLP(nn.Module):


  def __init__(self,n_embd):
    super().__init__()


    self.layer = nn.Linear(n_embd,4*n_embd)
    self.gelu = nn.GELU()
    self.proj = nn.Linear(4*n_embd,n_embd)
    self.dropout = nn.Dropout(0.2)


  def forward(self,x):


    x = self.gelu(self.layer(x))
    x = self.proj(x)
    x = self.dropout(x)

    return x


#_______________________________________________________________________________


class Attention(nn.Module):


  def __init__(self,n_embd,n_head) :
    super().__init__()

    self.nh = n_head

    self.w = nn.Linear(n_embd,3*n_embd)    # 3 * n_head * head_size
    self.proj = nn.Linear(n_embd,n_embd)


  def forward(self,x):

    B,T,C = x.shape

    wei = self.w(x)        # B,T, 3* C

    k,q,v = torch.chunk(wei,3, dim = -1)      # each B,T,C

    head_size = C//self.nh

    key   = k.view(B, T, self.nh, head_size).transpose(1, 2)    # B, n_head, T, head_size
    query = q.view(B, T, self.nh, head_size).transpose(1, 2)    # ""
    value = v.view(B, T, self.nh, head_size).transpose(1, 2)


    # weight = ( query @ key.transpose(-1,-2) )  * (head_size ** -0.5)    #B,nh,T,T
    # weight = F.softmax(weight,dim = -1)

    # out = weight @ value      #B,nh,T,n_head


    #Flash Attention
    out = F.scaled_dot_product_attention(query,key,value)



    out = out.transpose(1,2).view(B,T,C)    #B,T,nh,n_head

    out = self.proj(out)

    return out



In [None]:
@dataclass
class Config:
    num_classes: int = 10
    img_size: int = 32         # For CIFAR10
    im_channels: int = 3
    patch_size: int = 4

    n_head: int = 12
    n_layer: int = 12
    n_embd: int = 768

    dropout = 0.2

    @property
    def n_patch(self):
        return (self.img_size//self.patch_size)**2

In [None]:
transform_train = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.ImageNet(root='path/to/imagenet', train=True, download=True, transform=transform_train)
val_dataset = datasets.ImageNet(root='path/to/imagenet', train=False, download=True, transform=transform_train)

RuntimeError: The archive ILSVRC2012_devkit_t12.tar.gz is not present in the root directory or is corrupted. You need to download it externally and place it in path/to/imagenet.

In [None]:
torch.manual_seed(278)
if torch.cuda.is_available():
    torch.cuda.manual_seed(278)

batch_size = 128


train_data = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_data = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [None]:
import os
from google.colab import drive

# Mount drive
drive.mount('/content/drive')
save_dir = '/content/drive/MyDrive/vit_project'
os.makedirs(save_dir, exist_ok=True)

torch.manual_seed(278)
if torch.cuda.is_available():
    torch.cuda.manual_seed(278)

# Create model WITHOUT compiling yet
vit = model.to(device)
vit_compiled = torch.compile(vit)  # Keep reference to both

#------------------------------------------------------------------------------

max_iter = 10000
warm_up = max_iter * 0.1
max_lr = 3e-4
min_lr = max_lr * 0.1

def get_lr(i):
    if i < warm_up:
        return (max_lr/warm_up) * (i+1)

    if i > max_iter:
        return min_lr

    # cosine decay
    diff = max_lr - min_lr
    steps = max_iter - warm_up
    lr = (diff/2) * math.cos((i - warm_up) * (math.pi / steps)) + (diff/2) + min_lr
    return lr

#------------------------------------------------------------------------------

losses = torch.zeros((max_iter,))
lrs = torch.zeros((max_iter,))
norms = torch.zeros((max_iter,))

# Optimizer
use_fused = True if torch.cuda.is_available() else False
optimizer = torch.optim.AdamW(vit.parameters(), lr=max_lr, weight_decay=0.1, fused=use_fused)

# Gradient Scaler
scaler = torch.amp.GradScaler(device)

train_data_iter = iter(train_data)
best_val_acc = 0.0

# Validation function
def validate():
    vit.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for xb, yb in val_data:
            xb, yb = xb.to(device), yb.to(device)
            with torch.autocast(device_type=device, dtype=torch.float16):
                logits, _ = vit(xb, yb)
            pred = logits.argmax(dim=1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)
    vit.train()
    return correct / total

# Save function (uses original model, not compiled)
def save_checkpoint(step, val_acc, is_best=False):
    checkpoint = {
        'model_state_dict': vit.state_dict(),  # Use vit, NOT vit_compiled
        'optimizer_state_dict': optimizer.state_dict(),
        'step': step,
        'val_accuracy': val_acc,
        'config': {
            'num_classes': config.num_classes,
            'img_size': config.img_size,
            'im_channels': config.im_channels,
            'patch_size': config.patch_size,
            'n_head': config.n_head,
            'n_layer': config.n_layer,
            'n_embd': config.n_embd,
            'dropout': config.dropout
        }
    }

    if is_best:
        path = f'{save_dir}/vit_best.pth'
        torch.save(checkpoint, path)
        print(f'✓ Best model saved! Val Acc: {val_acc:.4f}')
    else:
        path = f'{save_dir}/vit_step_{step}.pth'
        torch.save(checkpoint, path)
        print(f'Checkpoint saved at step {step}')

#------------------------------------------------------------------------------
# Training Loop

vit.train()  # Set to training mode

for i in range(max_iter):
    t0 = time.time()

    try:
        xb, yb = next(train_data_iter)
    except StopIteration:
        train_data_iter = iter(train_data)
        xb, yb = next(train_data_iter)

    xb, yb = xb.to(device), yb.to(device)

    # Use compiled model for forward pass
    with torch.autocast(device_type=device, dtype=torch.float16):
        logits, loss = vit_compiled(xb, yb)

    optimizer.zero_grad()
    scaler.scale(loss).backward()

    # LR Schedule
    lr = get_lr(i)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # Gradient Clipping
    norm = torch.nn.utils.clip_grad_norm_(vit.parameters(), 1.0)

    scaler.step(optimizer)
    scaler.update()

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    t1 = time.time()
    dt = (t1 - t0) * 1000  # ms

    losses[i] = loss.item()
    lrs[i] = lr
    norms[i] = norm

    # Print progress
    if i % 1000 == 0:
        print(f'{i}/{max_iter}  {loss.item():.4f}  {dt:.4f} ms   norm:{norm.item():.4f}   lr:{lr:.4e}')

    # Validation and checkpointing
    if i % 2000 == 0 and i > 0:
        val_acc = validate()
        print(f'Validation accuracy: {val_acc:.4f}')

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_checkpoint(i, val_acc, is_best=True)

        # Save periodic checkpoint
        save_checkpoint(i, val_acc, is_best=False)

# Final save
final_val_acc = validate()
print(f'\nFinal validation accuracy: {final_val_acc:.4f}')
save_checkpoint(max_iter, final_val_acc, is_best=(final_val_acc > best_val_acc))

# Save training history
torch.save({
    'losses': losses,
    'lrs': lrs,
    'norms': norms
}, f'{save_dir}/training_history.pth')

In [None]:
# Validation Accuracy

torch.manual_seed(278)
if torch.cuda.is_available():
    torch.cuda.manual_seed(278)

correct, total = 0, 0
vit.eval()
vit = vit.to(device)
with torch.no_grad():
    for xb, yb in val_data:
        xb,yb = xb.to(device),yb.to(device)
        logits = vit(xb)
        preds = torch.argmax(logits, dim=-1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)

val_acc = correct / total
print(f"Validation accuracy: {val_acc:.4f}")