In [1]:
!pip install -q torch torchvision

import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10

# Normalization values for CIFAR-10
mean = (0.4914, 0.4822, 0.4465)
std = (0.2470, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.IMAGENET),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

train_set = CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_set  = CIFAR10(root='./data', train=False, download=True, transform=test_transform)


100%|██████████| 170M/170M [00:03<00:00, 43.5MB/s]


In [2]:
!pip install -q torch torchvision timm

In [3]:
import argparse
import math
import os
import random
from functools import partial
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm

In [4]:
# Utilities
# ---------------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append((correct_k.mul_(100.0 / batch_size)).item())
        return res


In [5]:
# MixUp helper (optional)
# ---------------------------
def mixup_data(x, y, alpha=1.0, device='cuda'):
    if alpha <= 0:
        return x, y, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


In [6]:
# Vision Transformer (ViT)
# ---------------------------
class PatchEmbed(nn.Module):
    """Split image into patches and embed them."""
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        assert img_size % patch_size == 0, 'Image size must be divisible by patch size.'
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size
        num_patches = self.grid_size * self.grid_size
        self.num_patches = num_patches

        # Implement as a conv layer for speed: kernel=patch_size, stride=patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        # After conv, flatten to (B, num_patches, embed_dim)

    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, embed_dim, H/patch, W/patch)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)
        return x


class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim=None, dropout=0.):
        out_dim = out_dim or in_dim
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, out_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=4.0, dropout=0., attn_dropout=0., norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=attn_dropout, batch_first=True)
        self.drop_path = nn.Identity()  # placeholder for stochastic depth if desired
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(dim, mlp_hidden_dim, dropout=dropout)

    def forward(self, x):
        # x: (B, N, dim)
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)  # (B, N, dim)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x


class ViT(nn.Module):
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_chans=3,
        num_classes=10,
        embed_dim=256,
        depth=8,
        num_heads=8,
        mlp_ratio=2.0,
        dropout=0.0,
        emb_dropout=0.0,
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        # CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # Position embeddings (including cls token)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=emb_dropout)

        # Transformer encoder
        self.blocks = nn.ModuleList([
            TransformerEncoderLayer(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                attn_dropout=dropout,
            )
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialization
        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        # head init
        nn.init.zeros_(self.head.weight)
        nn.init.zeros_(self.head.bias)
        # patch proj uses default kaiming

    def forward(self, x):
        # x: (B, C, H, W)
        B = x.shape[0]
        x = self.patch_embed(x)  # (B, num_patches, embed_dim)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B,1,embed_dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, num_patches+1, embed_dim)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls = x[:, 0]  # (B, embed_dim)
        logits = self.head(cls)
        return logits



In [7]:
# Training / Evaluation
# ---------------------------
import numpy as np
from torch.cuda.amp import GradScaler, autocast

def get_dataloaders(batch_size=128, num_workers=4, use_validation=True):
    # Augmentations
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])

    train_set = CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    test_set = CIFAR10(root='./data', train=False, download=True, transform=test_transform)

    if use_validation:
        val_ratio = 0.05
        n = len(train_set)
        n_val = int(val_ratio * n)
        n_train = n - n_val
        train_set, val_set = torch.utils.data.random_split(train_set, [n_train, n_val])
        val_set.dataset.transform = test_transform
        val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    else:
        val_loader = None

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    return train_loader, val_loader, test_loader


