In [3]:
import torch
import torchvision.models as models
from torchvision import transforms
from imagenetv2_pytorch import ImageNetV2Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [18]:
# 1. Custom Dataset Wrapper to handle PIL Images
class PreprocessedImageNetV2(torch.utils.data.Dataset):
    def __init__(self, variant="matched-frequency"):
        self.dataset = ImageNetV2Dataset(variant)
        self.preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225]),
        ])
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        return self.preprocess(img), label

# 2. Load pretrained ResNet18
model = models.resnet18(pretrained=True)
model = model.to(device)
model.eval()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [19]:
# 3. Load and preprocess dataset
print("Loading ImageNetV2 dataset...")
dataset = PreprocessedImageNetV2("matched-frequency")  # Also supports "threshold-0.7", "top-images"
dataloader = DataLoader(dataset, batch_size=64, num_workers=4)

Loading ImageNetV2 dataset...


In [20]:
# 4. Evaluation function
def evaluate_model(model, dataloader):
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

# 5. Run evaluation
print("Starting evaluation...")
accuracy = evaluate_model(model, dataloader)
print(f"\nResNet18 Top-1 Accuracy on ImageNetV2: {accuracy:.2f}%")

Starting evaluation...


Evaluating: 100%|██████████| 157/157 [00:10<00:00, 14.36it/s]


ResNet18 Top-1 Accuracy on ImageNetV2: 57.29%





In [21]:
def top5_accuracy(model, dataloader):
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            _, top5 = outputs.topk(5, 1, True, True)
            correct += top5.eq(labels.view(-1, 1)).sum().item()
            total += labels.size(0)
    
    return 100 * correct / total

print(f"Top-5 Accuracy: {top5_accuracy(model, dataloader):.2f}%")

Evaluating: 100%|██████████| 157/157 [00:10<00:00, 14.41it/s]

Top-5 Accuracy: 79.91%





In [22]:
from collections import defaultdict

def class_accuracy(model, dataloader):
    class_correct = defaultdict(int)
    class_total = defaultdict(int)
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            for label, pred in zip(labels, predicted):
                class_total[label.item()] += 1
                if label == pred:
                    class_correct[label.item()] += 1
    
    return {cls: 100 * class_correct[cls]/class_total[cls] 
            for cls in class_total}

class_acc = class_accuracy(model, dataloader)
print(f"Best class accuracy: {max(class_acc.values()):.2f}%")
print(f"Worst class accuracy: {min(class_acc.values()):.2f}%")


Evaluating: 100%|██████████| 157/157 [00:10<00:00, 14.31it/s]

Best class accuracy: 100.00%
Worst class accuracy: 0.00%



