In [None]:
import os
import random
import math
from datetime import datetime
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, datasets

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [None]:
config = {
    "img_size": 32,
    "patch_size": 4,
    "in_chans": 3,
    "num_classes": 10,
    "embed_dim": 192,
    "depth": 8,
    "num_heads": 8,
    "mlp_ratio": 4.0,
    "dropout": 0.1,
    "attn_dropout": 0.1,
    "lr": 3e-4,
    "weight_decay": 0.05,
    "batch_size": 128,
    "epochs": 50,
    "grad_clip": None,
    "warmup_steps": 500,
}



In [None]:
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2470, 0.2435, 0.2616)

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])


test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

data_root = "./data"
train_dataset = datasets.CIFAR10(root=data_root, train=True, download=True, transform=train_transform)
test_dataset  = datasets.CIFAR10(root=data_root, train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=config["batch_size"], shuffle=False, num_workers=2, pin_memory=True)

print(f"Train size: {len(train_dataset)}, Test size: {len(test_dataset)}")


100%|██████████| 170M/170M [00:03<00:00, 43.5MB/s]


Train size: 50000, Test size: 10000


In [None]:
# Cell: MixUp helpers
import torch.nn.functional as F
import numpy as np

def mixup_data(x, y, alpha=0.8):
    """Returns mixed inputs, pairs of targets, and lambda"""
    if alpha <= 0:
        return x, y, 1.0, y, 1.0
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, lam, y_b

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)


In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=128):
        super().__init__()
        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        return x

In [None]:
_pe = PatchEmbed(img_size=config["img_size"], patch_size=config["patch_size"], in_chans=config["in_chans"], embed_dim=config["embed_dim"])
dummy = torch.randn(2, 3, 32, 32)
out = _pe(dummy)
print("PatchEmbed output shape:", out.shape)

PatchEmbed output shape: torch.Size([2, 64, 192])


In [None]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, dropout=0.0):
        super().__init__()
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1, attn_dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=attn_dropout, batch_first=True)
        self.drop1 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), dropout)
    def forward(self, x):
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)
        x = x + self.drop1(attn_out)
        x_norm = self.norm2(x)
        x = x + self.mlp(x_norm)
        return x

In [None]:
class ViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10,
                 embed_dim=128, depth=6, num_heads=8, mlp_ratio=4.0, dropout=0.1, attn_dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, dropout=dropout, attn_dropout=attn_dropout)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        cls_out = x[:, 0]
        logits = self.head(cls_out)
        return logits

In [None]:
class BatchWarmupCosineScheduler:
    def __init__(self, optimizer, base_lr, total_steps, warmup_steps=0, min_lr=0.0):
        self.optimizer = optimizer
        self.base_lr = base_lr
        self.total_steps = max(1, total_steps)
        self.warmup_steps = max(0, warmup_steps)
        self.min_lr = min_lr
        self.step_num = 0

    def step(self):
        self.step_num += 1
        if self.step_num <= self.warmup_steps and self.warmup_steps > 0:
            lr = self.base_lr * float(self.step_num) / float(max(1, self.warmup_steps))
        else:
            progress = float(self.step_num - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))
            progress = min(max(progress, 0.0), 1.0)
            lr = self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1.0 + math.cos(math.pi * progress))

        for pg in self.optimizer.param_groups:
            pg['lr'] = lr

    def get_lr(self):
        return self.optimizer.param_groups[0]['lr']


In [None]:
model = ViT(img_size=config["img_size"], patch_size=config["patch_size"], in_chans=config["in_chans"],
            num_classes=config["num_classes"], embed_dim=config["embed_dim"], depth=config["depth"],
            num_heads=config["num_heads"], mlp_ratio=config["mlp_ratio"], dropout=config["dropout"], attn_dropout=config["attn_dropout"])
print("Model params (M):", sum(p.numel() for p in model.parameters())/1e6)
dummy = torch.randn(4, 3, 32, 32)
print("Forward pass shape:", model(dummy).shape)
model.to(device)

Model params (M): 3.583306
Forward pass shape: torch.Size([4, 10])


