# Assignment: Vision Transformers on CIFAR10

In [11]:
#imports
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils


In [12]:
#loading the dataset
dataset = dset.CIFAR10(root="./data", download=True,
                           transform=transforms.Compose([
                               transforms.Resize(64),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
nc=3

dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                         shuffle=True, num_workers=2)


In [13]:
#checking the availability of cuda devices
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Tasks:
* try to get the best test Accuracy on Cifar10 using a transformer model
* pre-trained models allowed


In [14]:
#install
!pip install transformers



In [None]:
#imports
from transformers import (
    ViTConfig,
    ViTFeatureExtractor,
    ViTForImageClassification,
    get_linear_schedule_with_warmup,
)
from torch.optim import AdamW

In [None]:
#seeden f. reproduzierbarkeit
seed = 42
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

In [None]:
#bilder vom pretrained model korrekt skalieren/transformioeren
feature_extractor = ViTFeatureExtractor.from_pretrained(
    "google/vit-base-patch16-224-in21k"
)

vit_transform = transforms.Compose(
    [
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=feature_extractor.image_mean, std=feature_extractor.image_std
        ),
    ]
)


In [None]:
#train/test Daten laden
train_dataset = dset.CIFAR10(
    root="./data", train=True, download=False, transform=vit_transform
)
test_dataset = dset.CIFAR10(
    root="./data", train=False, download=False, transform=vit_transform
)

#Dataloader
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=2
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=128, shuffle=False, num_workers=2
)


In [None]:

torch.backends.cudnn.benchmark = True  # schneller auf GPU

config = ViTConfig.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    image_size=64, #img size
    num_labels=10, #klassen
    id2label={i: c for i, c in enumerate(train_dataset.classes)}, #bidirektionales mapping
    label2id={c: i for i, c in enumerate(train_dataset.classes)},
)

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    config=config,
    ignore_mismatched_sizes=True,  #positionembeddings werden interpoliert
).to(device)

In [None]:
#Optimizer-adamw, scheduler m. warmup
epochs = 20
optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.05)
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps,
)

In [15]:
#train u valid.
for ep in range(epochs):
    model.train()
    ep_loss = 0.0
    ep_correct = 0
    ep_samples = 0

    for imgs, lbls in train_loader:
        imgs, lbls = imgs.to(device), lbls.to(device)

        out = model(pixel_values=imgs, labels=lbls)
        loss = out.loss

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        scheduler.step()

        ep_loss += loss.item() * imgs.size(0)
        ep_correct += (out.logits.argmax(dim=-1) == lbls).sum().item()
        ep_samples += lbls.size(0)

    train_acc = ep_correct / ep_samples
    train_loss = ep_loss / ep_samples

    # ─ Validation nach jeder Epoche
    model.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for imgs, lbls in test_loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            preds = model(pixel_values=imgs).logits.argmax(dim=-1)
            val_correct += (preds == lbls).sum().item()
            val_total += lbls.size(0)

    val_acc = val_correct / val_total
    print(
        f"[{ep+1:02}/{epochs}]  "
        f"loss={train_loss:.4f}  "
        f"train_acc={train_acc:.3f}  "
        f"val_acc={val_acc:.3f}"
    )



Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized because the shapes did not match:
- vit.embeddings.position_embeddings: found shape torch.Size([1, 197, 768]) in the checkpoint and torch.Size([1, 17, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[01/20]  loss=2.0393  train_acc=0.320  val_acc=0.621
[02/20]  loss=0.8984  train_acc=0.747  val_acc=0.823
[03/20]  loss=0.4548  train_acc=0.868  val_acc=0.859
[04/20]  loss=0.2553  train_acc=0.932  val_acc=0.868
[05/20]  loss=0.1418  train_acc=0.966  val_acc=0.868
[06/20]  loss=0.0784  train_acc=0.984  val_acc=0.863
[07/20]  loss=0.0477  train_acc=0.992  val_acc=0.875
[08/20]  loss=0.0323  train_acc=0.995  val_acc=0.877
[09/20]  loss=0.0205  train_acc=0.997  val_acc=0.878
[10/20]  loss=0.0153  train_acc=0.998  val_acc=0.879
[11/20]  loss=0.0125  train_acc=0.998  val_acc=0.873
[12/20]  loss=0.0126  train_acc=0.998  val_acc=0.874
[13/20]  loss=0.0092  train_acc=0.998  val_acc=0.874
[14/20]  loss=0.0076  train_acc=0.999  val_acc=0.876
[15/20]  loss=0.0049  train_acc=0.999  val_acc=0.880
[16/20]  loss=0.0033  train_acc=1.000  val_acc=0.881
[17/20]  loss=0.0028  train_acc=1.000  val_acc=0.879
[18/20]  loss=0.0026  train_acc=1.000  val_acc=0.880
[19/20]  loss=0.0024  train_acc=1.000  val_acc