def cosine_scheduler_with_warmup(optimizer, warmup_epochs, max_epochs, base_lr, final_lr=0., last_epoch=-1):
    # returns a LambdaLR scheduler
    def lr_lambda(current_step):
        current_epoch = current_step
        if current_epoch < warmup_epochs:
            return float(current_epoch) / float(max(1, warmup_epochs))
        # cosine decay
        progress = float(current_epoch - warmup_epochs) / float(max(1, max_epochs - warmup_epochs))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) * (1 - final_lr / base_lr) + final_lr/base_lr
    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, scaler=None, mixup_alpha=0.0):
    model.train()
    losses = []
    top1 = 0.0
    total = 0
    pbar = tqdm(data_loader, desc=f"Train E{epoch}", leave=False)
    for images, targets in pbar:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()
        with autocast(enabled=(scaler is not None)):
            outputs = model(images)
            loss = criterion(outputs, targets)

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        acc1 = accuracy(outputs, targets, topk=(1,))[0]
        losses.append(loss.item())
        top1 += acc1 * images.size(0)
        total += images.size(0)

        pbar.set_postfix({'loss': f'{np.mean(losses):.4f}', 'acc1': f'{(top1/total):.2f}'})

    return np.mean(losses), (top1 / total)


def evaluate(model, criterion, data_loader, device):
    model.eval()
    losses = []
    top1 = 0.0
    total = 0
    with torch.no_grad():
        pbar = tqdm(data_loader, desc="Eval", leave=False)
        for images, targets in pbar:
            images = images.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            outputs = model(images)
            loss = criterion(outputs, targets)
            acc1 = accuracy(outputs, targets, topk=(1,))[0]
            losses.append(loss.item())
            top1 += acc1 * images.size(0)
            total += images.size(0)
            pbar.set_postfix({'loss': f'{np.mean(losses):.4f}', 'acc1': f'{(top1/total):.2f}'})
    return np.mean(losses), (top1 / total)


In [4]:
# Main: training entrypoint and Utilities
# ---------------------------
import argparse
import math
import os
import random
from functools import partial
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm
import numpy as np
from torch.cuda.amp import GradScaler, autocast

# Utilities
# ---------------------------
def set_seed(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append((correct_k.mul_(100.0 / batch_size)).item())
        return res


# MixUp helper (optional)
# ---------------------------
def mixup_data(x, y, alpha=1.0, device='cuda'):
    if alpha <= 0:
        return x, y, None, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# Vision Transformer (ViT)
# ---------------------------
class PatchEmbed(nn.Module):
    """Split image into patches and embed them."""
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        assert img_size % patch_size == 0, 'Image size must be divisible by patch size.'
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size
        num_patches = self.grid_size * self.grid_size
        self.num_patches = num_patches

        # Implement as a conv layer for speed: kernel=patch_size, stride=patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        # After conv, flatten to (B, num_patches, embed_dim)

    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, embed_dim, H/patch, W/patch)
        x = x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)
        return x


class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim=None, dropout=0.):
        out_dim = out_dim or in_dim
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, out_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=4.0, dropout=0., attn_dropout=0., norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=attn_dropout, batch_first=True)
        self.drop_path = nn.Identity()  # placeholder for stochastic depth if desired
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(dim, mlp_hidden_dim, dropout=dropout)

    def forward(self, x):
        # x: (B, N, dim)
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)  # (B, N, dim)
        x = x + attn_out
        x = x + self.mlp(self.norm2(x))
        return x


class ViT(nn.Module):
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_chans=3,
        num_classes=10,
        embed_dim=256,
        depth=8,
        num_heads=8,
        mlp_ratio=2.0,
        dropout=0.0,
        emb_dropout=0.0,
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        # CLS token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        # Position embeddings (including cls token)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=emb_dropout)

        # Transformer encoder
        self.blocks = nn.ModuleList([
            TransformerEncoderLayer(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                attn_dropout=dropout,
            )
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes)

        # Initialization
        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        # head init
        nn.init.zeros_(self.head.weight)
        nn.init.zeros_(self.head.bias)
        # patch proj uses default kaiming

    def forward(self, x):
        # x: (B, C, H, W)
        B = x.shape[0]
        x = self.patch_embed(x)  # (B, num_patches, embed_dim)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B,1,embed_dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, num_patches+1, embed_dim)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls = x[:, 0]  # (B, embed_dim)
        logits = self.head(cls)
        return logits


# Training / Evaluation
# ---------------------------
def get_dataloaders(batch_size=128, num_workers=4, use_validation=True):
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])

    train_set = CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    test_set = CIFAR10(root='./data', train=False, download=True, transform=test_transform)

    if use_validation:
        # Optionally split train to train/val
        val_ratio = 0.05
        n = len(train_set)
        n_val = int(val_ratio * n)
        n_train = n - n_val
        train_set, val_set = torch.utils.data.random_split(train_set, [n_train, n_val])
        val_set.dataset.transform = test_transform  # use test transform for validation
        val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    else:
        val_loader = None

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    return train_loader, val_loader, test_loader