ViT(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
  )
  (pos_drop): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-7): 8 x TransformerEncoderBlock(
      (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
      )
      (drop1): Dropout(p=0.1, inplace=False)
      (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=768, out_features=192, bias=True)
        (drop): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
  (head): Linear(in_features=192, out_features=10, bias=True)
)

In [None]:
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append((correct_k.mul_(100.0 / batch_size)).item())
        return res

from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()

def train_one_epoch_mixup(model, dataloader, criterion, optimizer, device, epoch, scheduler=None, grad_clip=None, mixup_alpha=0.8, use_amp=True):
    model.train()
    running_loss = 0.0
    running_acc = 0.0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Train Epoch {epoch}")
    for i, (images, targets) in pbar:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if mixup_alpha > 0:
            images, targets_a, lam, targets_b = mixup_data(images, targets, mixup_alpha)
        else:
            targets_a, targets_b, lam = targets, targets, 1.0

        optimizer.zero_grad()
        if use_amp:
            with autocast():
                outputs = model(images)
                loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            scaler.scale(loss).backward()
            if grad_clip:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(images)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            loss.backward()
            if grad_clip:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

        if scheduler is not None:
            if hasattr(scheduler, "step"):
                scheduler.step()
            elif callable(scheduler):
                scheduler()

        with torch.no_grad():
            _, preds = outputs.topk(1, dim=1)
            preds = preds.squeeze(1)
            correct = (preds == targets_a).float() * lam + (preds == targets_b).float() * (1 - lam)
            acc1 = correct.mean().item() * 100.0

        running_loss = (running_loss * i + loss.item()) / (i + 1)
        running_acc = (running_acc * i + acc1) / (i + 1)
        pbar.set_postfix(loss=running_loss, acc=running_acc)
    return running_loss, running_acc


def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_acc = 0.0
    with torch.no_grad():
        pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Eval")
        for i, (images, targets) in pbar:
            images = images.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            outputs = model(images)
            loss = criterion(outputs, targets)
            acc1 = accuracy(outputs, targets, topk=(1,))[0]
            total_loss = (total_loss * i + loss.item()) / (i + 1)
            total_acc = (total_acc * i + acc1) / (i + 1)
            pbar.set_postfix(loss=total_loss, acc=total_acc)
    return total_loss, total_acc


  scaler = GradScaler()


In [None]:
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
    def forward(self, x, target):
        log_probs = F.log_softmax(x, dim=-1)
        nll = -log_probs.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        loss = (1.0 - self.smoothing) * nll + self.smoothing * smooth_loss
        return loss.mean()

In [None]:
from torch.optim import AdamW

model = model.to(device)
criterion = LabelSmoothingCrossEntropy(smoothing=0.1)
optimizer = AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])

total_steps = len(train_loader) * config["epochs"]
warmup_steps = config.get("warmup_steps", 500)

scheduler = BatchWarmupCosineScheduler(optimizer, base_lr=config["lr"], total_steps=total_steps, warmup_steps=warmup_steps, min_lr=1e-6)

best_acc = 0.0
history = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}

for epoch in range(1, config["epochs"] + 1):
    train_loss, train_acc = train_one_epoch_mixup(model, train_loader, criterion, optimizer, device, epoch, scheduler=scheduler, grad_clip=config.get("grad_clip"))
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    history["train_loss"].append(train_loss); history["train_acc"].append(train_acc)
    history["test_loss"].append(test_loss); history["test_acc"].append(test_acc)

    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), "best_vit_cifar10.pth")
        print(f" => Saved best model with acc: {best_acc:.2f}%")
    print(f"Epoch {epoch} summary -> train_acc: {train_acc:.2f}%, test_acc: {test_acc:.2f}%")



  with autocast():
Train Epoch 1: 100%|██████████| 391/391 [00:43<00:00,  8.97it/s, acc=20.2, loss=2.18]
Eval: 100%|██████████| 79/79 [00:04<00:00, 18.21it/s, acc=35.7, loss=1.91]


 => Saved best model with acc: 35.71%
Epoch 1 summary -> train_acc: 20.21%, test_acc: 35.71%


