##Packages and Libraries

In [1]:
import math
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import AutoAugment, AutoAugmentPolicy
from torch.amp import GradScaler, autocast

##Patchify

In [2]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=384):
        super().__init__()
        assert img_size % patch_size == 0
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

##MLP

In [3]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        return self.drop(self.fc2(self.drop(self.act(self.fc1(x)))))

##Attention

In [4]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=6, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj_drop(self.proj(x))

##Stochastic Depth

In [5]:
class StochasticDepth(nn.Module):
    def __init__(self, p):
        super().__init__()
        self.p = p
    def forward(self, x):
        if not self.training or self.p == 0.: return x
        survival = 1. - self.p
        mask = torch.rand(x.shape[0], 1, 1, device=x.device) < survival
        return x * mask / survival

##Blocks

In [6]:
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias,
                              attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = StochasticDepth(drop_path)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio), drop=drop)
    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

##Vision Transformer

In [7]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, num_classes=10,
                 embed_dim=384, depth=12, num_heads=6, mlp_ratio=4.,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, 3, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(drop_rate)

        dpr = torch.linspace(0, drop_path_rate, depth)
        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True,
                  drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i])
            for i in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=.02)
        nn.init.trunc_normal_(self.cls_token, std=.02)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.pos_drop(x + self.pos_embed)
        for blk in self.blocks: x = blk(x)
        x = self.norm(x)
        return self.head(x[:, 0])

##Learning Rate Scheduler

In [8]:
class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_epochs, max_epochs, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        super().__init__(optimizer, last_epoch)
    def get_lr(self):
        cur = self.last_epoch
        if cur < self.warmup_epochs:
            return [base_lr * (cur + 1) / self.warmup_epochs for base_lr in self.base_lrs]
        t = (cur - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)
        return [base_lr * 0.5 * (1 + math.cos(math.pi * t)) for base_lr in self.base_lrs]

##Data

In [9]:
def get_loaders(batch_size=128):
    norm = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                (0.247, 0.243, 0.261))
    train_tfms = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        AutoAugment(AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(), norm
    ])
    test_tfms = transforms.Compose([transforms.ToTensor(), norm])
    train_set = datasets.CIFAR10("./data", train=True, download=True, transform=train_tfms)
    test_set = datasets.CIFAR10("./data", train=False, download=True, transform=test_tfms)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, test_loader

##Train

In [11]:
def train(model, train_loader, test_loader, device, epochs=50, patience=10):
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05)
    sched = WarmupCosineLR(opt, warmup_epochs=10, max_epochs=epochs)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    scaler = GradScaler()

    best_acc = 0.0
    best_epoch = 0
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        total_loss, total_correct, n = 0.0, 0, 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            with autocast(device_type=device):
                out = model(x)
                loss = criterion(out, y)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

            batch_size = x.size(0)
            total_loss += loss.item() * batch_size
            total_correct += (out.argmax(1) == y).sum().item()
            n += batch_size

        sched.step()

        train_loss = total_loss / n
        train_acc = total_correct / n * 100

        model.eval()
        test_correct, m = 0, 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                with autocast(device_type=device):
                    out = model(x)
                test_correct += (out.argmax(1) == y).sum().item()
                m += x.size(0)
        test_acc = test_correct / m * 100

        print(f"Epoch {epoch+1}/{epochs} | Train Loss {train_loss:.4f} | "
              f"Train Acc {train_acc:.2f}% | Test Acc {test_acc:.2f}%")

        if test_acc > best_acc:
            best_acc = test_acc
            best_epoch = epoch
            patience_counter = 0
            torch.save(model.state_dict(), "best_vit.pth")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}. "
                      f"Best accuracy: {best_acc:.2f}% (epoch {best_epoch+1})")
                break

    model.load_state_dict(torch.load("best_vit.pth"))
    print(f"Training finished. Best Test Accuracy: {best_acc:.2f}% at epoch {best_epoch+1}")
    return model


##Run

#Learning rate = 3e-4, batch_size = 128

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
train_loader, test_loader = get_loaders()
model = VisionTransformer().to(device)
model = train(model, train_loader, test_loader, device, epochs=1000, patience=10)

100%|██████████| 170M/170M [07:43<00:00, 368kB/s]


Epoch 1/1000 | Train Loss 2.0803 | Train Acc 25.22% | Test Acc 35.31%
Epoch 2/1000 | Train Loss 1.9388 | Train Acc 32.48% | Test Acc 40.05%
Epoch 3/1000 | Train Loss 1.8187 | Train Acc 38.37% | Test Acc 47.59%
Epoch 4/1000 | Train Loss 1.7353 | Train Acc 42.85% | Test Acc 51.34%
Epoch 5/1000 | Train Loss 1.6714 | Train Acc 45.49% | Test Acc 51.77%
Epoch 6/1000 | Train Loss 1.6294 | Train Acc 47.99% | Test Acc 57.32%
Epoch 7/1000 | Train Loss 1.5836 | Train Acc 49.98% | Test Acc 59.99%
Epoch 8/1000 | Train Loss 1.5596 | Train Acc 51.17% | Test Acc 61.11%
Epoch 9/1000 | Train Loss 1.5318 | Train Acc 52.92% | Test Acc 63.19%
Epoch 10/1000 | Train Loss 1.5088 | Train Acc 53.94% | Test Acc 62.39%
Epoch 11/1000 | Train Loss 1.4745 | Train Acc 55.18% | Test Acc 65.29%
Epoch 12/1000 | Train Loss 1.4437 | Train Acc 56.87% | Test Acc 65.67%
Epoch 13/1000 | Train Loss 1.4082 | Train Acc 58.88% | Test Acc 66.95%
Epoch 14/1000 | Train Loss 1.3821 | Train Acc 59.97% | Test Acc 69.03%
Epoch 15/1000 |