def cosine_scheduler_with_warmup(optimizer, warmup_epochs, max_epochs, base_lr, final_lr=0., last_epoch=-1):
    # returns a LambdaLR scheduler
    def lr_lambda(current_step):
        current_epoch = current_step
        if current_epoch < warmup_epochs:
            return float(current_epoch) / float(max(1, warmup_epochs))
        # cosine decay
        progress = float(current_epoch - warmup_epochs) / float(max(1, max_epochs - warmup_epochs))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) * (1 - final_lr / base_lr) + final_lr/base_lr
    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, scaler=None, mixup_alpha=0.0):
    model.train()
    losses = []
    top1 = 0.0
    total = 0
    pbar = tqdm(data_loader, desc=f"Train E{epoch}", leave=False)
    for images, targets in pbar:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()
        with autocast(enabled=(scaler is not None)):
            outputs = model(images)
            loss = criterion(outputs, targets)

        if scaler is not None:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        acc1 = accuracy(outputs, targets, topk=(1,))[0]
        losses.append(loss.item())
        top1 += acc1 * images.size(0)
        total += images.size(0)

        pbar.set_postfix({'loss': f'{np.mean(losses):.4f}', 'acc1': f'{(top1/total):.2f}'})

    return np.mean(losses), (top1 / total)


def evaluate(model, criterion, data_loader, device):
    model.eval()
    losses = []
    top1 = 0.0
    total = 0
    with torch.no_grad():
        pbar = tqdm(data_loader, desc="Eval", leave=False)
        for images, targets in pbar:
            images = images.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            outputs = model(images)
            loss = criterion(outputs, targets)
            acc1 = accuracy(outputs, targets, topk=(1,))[0]
            losses.append(loss.item())
            top1 += acc1 * images.size(0)
            total += images.size(0)
            pbar.set_postfix({'loss': f'{np.mean(losses):.4f}', 'acc1': f'{(top1/total):.2f}'})
    return np.mean(losses), (top1 / total)