Train Epoch 2: 100%|██████████| 391/391 [00:46<00:00,  8.37it/s, acc=27.5, loss=2.06]
Eval: 100%|██████████| 79/79 [00:03<00:00, 21.54it/s, acc=41.1, loss=1.78]


 => Saved best model with acc: 41.14%
Epoch 2 summary -> train_acc: 27.54%, test_acc: 41.14%


Train Epoch 3: 100%|██████████| 391/391 [00:43<00:00,  8.97it/s, acc=32.5, loss=1.98]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.63it/s, acc=47.1, loss=1.64]


 => Saved best model with acc: 47.13%
Epoch 3 summary -> train_acc: 32.53%, test_acc: 47.13%


Train Epoch 4: 100%|██████████| 391/391 [00:44<00:00,  8.86it/s, acc=36, loss=1.92]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.85it/s, acc=50.8, loss=1.56]


 => Saved best model with acc: 50.84%
Epoch 4 summary -> train_acc: 35.99%, test_acc: 50.84%


Train Epoch 5: 100%|██████████| 391/391 [00:43<00:00,  9.02it/s, acc=37.4, loss=1.89]
Eval: 100%|██████████| 79/79 [00:03<00:00, 22.07it/s, acc=53.6, loss=1.52]


 => Saved best model with acc: 53.62%
Epoch 5 summary -> train_acc: 37.43%, test_acc: 53.62%


Train Epoch 6: 100%|██████████| 391/391 [00:42<00:00,  9.12it/s, acc=38.9, loss=1.87]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.98it/s, acc=57.3, loss=1.44]


 => Saved best model with acc: 57.33%
Epoch 6 summary -> train_acc: 38.90%, test_acc: 57.33%


Train Epoch 7: 100%|██████████| 391/391 [00:44<00:00,  8.85it/s, acc=39.4, loss=1.86]
Eval: 100%|██████████| 79/79 [00:03<00:00, 20.57it/s, acc=56.9, loss=1.47]

Epoch 7 summary -> train_acc: 39.41%, test_acc: 56.91%



Train Epoch 8: 100%|██████████| 391/391 [00:51<00:00,  7.64it/s, acc=40.9, loss=1.83]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.41it/s, acc=58.1, loss=1.43]


 => Saved best model with acc: 58.11%
Epoch 8 summary -> train_acc: 40.89%, test_acc: 58.11%


Train Epoch 9: 100%|██████████| 391/391 [00:45<00:00,  8.67it/s, acc=41.3, loss=1.83]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.40it/s, acc=60.7, loss=1.38]


 => Saved best model with acc: 60.74%
Epoch 9 summary -> train_acc: 41.34%, test_acc: 60.74%


Train Epoch 10: 100%|██████████| 391/391 [00:43<00:00,  8.89it/s, acc=42.4, loss=1.8]
Eval: 100%|██████████| 79/79 [00:03<00:00, 23.57it/s, acc=60.8, loss=1.36]


 => Saved best model with acc: 60.76%
Epoch 10 summary -> train_acc: 42.42%, test_acc: 60.76%


Train Epoch 11: 100%|██████████| 391/391 [00:43<00:00,  8.89it/s, acc=42.8, loss=1.8]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.14it/s, acc=62.5, loss=1.33]


 => Saved best model with acc: 62.45%
Epoch 11 summary -> train_acc: 42.83%, test_acc: 62.45%


Train Epoch 12: 100%|██████████| 391/391 [00:44<00:00,  8.74it/s, acc=43.1, loss=1.79]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.55it/s, acc=63, loss=1.32]


 => Saved best model with acc: 63.04%
Epoch 12 summary -> train_acc: 43.12%, test_acc: 63.04%


Train Epoch 13: 100%|██████████| 391/391 [00:44<00:00,  8.70it/s, acc=43.9, loss=1.78]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.96it/s, acc=63.6, loss=1.31]


 => Saved best model with acc: 63.59%
Epoch 13 summary -> train_acc: 43.94%, test_acc: 63.59%


Train Epoch 14: 100%|██████████| 391/391 [00:43<00:00,  8.97it/s, acc=45, loss=1.76]
Eval: 100%|██████████| 79/79 [00:03<00:00, 24.02it/s, acc=65.3, loss=1.27]


 => Saved best model with acc: 65.35%
