# Evaluation of Vision Transformers on CIFAR-10 (corrupted)



In [15]:
# Install required package
!pip install timm --quiet


[notice] A new release of pip is available: 23.2.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [16]:
# Import packages

import timm
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import TensorDataset, DataLoader
import timm
from PIL import Image

In [17]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

#define transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

Using device: cuda


## Load and preprocess CIFAR-10C

In [18]:
CIFAR10C_PATH = "./cifar-10-c"  # download & extract from https://zenodo.org/record/2535967
corruptions = [
    'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur',
    'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 'brightness', 'contrast',
    'elastic_transform', 'pixelate', 'jpeg_compression'
]

# ==== Custom Dataset ====
from torch.utils.data import Dataset, DataLoader
class CIFAR10CDataset(Dataset):
    def __init__(self, data_np, labels_np, transform=None):
        self.data = data_np
        self.labels = labels_np
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = Image.fromarray(self.data[idx])
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

# ==== Load Pretrained ViT Model ====
model = timm.create_model('vit_tiny_patch16_224', pretrained=False, num_classes=10)
model.load_state_dict(torch.load("vit_cifar10.pth"))  # path to trained ViT model
model = model.to(device)
model.eval()

  model.load_state_dict(torch.load("vit_cifar10.pth"))  # path to your trained ViT model


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()


## Load trained models

In [19]:
def evaluate_corruption(data_np, labels_np):
    dataset = CIFAR10CDataset(data_np, labels_np, transform=transform)
    loader = DataLoader(dataset, batch_size=64, shuffle=False, pin_memory=True)

    correct, total = 0, 0
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(device)
            lbls = lbls.to(device)
            outputs = model(imgs)
            preds = outputs.argmax(dim=1)
            correct += (preds == lbls).sum().item()
            total += lbls.size(0)
    return 100 * correct / total

# ==== Load Common Labels (Same for All) ====
labels = np.load(os.path.join(CIFAR10C_PATH, "labels.npy"))


## load labels and define eval function

In [20]:
# ==== Evaluate All Corruptions ====
results = {}
for name in corruptions:
    print(f"Evaluating: {name}")
    data_np = np.load(os.path.join(CIFAR10C_PATH, f"{name}.npy"))
    acc = evaluate_corruption(data_np, labels)
    results[name] = acc
    print(f"{name}: {acc:.2f}%")

# ==== Summary ====
avg_acc = np.mean(list(results.values()))
print("\n===== CIFAR-10-C Evaluation Complete =====")
for k, v in results.items():
    print(f"{k:<20}: {v:.2f}%")
print(f"\nAverage Corruption Accuracy: {avg_acc:.2f}%")

Evaluating: gaussian_noise
gaussian_noise: 51.27%
Evaluating: shot_noise
shot_noise: 61.22%
Evaluating: impulse_noise
impulse_noise: 63.99%
Evaluating: defocus_blur
defocus_blur: 88.41%
Evaluating: glass_blur
glass_blur: 67.66%
Evaluating: motion_blur
motion_blur: 82.21%
Evaluating: zoom_blur
zoom_blur: 84.39%
Evaluating: snow
snow: 88.04%
Evaluating: frost
frost: 86.91%
Evaluating: fog
fog: 86.77%
Evaluating: brightness
brightness: 93.94%
Evaluating: contrast
contrast: 84.50%
Evaluating: elastic_transform
elastic_transform: 84.84%
Evaluating: pixelate
pixelate: 85.19%
Evaluating: jpeg_compression
jpeg_compression: 78.87%

===== CIFAR-10-C Evaluation Complete =====
gaussian_noise      : 51.27%
shot_noise          : 61.22%
impulse_noise       : 63.99%
defocus_blur        : 88.41%
glass_blur          : 67.66%
motion_blur         : 82.21%
zoom_blur           : 84.39%
snow                : 88.04%
frost               : 86.91%
fog                 : 86.77%
brightness          : 93.94%
contras