def main(args):
    set_seed(args.seed)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f'Using device: {device}')

    train_loader, val_loader, test_loader = get_dataloaders(batch_size=args.batch_size, num_workers=args.num_workers)

    model = ViT(
        img_size=32,
        patch_size=args.patch_size,
        in_chans=3,
        num_classes=10,
        embed_dim=args.embed_dim,
        depth=args.depth,
        num_heads=args.num_heads,
        mlp_ratio=args.mlp_ratio,
        dropout=args.dropout,
        emb_dropout=args.emb_dropout,
    )
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    scheduler = cosine_scheduler_with_warmup(
        optimizer,
        warmup_epochs=args.warmup_epochs,
        max_epochs=args.epochs,
        base_lr=args.lr,
        final_lr=args.min_lr
    )

    scaler = GradScaler() if args.use_amp and device == 'cuda' else None

    best_val_acc = 0.0
    best_test_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(1, args.epochs + 1):
        # train
        train_loss, train_acc = train_one_epoch(model, criterion, optimizer, train_loader, device, epoch, scaler=scaler)
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)

        # step scheduler per-epoch (our lambda scheduler expects epoch steps)
        scheduler.step()

        # eval on validation or test
        if val_loader is not None:
            val_loss, val_acc = evaluate(model, criterion, val_loader, device)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            is_best = val_acc > best_val_acc
            if is_best:
                best_val_acc = val_acc
                # evaluate on test set as well for reporting
                test_loss, test_acc = evaluate(model, criterion, test_loader, device)
                best_test_acc = test_acc
                # save best
                torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(),
                            'val_acc': val_acc, 'test_acc': test_acc}, args.checkpoint_path)
                print(f"New best val acc: {val_acc:.2f}%, test acc: {best_test_acc:.2f}%, saved checkpoint.")
        else:
            # No validation split: test every epoch
            test_loss, test_acc = evaluate(model, criterion, test_loader, device)
            if test_acc > best_test_acc:
                best_test_acc = test_acc
                torch.save({'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict(),
                            'test_acc': test_acc}, args.checkpoint_path)
                print(f"New best test acc: {test_acc:.2f}%, saved checkpoint.")

        print(f"Epoch {epoch}: train_acc={train_acc:.2f}%  val_acc={best_val_acc:.2f}%  best_test_acc={best_test_acc:.2f}%")

    # Final evaluation
    print("Training complete. Loading best checkpoint and evaluating on test set...")
    if os.path.exists(args.checkpoint_path):
        ckpt = torch.load(args.checkpoint_path, map_location=device)
        model.load_state_dict(ckpt['model_state'])
        test_loss, test_acc = evaluate(model, criterion, test_loader, device)
        print(f"Best checkpoint test acc: {test_acc:.2f}%")
    else:
        test_loss, test_acc = evaluate(model, criterion, test_loader, device)
        print(f"No checkpoint saved. Final test acc: {test_acc:.2f}%")

    print("Done.")


if __name__ == '__main__':
    import sys
    # Explicitly set sys.argv to a list containing only the script name
    sys.argv = [sys.argv[0]]

    parser = argparse.ArgumentParser(description='ViT CIFAR-10 training')
    parser.add_argument('--epochs', default=150, type=int)
    parser.add_argument('--batch-size', default=256, type=int)
    parser.add_argument('--lr', default=3e-4, type=float)
    parser.add_argument('--min-lr', default=1e-6, type=float)
    parser.add_argument('--weight-decay', default=0.05, type=float)
    parser.add_argument('--momentum', default=0.9, type=float)
    parser.add_argument('--patch-size', default=4, type=int)
    parser.add_argument('--embed-dim', dest='embed_dim', default=256, type=int)
    parser.add_argument('--depth', default=8, type=int)
    parser.add_argument('--num-heads', dest='num_heads', default=8, type=int)
    parser.add_argument('--mlp-ratio', dest='mlp_ratio', default=2.0, type=float)
    parser.add_argument('--dropout', default=0.0, type=float)
    parser.add_argument('--emb-dropout', default=0.0, type=float)
    parser.add_argument('--warmup-epochs', default=10, type=int)
    parser.add_argument('--use-amp', action='store_true', help='Use mixed precision')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--num-workers', default=4, type=int)
    parser.add_argument('--checkpoint-path', default='vit_cifar10_checkpoint.pth', type=str)
    args = parser.parse_args([]) # Pass an empty list to parse_args()
    main(args)

Using device: cuda


  with autocast(enabled=(scaler is not None)):


New best val acc: 10.12%, test acc: 10.00%, saved checkpoint.
Epoch 1: train_acc=9.99%  val_acc=10.12%  best_test_acc=10.00%




New best val acc: 32.44%, test acc: 33.05%, saved checkpoint.
Epoch 2: train_acc=26.26%  val_acc=32.44%  best_test_acc=33.05%




New best val acc: 43.36%, test acc: 42.71%, saved checkpoint.
Epoch 3: train_acc=37.10%  val_acc=43.36%  best_test_acc=42.71%




New best val acc: 48.40%, test acc: 49.42%, saved checkpoint.
Epoch 4: train_acc=45.89%  val_acc=48.40%  best_test_acc=49.42%