Epoch 14 summary -> train_acc: 45.00%, test_acc: 65.35%


Train Epoch 15: 100%|██████████| 391/391 [00:43<00:00,  8.94it/s, acc=45.2, loss=1.76]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.67it/s, acc=65, loss=1.27]

Epoch 15 summary -> train_acc: 45.18%, test_acc: 64.99%



Train Epoch 16: 100%|██████████| 391/391 [00:44<00:00,  8.75it/s, acc=46.1, loss=1.74]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.84it/s, acc=66.5, loss=1.23]


 => Saved best model with acc: 66.51%
Epoch 16 summary -> train_acc: 46.07%, test_acc: 66.51%


Train Epoch 17: 100%|██████████| 391/391 [00:45<00:00,  8.58it/s, acc=46.7, loss=1.73]
Eval: 100%|██████████| 79/79 [00:03<00:00, 22.46it/s, acc=67.4, loss=1.23]


 => Saved best model with acc: 67.41%
Epoch 17 summary -> train_acc: 46.73%, test_acc: 67.41%


Train Epoch 18: 100%|██████████| 391/391 [00:48<00:00,  8.11it/s, acc=48.4, loss=1.7]
Eval: 100%|██████████| 79/79 [00:03<00:00, 23.51it/s, acc=68.2, loss=1.21]


 => Saved best model with acc: 68.19%
Epoch 18 summary -> train_acc: 48.45%, test_acc: 68.19%


Train Epoch 19: 100%|██████████| 391/391 [00:43<00:00,  8.91it/s, acc=47.7, loss=1.71]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.75it/s, acc=69.9, loss=1.19]


 => Saved best model with acc: 69.87%
Epoch 19 summary -> train_acc: 47.74%, test_acc: 69.87%


Train Epoch 20: 100%|██████████| 391/391 [00:44<00:00,  8.82it/s, acc=48.6, loss=1.69]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.59it/s, acc=68.9, loss=1.2]

Epoch 20 summary -> train_acc: 48.63%, test_acc: 68.92%



Train Epoch 21: 100%|██████████| 391/391 [00:43<00:00,  8.97it/s, acc=48.8, loss=1.7]
Eval: 100%|██████████| 79/79 [00:03<00:00, 23.78it/s, acc=70.2, loss=1.17]


 => Saved best model with acc: 70.23%
Epoch 21 summary -> train_acc: 48.80%, test_acc: 70.23%


Train Epoch 22: 100%|██████████| 391/391 [00:43<00:00,  8.89it/s, acc=49.9, loss=1.67]
Eval: 100%|██████████| 79/79 [00:03<00:00, 24.97it/s, acc=68.3, loss=1.2]

Epoch 22 summary -> train_acc: 49.89%, test_acc: 68.26%



Train Epoch 23: 100%|██████████| 391/391 [00:44<00:00,  8.79it/s, acc=50.3, loss=1.66]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.68it/s, acc=72.2, loss=1.13]


 => Saved best model with acc: 72.24%
Epoch 23 summary -> train_acc: 50.32%, test_acc: 72.24%


Train Epoch 24: 100%|██████████| 391/391 [00:44<00:00,  8.79it/s, acc=50, loss=1.67]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.20it/s, acc=72.8, loss=1.13]


 => Saved best model with acc: 72.77%
Epoch 24 summary -> train_acc: 50.00%, test_acc: 72.77%


Train Epoch 25: 100%|██████████| 391/391 [00:43<00:00,  8.95it/s, acc=49.8, loss=1.67]
Eval: 100%|██████████| 79/79 [00:03<00:00, 21.14it/s, acc=72.4, loss=1.12]

Epoch 25 summary -> train_acc: 49.80%, test_acc: 72.45%



Train Epoch 26: 100%|██████████| 391/391 [00:43<00:00,  8.97it/s, acc=50.9, loss=1.66]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.81it/s, acc=72.8, loss=1.12]


 => Saved best model with acc: 72.77%
Epoch 26 summary -> train_acc: 50.86%, test_acc: 72.77%


