In [1]:
import torch
import timm
import os
import random
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

# Configurable ViT models to benchmark
vit_models = {
    "ViT-B/16": {"timm_name": "vit_base_patch16_384", "image_size": 384},
    "ViT-L/16": {"timm_name": "vit_large_patch16_384", "image_size": 384},
}

# Root dataset directory
data_root = "."

# Device config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Confirm configuration
print("Models to run:")
for name, meta in vit_models.items():
    print(f" - {name}: {meta['timm_name']} | {meta['image_size']}x{meta['image_size']}")

print(f"\nDataset root: {data_root}")
print(f"Device: {device}")


Models to run:
 - ViT-B/16: vit_base_patch16_384 | 384x384
 - ViT-L/16: vit_large_patch16_384 | 384x384

Dataset root: .
Device: cuda


In [2]:
# Cell 2: OxfordPetsDataset and DataLoader
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import os
import random

class OxfordPetsDataset(Dataset):
    def __init__(self, root_dir, split='train', image_size=384, transform=None, split_ratio=0.8, seed=42):
        self.root_dir = root_dir
        self.split = split
        self.image_size = image_size

        image_dir = os.path.join(root_dir, "images")
        all_files = sorted([
            f[:-4] for f in os.listdir(image_dir)
            if f.endswith(".jpg")
        ])

        random.seed(seed)
        random.shuffle(all_files)
        split_idx = int(len(all_files) * split_ratio)
        self.image_ids = all_files[:split_idx] if split == 'train' else all_files[split_idx:]

        self.class_names = sorted(list(set([img_id.rsplit('_', 1)[0] for img_id in all_files])))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.class_names)}

        self.transform = transform if transform else transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_path = os.path.join(self.root_dir, 'images', f'{img_id}.jpg')
        image = Image.open(img_path).convert('RGB')

        class_name = img_id.rsplit('_', 1)[0]
        label = self.class_to_idx[class_name]

        image = self.transform(image)
        return image, label

def get_loaders(root_dir, image_size=384, batch_size=64, num_workers=0, seed=42):
    train_dataset = OxfordPetsDataset(root_dir, split='train', image_size=image_size, seed=seed)
    val_dataset   = OxfordPetsDataset(root_dir, split='val', image_size=image_size, seed=seed)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader   = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader


In [3]:
import torch.nn as nn

def load_model(model_name, num_classes):
    model_info = vit_models[model_name]
    timm_name = model_info["timm_name"]

    # Load pretrained ViT model from timm
    model = timm.create_model(timm_name, pretrained=True)

    # Replace classifier head
    if hasattr(model, "head") and isinstance(model.head, nn.Linear):
        in_features = model.head.in_features
        model.head = nn.Linear(in_features, num_classes)
    else:
        raise ValueError("Unexpected classifier head structure.")

    return model.to(device)


In [4]:
import time
from tqdm import tqdm

def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    start_time = time.time()

    loop = tqdm(dataloader, desc="Training", leave=False)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        loop.set_postfix(loss=loss.item(), acc=100.0 * correct / total)

    avg_loss = total_loss / total
    accuracy = 100.0 * correct / total
    elapsed = time.time() - start_time
    return avg_loss, accuracy, elapsed

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    start_time = time.time()

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Validating", leave=False):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            _, preds = outputs.max(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_loss = total_loss / total
    accuracy = 100.0 * correct / total
    elapsed = time.time() - start_time
    return avg_loss, accuracy, elapsed


In [5]:
from torch import nn, optim
import pandas as pd
import os
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
import time
import warnings
import contextlib
import io
from fvcore.nn import FlopCountAnalysis
import gc

results = []
num_epochs = 5
learning_rate = 1e-3
batch_size = 64
num_workers = 0

for model_name, config in vit_models.items():
    try:
        print(f"\n🔍 Running model: {model_name}")

        image_size = config["image_size"]
        train_loader, val_loader = get_loaders(data_root, image_size=image_size, batch_size=batch_size, num_workers=num_workers)

        # Initialize ViT model using load_model()
        num_classes = len(train_loader.dataset.class_to_idx)
        model = load_model(model_name, num_classes)
        model.eval()  # Ensure eval mode before FLOPs analysis

        # Compute FLOPs using fvcore, suppressing warnings and stderr
        sample_input = torch.randn(1, 3, image_size, image_size).to(device)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            with contextlib.redirect_stderr(io.StringIO()):
                flops = FlopCountAnalysis(model, sample_input).total() / 1e9  # GFLOPs
        del sample_input
        torch.cuda.empty_cache()

        # Optimizer and loss
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
        criterion = nn.CrossEntropyLoss()

        # Store epoch history
        history = []

        for epoch in range(num_epochs):
            print(f"Epoch {epoch+1}/{num_epochs}")

            try:
                train_loss, train_acc, train_time = train(model, train_loader, criterion, optimizer, device)
                val_loss, val_acc, val_time = validate(model, val_loader, criterion, device)
            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    print(f"❌ OOM on model {model_name}, skipping...")
                    torch.cuda.empty_cache()
                    break
                else:
                    raise

            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"Val Loss:   {val_loss:.4f}, Val Acc:   {val_acc:.2f}%")
            print(f"Train Time: {train_time:.2f}s, Val Time: {val_time:.2f}s\n")

            history.append((train_loss, train_acc, val_loss, val_acc, train_time, val_time))

        if history:  # Only report if training ran successfully
            final_metrics = history[-1]
            params_m = round(sum(p.numel() for p in model.parameters()) / 1e6)
            flops_g = round(flops, 1)
            throughput = round(len(val_loader.dataset) / final_metrics[5], 1)
            top1_acc = round(final_metrics[3], 1)

            results.append({
                "method": model_name,
                "image size": f"{image_size}²",
                "#params": f"{params_m}M",
                "FLOPs": f"{flops_g}G",
                "throughput (image / s)": throughput,
                "ImageNet top-1 acc.": top1_acc,
            })

    finally:
        del model, train_loader, val_loader
        torch.cuda.empty_cache()
        gc.collect()

# Final table
results_df = pd.DataFrame(results)
print("\n✅ ViT Benchmark Summary:")
display(results_df)



🔍 Running model: ViT-B/16
Epoch 1/5


                                                                              

Train Loss: 4.2633, Train Acc: 2.81%
Val Loss:   3.7855, Val Acc:   2.98%
Train Time: 115.26s, Val Time: 17.81s

Epoch 2/5


                                                                              

Train Loss: 3.7015, Train Acc: 4.11%
Val Loss:   3.6951, Val Acc:   4.94%
Train Time: 86.30s, Val Time: 10.77s

Epoch 3/5


                                                                              

Train Loss: 3.5963, Train Acc: 4.75%
Val Loss:   3.5638, Val Acc:   5.48%
Train Time: 85.25s, Val Time: 10.30s

Epoch 4/5


                                                                              

Train Loss: 3.5225, Train Acc: 6.21%
Val Loss:   3.5356, Val Acc:   5.75%
Train Time: 85.12s, Val Time: 10.35s

Epoch 5/5


                                                                              

Train Loss: 3.4614, Train Acc: 6.73%
Val Loss:   3.5491, Val Acc:   4.19%
Train Time: 86.16s, Val Time: 10.48s


🔍 Running model: ViT-L/16
Epoch 1/5


                                                

❌ OOM on model ViT-L/16, skipping...

✅ ViT Benchmark Summary:


Unnamed: 0,method,image size,#params,FLOPs,throughput (image / s),ImageNet top-1 acc.
0,ViT-B/16,384²,86M,49.4G,141.0,4.2