New best val acc: 54.76%, test acc: 54.16%, saved checkpoint.
Epoch 5: train_acc=52.19%  val_acc=54.76%  best_test_acc=54.16%




New best val acc: 59.20%, test acc: 58.57%, saved checkpoint.
Epoch 6: train_acc=57.94%  val_acc=59.20%  best_test_acc=58.57%




New best val acc: 60.28%, test acc: 59.89%, saved checkpoint.
Epoch 7: train_acc=61.69%  val_acc=60.28%  best_test_acc=59.89%




New best val acc: 63.36%, test acc: 62.56%, saved checkpoint.
Epoch 8: train_acc=65.64%  val_acc=63.36%  best_test_acc=62.56%




New best val acc: 64.00%, test acc: 63.15%, saved checkpoint.
Epoch 9: train_acc=68.50%  val_acc=64.00%  best_test_acc=63.15%




New best val acc: 66.40%, test acc: 65.17%, saved checkpoint.
Epoch 10: train_acc=71.27%  val_acc=66.40%  best_test_acc=65.17%




Epoch 11: train_acc=73.75%  val_acc=66.40%  best_test_acc=65.17%




New best val acc: 66.60%, test acc: 65.38%, saved checkpoint.
Epoch 12: train_acc=77.45%  val_acc=66.60%  best_test_acc=65.38%




New best val acc: 67.48%, test acc: 65.65%, saved checkpoint.
Epoch 13: train_acc=80.30%  val_acc=67.48%  best_test_acc=65.65%




Epoch 14: train_acc=83.57%  val_acc=67.48%  best_test_acc=65.65%




Epoch 15: train_acc=85.26%  val_acc=67.48%  best_test_acc=65.65%




Epoch 16: train_acc=88.13%  val_acc=67.48%  best_test_acc=65.65%




Epoch 17: train_acc=89.46%  val_acc=67.48%  best_test_acc=65.65%




Epoch 18: train_acc=91.35%  val_acc=67.48%  best_test_acc=65.65%




Epoch 19: train_acc=92.56%  val_acc=67.48%  best_test_acc=65.65%




Epoch 20: train_acc=93.57%  val_acc=67.48%  best_test_acc=65.65%




Epoch 21: train_acc=94.31%  val_acc=67.48%  best_test_acc=65.65%




Epoch 22: train_acc=95.12%  val_acc=67.48%  best_test_acc=65.65%




New best val acc: 67.84%, test acc: 66.04%, saved checkpoint.
Epoch 23: train_acc=95.77%  val_acc=67.84%  best_test_acc=66.04%




Epoch 24: train_acc=96.00%  val_acc=67.84%  best_test_acc=66.04%




New best val acc: 67.88%, test acc: 65.86%, saved checkpoint.
Epoch 25: train_acc=96.28%  val_acc=67.88%  best_test_acc=65.86%




Epoch 26: train_acc=96.81%  val_acc=67.88%  best_test_acc=65.86%




Epoch 27: train_acc=96.90%  val_acc=67.88%  best_test_acc=65.86%




Epoch 28: train_acc=97.20%  val_acc=67.88%  best_test_acc=65.86%




Epoch 29: train_acc=97.06%  val_acc=67.88%  best_test_acc=65.86%




Epoch 30: train_acc=97.32%  val_acc=67.88%  best_test_acc=65.86%




Epoch 31: train_acc=97.71%  val_acc=67.88%  best_test_acc=65.86%




Epoch 32: train_acc=97.56%  val_acc=67.88%  best_test_acc=65.86%




Epoch 33: train_acc=97.63%  val_acc=67.88%  best_test_acc=65.86%




Epoch 34: train_acc=97.69%  val_acc=67.88%  best_test_acc=65.86%




New best val acc: 68.72%, test acc: 67.35%, saved checkpoint.
Epoch 35: train_acc=98.02%  val_acc=68.72%  best_test_acc=67.35%