Train Epoch 27: 100%|██████████| 391/391 [00:44<00:00,  8.84it/s, acc=51.1, loss=1.65]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.56it/s, acc=73.6, loss=1.1]


 => Saved best model with acc: 73.65%
Epoch 27 summary -> train_acc: 51.12%, test_acc: 73.65%


Train Epoch 28: 100%|██████████| 391/391 [00:44<00:00,  8.86it/s, acc=50.9, loss=1.66]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.14it/s, acc=74.8, loss=1.08]


 => Saved best model with acc: 74.78%
Epoch 28 summary -> train_acc: 50.86%, test_acc: 74.78%


Train Epoch 29: 100%|██████████| 391/391 [00:43<00:00,  8.98it/s, acc=52.9, loss=1.62]
Eval: 100%|██████████| 79/79 [00:03<00:00, 21.95it/s, acc=74.3, loss=1.08]

Epoch 29 summary -> train_acc: 52.90%, test_acc: 74.31%



Train Epoch 30: 100%|██████████| 391/391 [00:43<00:00,  8.99it/s, acc=51.4, loss=1.64]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.72it/s, acc=74.5, loss=1.08]

Epoch 30 summary -> train_acc: 51.42%, test_acc: 74.53%



Train Epoch 31: 100%|██████████| 391/391 [00:44<00:00,  8.85it/s, acc=52.3, loss=1.63]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.41it/s, acc=76.1, loss=1.05]


 => Saved best model with acc: 76.10%
Epoch 31 summary -> train_acc: 52.29%, test_acc: 76.10%


Train Epoch 32: 100%|██████████| 391/391 [00:44<00:00,  8.87it/s, acc=54.7, loss=1.58]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.15it/s, acc=76.4, loss=1.04]


 => Saved best model with acc: 76.40%
Epoch 32 summary -> train_acc: 54.66%, test_acc: 76.40%


Train Epoch 33: 100%|██████████| 391/391 [00:43<00:00,  9.07it/s, acc=53.6, loss=1.6]
Eval: 100%|██████████| 79/79 [00:03<00:00, 21.77it/s, acc=76.2, loss=1.05]

Epoch 33 summary -> train_acc: 53.62%, test_acc: 76.21%



Train Epoch 34: 100%|██████████| 391/391 [00:43<00:00,  9.02it/s, acc=53.2, loss=1.6]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.04it/s, acc=76.1, loss=1.05]

Epoch 34 summary -> train_acc: 53.20%, test_acc: 76.08%



Train Epoch 35: 100%|██████████| 391/391 [00:45<00:00,  8.69it/s, acc=53.7, loss=1.6]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.13it/s, acc=77.4, loss=1.02]


 => Saved best model with acc: 77.38%
Epoch 35 summary -> train_acc: 53.69%, test_acc: 77.38%


Train Epoch 36: 100%|██████████| 391/391 [00:45<00:00,  8.67it/s, acc=54.4, loss=1.58]
Eval: 100%|██████████| 79/79 [00:03<00:00, 25.97it/s, acc=77.7, loss=1.02]


 => Saved best model with acc: 77.65%
Epoch 36 summary -> train_acc: 54.44%, test_acc: 77.65%


Train Epoch 37: 100%|██████████| 391/391 [00:45<00:00,  8.60it/s, acc=54.7, loss=1.58]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.52it/s, acc=77.1, loss=1.02]

Epoch 37 summary -> train_acc: 54.72%, test_acc: 77.15%



Train Epoch 38: 100%|██████████| 391/391 [00:44<00:00,  8.86it/s, acc=55, loss=1.57]
Eval: 100%|██████████| 79/79 [00:03<00:00, 22.01it/s, acc=77.3, loss=1.01]

Epoch 38 summary -> train_acc: 55.04%, test_acc: 77.28%



Train Epoch 39: 100%|██████████| 391/391 [00:43<00:00,  9.02it/s, acc=54.9, loss=1.57]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.15it/s, acc=77.6, loss=1.01]

Epoch 39 summary -> train_acc: 54.93%, test_acc: 77.63%



