In [None]:
!nvidia-smi
!pip install -q einops torchvision
!pip install -q timm

In [None]:
import math, torch, torch.nn as nn, torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from einops import rearrange

class Config:
    img_size = 32
    patch_size = 4
    in_channels = 3
    num_classes = 10
    embed_dim = 384
    depth = 12
    num_heads = 8
    mlp_ratio = 4
    dropout = 0.1
    attn_dropout = 0.1
    batch_size = 128
    epochs = 200
    lr = 3e-4
    weight_decay = 0.05
    warmup_epochs = 5
    device = "cuda" if torch.cuda.is_available() else "cpu"

cfg = Config()
print("Using device:", cfg.device)


In [None]:
mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
from timm.data.mixup import Mixup
mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, label_smoothing=0.1, num_classes=cfg.num_classes)

train_tf = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    transforms.RandomErasing(p=0.25),
])

test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

train_ds = datasets.CIFAR10("./data", train=True, download=True, transform=train_tf)
test_ds  = datasets.CIFAR10("./data", train=False, download=True, transform=test_tf)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=2, pin_memory=True)


In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, embed_dim):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2
    def forward(self, x):
        x = self.proj(x)          # (B, E, H/ps, W/ps)
        x = x.flatten(2).transpose(1,2)  # (B, n_patches, E)
        return x

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads, attn_dropout=0., proj_dropout=0.):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim*3, bias=True)
        self.attn_drop = nn.Dropout(attn_dropout)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_dropout)
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B,N,3,self.num_heads,self.head_dim).permute(2,0,3,1,4)
        q,k,v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1,2).reshape(B,N,C)
        return self.proj_drop(self.proj(x))

class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4., dropout=0.):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(dropout)
    def forward(self,x):
        return self.drop(self.fc2(self.drop(self.act(self.fc1(x)))))

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., dropout=0., attn_dropout=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim,num_heads,attn_dropout,dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim,mlp_ratio,dropout)
    def forward(self,x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, num_classes,
                 embed_dim, depth, num_heads, mlp_ratio, dropout, attn_dropout):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        n_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1,n_patches+1,embed_dim))
        self.pos_drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim,num_heads,mlp_ratio,dropout,attn_dropout) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim,num_classes)
        nn.init.trunc_normal_(self.pos_embed,std=0.02)
        nn.init.trunc_normal_(self.cls_token,std=0.02)
    def forward(self,x):
        B = x.size(0)
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B,-1,-1)
        x = torch.cat((cls_tokens,x),dim=1)
        x = self.pos_drop(x + self.pos_embed)
        for blk in self.blocks: x = blk(x)
        x = self.norm(x)
        return self.head(x[:,0])


In [None]:
def get_optimizer(model, lr, wd):
    param_groups = [
        {"params":[p for n,p in model.named_parameters() if p.requires_grad and 'bias' not in n and 'norm' not in n],"weight_decay":wd},
        {"params":[p for n,p in model.named_parameters() if p.requires_grad and ('bias' in n or 'norm' in n)],"weight_decay":0.0},
    ]
    return optim.AdamW(param_groups, lr=lr, betas=(0.9,0.999))

def get_cosine_with_warmup(optimizer, warmup_steps, total_steps):
    def lr_lambda(step):
        if step < warmup_steps: return float(step)/float(max(1,warmup_steps))
        progress = float(step-warmup_steps)/float(max(1,total_steps-warmup_steps))
        return 0.5*(1.0+math.cos(math.pi*progress))
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


In [None]:
criterion = nn.CrossEntropyLoss()

@torch.no_grad()
def evaluate(model,loader,device):
    model.eval(); total=0; correct=0; loss_sum=0
    for imgs,labels in loader:
        imgs,labels=imgs.to(device),labels.to(device)
        out=model(imgs); loss=criterion(out,labels)
        loss_sum+=loss.item()*imgs.size(0)
        pred=out.argmax(1); correct+=(pred==labels).sum().item()
        total+=imgs.size(0)
    return loss_sum/total, correct/total

def train_one_epoch(model,loader,optimizer,scaler,device,scheduler=None):
    model.train(); running_loss=0
    pbar=tqdm(loader,desc="Train")
    for imgs,labels in pbar:
        imgs,labels=imgs.to(device),labels.to(device)
        optimizer.zero_grad()

        if mixup_fn is not None:
           imgs, labels = mixup_fn(imgs, labels)

        with autocast():
            out = model(imgs)
            loss = criterion(out, labels)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)
        scaler.step(optimizer); scaler.update()
        if scheduler: scheduler.step()
        running_loss+=loss.item()*imgs.size(0)
        pbar.set_postfix({"loss":running_loss/((pbar.n+1)*loader.batch_size)})
    return running_loss/len(loader.dataset)


In [None]:
device = cfg.device
model = VisionTransformer(cfg.img_size,cfg.patch_size,cfg.in_channels,
                          cfg.num_classes,cfg.embed_dim,cfg.depth,cfg.num_heads,
                          cfg.mlp_ratio,cfg.dropout,cfg.attn_dropout).to(device)

optimizer=get_optimizer(model,cfg.lr,cfg.weight_decay)
total_steps=len(train_loader)*cfg.epochs
warmup_steps=cfg.warmup_epochs*len(train_loader)
scheduler=get_cosine_with_warmup(optimizer,warmup_steps,total_steps)
scaler=GradScaler()

best_acc=0; save_path="best_vit_cifar10.pth"

for epoch in range(cfg.epochs):
    print(f"Epoch {epoch+1}/{cfg.epochs}")
    train_loss=train_one_epoch(model,train_loader,optimizer,scaler,device,scheduler)
    val_loss,val_acc=evaluate(model,test_loader,device)
    print(f"Train Loss:{train_loss:.4f} | Val Loss:{val_loss:.4f} | Val Acc:{val_acc*100:.2f}%")
    if val_acc>best_acc:
        best_acc=val_acc
        torch.save({"model_state":model.state_dict(),"acc":best_acc,"epoch":epoch},save_path)
        print(f"✓ Saved Best: {best_acc*100:.2f}%")
print("Training finished. Best Accuracy:",best_acc*100)