Epoch 36: train_acc=98.15%  val_acc=68.72%  best_test_acc=67.35%




Epoch 37: train_acc=97.91%  val_acc=68.72%  best_test_acc=67.35%




Epoch 38: train_acc=97.88%  val_acc=68.72%  best_test_acc=67.35%




Epoch 39: train_acc=98.29%  val_acc=68.72%  best_test_acc=67.35%




Epoch 40: train_acc=98.29%  val_acc=68.72%  best_test_acc=67.35%




Epoch 41: train_acc=98.52%  val_acc=68.72%  best_test_acc=67.35%




Epoch 42: train_acc=98.20%  val_acc=68.72%  best_test_acc=67.35%




Epoch 43: train_acc=98.29%  val_acc=68.72%  best_test_acc=67.35%




Epoch 44: train_acc=98.57%  val_acc=68.72%  best_test_acc=67.35%




Epoch 45: train_acc=98.46%  val_acc=68.72%  best_test_acc=67.35%




Epoch 46: train_acc=98.30%  val_acc=68.72%  best_test_acc=67.35%




Epoch 47: train_acc=98.83%  val_acc=68.72%  best_test_acc=67.35%




Epoch 48: train_acc=98.80%  val_acc=68.72%  best_test_acc=67.35%




Epoch 49: train_acc=98.54%  val_acc=68.72%  best_test_acc=67.35%




New best val acc: 68.92%, test acc: 66.55%, saved checkpoint.
Epoch 50: train_acc=98.64%  val_acc=68.92%  best_test_acc=66.55%




Epoch 51: train_acc=98.52%  val_acc=68.92%  best_test_acc=66.55%




Epoch 52: train_acc=98.81%  val_acc=68.92%  best_test_acc=66.55%




Epoch 53: train_acc=98.99%  val_acc=68.92%  best_test_acc=66.55%




Epoch 54: train_acc=98.68%  val_acc=68.92%  best_test_acc=66.55%




Epoch 55: train_acc=98.99%  val_acc=68.92%  best_test_acc=66.55%




Epoch 56: train_acc=99.26%  val_acc=68.92%  best_test_acc=66.55%




Epoch 57: train_acc=99.20%  val_acc=68.92%  best_test_acc=66.55%




Epoch 58: train_acc=98.86%  val_acc=68.92%  best_test_acc=66.55%




New best val acc: 69.36%, test acc: 67.52%, saved checkpoint.
Epoch 59: train_acc=99.21%  val_acc=69.36%  best_test_acc=67.52%




Epoch 60: train_acc=99.13%  val_acc=69.36%  best_test_acc=67.52%




Epoch 61: train_acc=98.96%  val_acc=69.36%  best_test_acc=67.52%




Epoch 62: train_acc=99.35%  val_acc=69.36%  best_test_acc=67.52%




Epoch 63: train_acc=99.24%  val_acc=69.36%  best_test_acc=67.52%




Epoch 64: train_acc=99.26%  val_acc=69.36%  best_test_acc=67.52%




Epoch 65: train_acc=99.00%  val_acc=69.36%  best_test_acc=67.52%




Epoch 66: train_acc=99.46%  val_acc=69.36%  best_test_acc=67.52%




Epoch 67: train_acc=99.40%  val_acc=69.36%  best_test_acc=67.52%




Epoch 68: train_acc=99.16%  val_acc=69.36%  best_test_acc=67.52%




Epoch 69: train_acc=99.53%  val_acc=69.36%  best_test_acc=67.52%




Epoch 70: train_acc=99.45%  val_acc=69.36%  best_test_acc=67.52%




Epoch 71: train_acc=99.36%  val_acc=69.36%  best_test_acc=67.52%




Epoch 72: train_acc=99.40%  val_acc=69.36%  best_test_acc=67.52%




Epoch 73: train_acc=99.53%  val_acc=69.36%  best_test_acc=67.52%




