# Vision Transformers vs. CNN on CIFAR-10

This notebook compares the performance of a Vision Transformer (ViT) with a standard CNN (ResNet18) on the CIFAR-10 image classification task.

In [1]:
# Install required package
!pip install timm --quiet


[notice] A new release of pip is available: 23.2.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
# Import packages
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import timm
import time
import matplotlib.pyplot as plt
import tqdm

In [3]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Load and preprocess CIFAR-10

In [4]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), 
                         (0.229, 0.224, 0.225))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                            download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


## Define training and evaluation utilities

In [5]:
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    return running_loss / len(train_loader), 100 * correct / total

def evaluate(model, test_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

## Train ViT

In [9]:
# Define and move ViT to device
vit = timm.create_model('vit_tiny_patch16_224', pretrained=True, num_classes=10)
vit = vit.to(device)

print("\nTraining ViT Tiny...")
criterion = nn.CrossEntropyLoss()
optimizer_vit = torch.optim.Adam(vit.parameters(), lr=1e-4)
train_losses_vit, train_accs_vit, test_accs_vit = [], [], []

for epoch in range(5):
    vit.train()
    running_loss, correct, total = 0.0, 0, 0
    loop = tqdm.tqdm(train_loader, desc=f"ViT Epoch {epoch+1}", leave=False)

    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)
        optimizer_vit.zero_grad()
        outputs = vit(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_vit.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        loop.set_postfix(loss=loss.item(), acc=100 * correct / total)

    avg_loss = running_loss / len(train_loader)
    train_acc = 100 * correct / total
    test_acc = evaluate(vit, test_loader, device)

    train_losses_vit.append(avg_loss)
    train_accs_vit.append(train_acc)
    test_accs_vit.append(test_acc)

    print(f"ViT Epoch {epoch+1}: Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")

torch.save(vit.state_dict(), "vit_cifar10.pth")



Training ViT Tiny...


                                                                                     

ViT Epoch 1: Train Acc=92.25%, Test Acc=95.80%


                                                                                      

ViT Epoch 2: Train Acc=97.09%, Test Acc=95.77%


                                                                                   

KeyboardInterrupt: 

## Plot and Compare

In [None]:
for metric_idx, metric_name in zip([1, 2], ["Train Accuracy", "Test Accuracy"]):
    plt.figure(figsize=(8, 5))
    for name in results:
        plt.plot(results[name][metric_idx], label=name)
    plt.title(metric_name + " per Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.legend()
    plt.grid(True)
    plt.show()