# Assignment: Vision Transformers on CIFAR10

In [6]:
#imports
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils


In [7]:
#loading the dataset
dataset = dset.CIFAR10(root="./data", download=True,
                           transform=transforms.Compose([
                               transforms.Resize(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
nc=3

dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                         shuffle=True, num_workers=2)


In [8]:
#checking the availability of cuda devices
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Tasks:
* try to get the best test Accuracy on Cifar10 using a transformer model
* pre-trained models allowed


In [9]:
#install
!pip install transformers



In [14]:
# Vision Transformer on CIFAR-10 kompletter Notebook-Code


# 0) Imports & Dynamo-Bypass
import torch
# Deaktiviere Torch Dynamo, um Circular-Import in Transformers zu verhindern
if hasattr(torch, '_dynamo'):
    try:
        torch._dynamo.disable()
    except Exception:
        pass

from __future__ import print_function
import os
import random
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTConfig, get_linear_schedule_with_warmup
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils

from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTConfig, get_linear_schedule_with_warmup

# 1) Datenvorbereitung
# Lade FeatureExtractor für Mean/Std und Patch-Size
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

transform = transforms.Compose([
    transforms.Resize(64),               # CIFAR-ähnliche Auflösung 64x64
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=feature_extractor.image_mean,
        std=feature_extractor.image_std
    )
])

dataset = dset.CIFAR10(root="./data", train=True, download=True, transform=transform)

# DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=128,
    shuffle=True,
    num_workers=2
)

# 2) Device-Check
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# 3) Modellinitialisierung mit 64x64-Konfiguration
torch.backends.cudnn.benchmark = True

config = ViTConfig.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    image_size=64,
    num_labels=10,
    id2label={i: cls for i, cls in enumerate(dataset.classes)},
    label2id={cls: i for i, cls in enumerate(dataset.classes)}
)

model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    config=config,
    ignore_mismatched_sizes=True    # interpoliert Position-Embeddings
)
model.to(device)

# 4) Optimizer & Scheduler
epochs = 5
total_steps = len(dataloader) * epochs
optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

# 5) Trainingsschleife
for epoch in range(1, epochs + 1):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for imgs, labels in dataloader:
        imgs, labels = imgs.to(device), labels.to(device)

        outputs = model(pixel_values=imgs, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = running_loss / len(dataloader)
    train_acc = correct / total
    print(f"Epoch {epoch}/{epochs} - Loss: {avg_loss:.4f}, Train Acc: {train_acc:.4f}")

# 6) Evaluation auf dem Testset
# Dataset und DataLoader für Testdaten
test_dataset = dset.CIFAR10(root="./data", train=False, download=True, transform=transform)

test_loader = DataLoader(
    test_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=2
)

# Modell in Eval-Modus versetzen
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(pixel_values=imgs)
        preds = outputs.logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

test_acc = correct / total
print(f"Test Accuracy: {test_acc:.4f}")

# 7) Modell speichern
torch.save(model.state_dict(), 'vit64_cifar10.pth')
print("Training und Evaluation abgeschlossen. Modell gespeichert.")


Using device: cuda


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized because the shapes did not match:
- vit.embeddings.position_embeddings: found shape torch.Size([1, 197, 768]) in the checkpoint and torch.Size([1, 17, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/5 - Loss: 1.3669, Train Acc: 0.5592
Epoch 2/5 - Loss: 0.4419, Train Acc: 0.8682
Epoch 3/5 - Loss: 0.2114, Train Acc: 0.9437
Epoch 4/5 - Loss: 0.0962, Train Acc: 0.9806
Epoch 5/5 - Loss: 0.0505, Train Acc: 0.9938
Test Accuracy: 0.8838
Training und Evaluation abgeschlossen. Modell gespeichert.