Epoch 74: train_acc=99.55%  val_acc=69.36%  best_test_acc=67.52%




Epoch 75: train_acc=99.32%  val_acc=69.36%  best_test_acc=67.52%




Epoch 76: train_acc=99.64%  val_acc=69.36%  best_test_acc=67.52%




Epoch 77: train_acc=99.79%  val_acc=69.36%  best_test_acc=67.52%




Epoch 78: train_acc=99.50%  val_acc=69.36%  best_test_acc=67.52%




Epoch 79: train_acc=99.48%  val_acc=69.36%  best_test_acc=67.52%




Epoch 80: train_acc=99.53%  val_acc=69.36%  best_test_acc=67.52%




Epoch 81: train_acc=99.63%  val_acc=69.36%  best_test_acc=67.52%




Epoch 82: train_acc=99.78%  val_acc=69.36%  best_test_acc=67.52%




Epoch 83: train_acc=99.73%  val_acc=69.36%  best_test_acc=67.52%




Epoch 84: train_acc=99.67%  val_acc=69.36%  best_test_acc=67.52%




Epoch 85: train_acc=99.83%  val_acc=69.36%  best_test_acc=67.52%




Epoch 86: train_acc=99.76%  val_acc=69.36%  best_test_acc=67.52%




Epoch 87: train_acc=99.62%  val_acc=69.36%  best_test_acc=67.52%




Epoch 88: train_acc=99.84%  val_acc=69.36%  best_test_acc=67.52%




Epoch 89: train_acc=99.72%  val_acc=69.36%  best_test_acc=67.52%




Epoch 90: train_acc=99.76%  val_acc=69.36%  best_test_acc=67.52%




Epoch 91: train_acc=99.84%  val_acc=69.36%  best_test_acc=67.52%




Epoch 92: train_acc=99.77%  val_acc=69.36%  best_test_acc=67.52%




Epoch 93: train_acc=99.88%  val_acc=69.36%  best_test_acc=67.52%




Epoch 94: train_acc=99.99%  val_acc=69.36%  best_test_acc=67.52%




Epoch 95: train_acc=100.00%  val_acc=69.36%  best_test_acc=67.52%




Epoch 96: train_acc=100.00%  val_acc=69.36%  best_test_acc=67.52%




Epoch 97: train_acc=100.00%  val_acc=69.36%  best_test_acc=67.52%




Epoch 98: train_acc=100.00%  val_acc=69.36%  best_test_acc=67.52%




New best val acc: 69.40%, test acc: 68.56%, saved checkpoint.
Epoch 99: train_acc=100.00%  val_acc=69.40%  best_test_acc=68.56%




Epoch 100: train_acc=100.00%  val_acc=69.40%  best_test_acc=68.56%




New best val acc: 69.44%, test acc: 68.60%, saved checkpoint.
Epoch 101: train_acc=100.00%  val_acc=69.44%  best_test_acc=68.60%




New best val acc: 69.48%, test acc: 68.64%, saved checkpoint.
Epoch 102: train_acc=100.00%  val_acc=69.48%  best_test_acc=68.64%




New best val acc: 69.56%, test acc: 68.66%, saved checkpoint.
Epoch 103: train_acc=100.00%  val_acc=69.56%  best_test_acc=68.66%




New best val acc: 69.56%, test acc: 68.67%, saved checkpoint.
Epoch 104: train_acc=100.00%  val_acc=69.56%  best_test_acc=68.67%




New best val acc: 69.68%, test acc: 68.69%, saved checkpoint.
Epoch 105: train_acc=100.00%  val_acc=69.68%  best_test_acc=68.69%




Epoch 106: train_acc=100.00%  val_acc=69.68%  best_test_acc=68.69%




New best val acc: 69.72%, test acc: 68.80%, saved checkpoint.
Epoch 107: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




Epoch 108: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




Epoch 109: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




Epoch 110: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




Epoch 111: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




