# ViT Baseline Comparison on CIFAR-10

In [1]:
!pip install timm torchvision --quiet


[notice] A new release of pip is available: 25.0.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


## Step 1: Import Libraries

In [2]:
import torch
import torch.nn as nn
import timm
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


## Step 2: CIFAR-10 Dataset

In [3]:
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

## Step 3: Load and Modify Models

In [4]:
def load_vit(model_name):
    model = timm.create_model(model_name, pretrained=True)
    model.head = nn.Linear(model.head.in_features, 10)
    return model

model_small = load_vit('vit_small_patch16_224')
model_large = load_vit('vit_base_patch16_224')

## Step 4: Training Function

In [5]:
def train(model, train_loader, device, epochs=1):
    model.to(device)
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=3e-4)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        total_loss = 0
        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")

## Step 5: Evaluation Function

In [6]:
def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            correct += (preds == targets).sum().item()
            total += targets.size(0)
    accuracy = correct / total
    return accuracy

## Step 6: Train & Evaluate Both Models

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Training ViT-Small...")
train(model_small, train_loader, device, epochs=1)
acc_small = evaluate(model_small, test_loader, device)
print(f"ViT-Small Accuracy: {acc_small:.2%}")

print("\nTraining ViT-Base...")
train(model_large, train_loader, device, epochs=1)
acc_large = evaluate(model_large, test_loader, device)
print(f"ViT-Base Accuracy: {acc_large:.2%}")

Training ViT-Small...


100%|██████████| 782/782 [1:12:59<00:00,  5.60s/it]


Epoch 1, Loss: 0.2638
ViT-Small Accuracy: 93.36%

Training ViT-Base...


100%|██████████| 782/782 [3:54:19<00:00, 17.98s/it]  


Epoch 1, Loss: 1.2682
ViT-Base Accuracy: 71.50%