Train Epoch 40: 100%|██████████| 391/391 [00:44<00:00,  8.73it/s, acc=55.5, loss=1.56]
Eval: 100%|██████████| 79/79 [00:03<00:00, 25.95it/s, acc=78.1, loss=1]


 => Saved best model with acc: 78.11%
Epoch 40 summary -> train_acc: 55.46%, test_acc: 78.11%


Train Epoch 41: 100%|██████████| 391/391 [00:44<00:00,  8.71it/s, acc=55.6, loss=1.56]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.78it/s, acc=78.2, loss=1]


 => Saved best model with acc: 78.15%
Epoch 41 summary -> train_acc: 55.65%, test_acc: 78.15%


Train Epoch 42: 100%|██████████| 391/391 [00:44<00:00,  8.70it/s, acc=55.6, loss=1.56]
Eval: 100%|██████████| 79/79 [00:03<00:00, 26.10it/s, acc=78.3, loss=0.998]


 => Saved best model with acc: 78.34%
Epoch 42 summary -> train_acc: 55.61%, test_acc: 78.34%


Train Epoch 43: 100%|██████████| 391/391 [00:43<00:00,  8.97it/s, acc=55.7, loss=1.56]
Eval: 100%|██████████| 79/79 [00:03<00:00, 21.13it/s, acc=77.7, loss=1.01]

Epoch 43 summary -> train_acc: 55.72%, test_acc: 77.75%



Train Epoch 44: 100%|██████████| 391/391 [00:43<00:00,  8.96it/s, acc=56.5, loss=1.54]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.74it/s, acc=78.2, loss=0.997]


Epoch 44 summary -> train_acc: 56.52%, test_acc: 78.18%


Train Epoch 45: 100%|██████████| 391/391 [00:44<00:00,  8.84it/s, acc=56.5, loss=1.54]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.28it/s, acc=78.3, loss=0.995]

Epoch 45 summary -> train_acc: 56.50%, test_acc: 78.31%



Train Epoch 46: 100%|██████████| 391/391 [00:44<00:00,  8.85it/s, acc=55.6, loss=1.55]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.75it/s, acc=78.5, loss=0.99]


 => Saved best model with acc: 78.54%
Epoch 46 summary -> train_acc: 55.62%, test_acc: 78.54%


Train Epoch 47: 100%|██████████| 391/391 [00:43<00:00,  9.00it/s, acc=55.4, loss=1.57]
Eval: 100%|██████████| 79/79 [00:03<00:00, 22.98it/s, acc=78.5, loss=0.99]

Epoch 47 summary -> train_acc: 55.36%, test_acc: 78.48%



Train Epoch 48: 100%|██████████| 391/391 [00:43<00:00,  8.98it/s, acc=56.8, loss=1.54]
Eval: 100%|██████████| 79/79 [00:02<00:00, 26.62it/s, acc=78.5, loss=0.988]

Epoch 48 summary -> train_acc: 56.76%, test_acc: 78.51%



Train Epoch 49: 100%|██████████| 391/391 [00:44<00:00,  8.82it/s, acc=55.8, loss=1.56]
Eval: 100%|██████████| 79/79 [00:02<00:00, 27.15it/s, acc=78.5, loss=0.988]

Epoch 49 summary -> train_acc: 55.82%, test_acc: 78.54%



Train Epoch 50: 100%|██████████| 391/391 [00:44<00:00,  8.77it/s, acc=56.6, loss=1.54]
Eval: 100%|██████████| 79/79 [00:03<00:00, 25.94it/s, acc=78.5, loss=0.988]

Epoch 50 summary -> train_acc: 56.59%, test_acc: 78.48%





In [None]:
best_path = "best_vit_cifar10.pth"
if os.path.exists(best_path):
    model.load_state_dict(torch.load(best_path, map_location=device))
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    print(f"Best model test accuracy: {test_acc:.2f}%")
else:
    print("No saved model found, evaluate current model instead.")
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    print(f"Current model test accuracy: {test_acc:.2f}%")


Eval: 100%|██████████| 79/79 [00:02<00:00, 27.38it/s, acc=78.5, loss=0.99]

Best model test accuracy: 78.54%