Epoch 112: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




Epoch 113: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




Epoch 114: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




Epoch 115: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




Epoch 116: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




Epoch 117: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




Epoch 118: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




Epoch 119: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




Epoch 120: train_acc=100.00%  val_acc=69.72%  best_test_acc=68.80%




New best val acc: 69.76%, test acc: 69.32%, saved checkpoint.
Epoch 121: train_acc=100.00%  val_acc=69.76%  best_test_acc=69.32%




Epoch 122: train_acc=100.00%  val_acc=69.76%  best_test_acc=69.32%




New best val acc: 69.80%, test acc: 69.35%, saved checkpoint.
Epoch 123: train_acc=100.00%  val_acc=69.80%  best_test_acc=69.35%




New best val acc: 69.88%, test acc: 69.37%, saved checkpoint.
Epoch 124: train_acc=100.00%  val_acc=69.88%  best_test_acc=69.37%




Epoch 125: train_acc=100.00%  val_acc=69.88%  best_test_acc=69.37%




Epoch 126: train_acc=100.00%  val_acc=69.88%  best_test_acc=69.37%




Epoch 127: train_acc=100.00%  val_acc=69.88%  best_test_acc=69.37%




New best val acc: 69.92%, test acc: 69.37%, saved checkpoint.
Epoch 128: train_acc=100.00%  val_acc=69.92%  best_test_acc=69.37%




Epoch 129: train_acc=100.00%  val_acc=69.92%  best_test_acc=69.37%




New best val acc: 69.96%, test acc: 69.39%, saved checkpoint.
Epoch 130: train_acc=100.00%  val_acc=69.96%  best_test_acc=69.39%




Epoch 131: train_acc=100.00%  val_acc=69.96%  best_test_acc=69.39%




Epoch 132: train_acc=100.00%  val_acc=69.96%  best_test_acc=69.39%




Epoch 133: train_acc=100.00%  val_acc=69.96%  best_test_acc=69.39%




Epoch 134: train_acc=100.00%  val_acc=69.96%  best_test_acc=69.39%




Epoch 135: train_acc=100.00%  val_acc=69.96%  best_test_acc=69.39%




New best val acc: 69.96%, test acc: 69.47%, saved checkpoint.
Epoch 136: train_acc=100.00%  val_acc=69.96%  best_test_acc=69.47%




Epoch 137: train_acc=100.00%  val_acc=69.96%  best_test_acc=69.47%




Epoch 138: train_acc=100.00%  val_acc=69.96%  best_test_acc=69.47%




New best val acc: 70.04%, test acc: 69.43%, saved checkpoint.
Epoch 139: train_acc=100.00%  val_acc=70.04%  best_test_acc=69.43%




Epoch 140: train_acc=100.00%  val_acc=70.04%  best_test_acc=69.43%




Epoch 141: train_acc=100.00%  val_acc=70.04%  best_test_acc=69.43%




Epoch 142: train_acc=100.00%  val_acc=70.04%  best_test_acc=69.43%




Epoch 143: train_acc=100.00%  val_acc=70.04%  best_test_acc=69.43%




Epoch 144: train_acc=100.00%  val_acc=70.04%  best_test_acc=69.43%




Epoch 145: train_acc=100.00%  val_acc=70.04%  best_test_acc=69.43%




Epoch 146: train_acc=100.00%  val_acc=70.04%  best_test_acc=69.43%




Epoch 147: train_acc=100.00%  val_acc=70.04%  best_test_acc=69.43%




Epoch 148: train_acc=100.00%  val_acc=70.04%  best_test_acc=69.43%




Epoch 149: train_acc=100.00%  val_acc=70.04%  best_test_acc=69.43%




Epoch 150: train_acc=100.00%  val_acc=70.04%  best_test_acc=69.43%
Training complete. Loading best checkpoint and evaluating on test set...


                                                                              

Best checkpoint test acc: 69.43%
Done.


