In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from pathlib import Path

from torchvision.models import (
    resnet50, ResNet50_Weights,
    vit_b_16, ViT_B_16_Weights
)
import torchvision.transforms.functional as F


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [5]:
resnet = resnet50(weights=ResNet50_Weights.DEFAULT).to(device).eval()
vit    = vit_b_16(weights=ViT_B_16_Weights.DEFAULT).to(device).eval()
resnet_labels = ResNet50_Weights.DEFAULT.meta["categories"]
vit_labels    = ViT_B_16_Weights.DEFAULT.meta["categories"]

In [6]:
class EditedImagesDataset(Dataset):
    def __init__(self, folder):
        self.folder = Path(folder)
        self.files = sorted(self.folder.glob("*.jpg"))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        img = Image.open(img_path).convert("RGB")
        img = F.resize(img, [224,224])
        x = F.to_tensor(img)               
        y = -1
        return x, y

In [7]:
test_ds = EditedImagesDataset("edited/")
test_dl = DataLoader(test_ds, batch_size=8, shuffle=False)

In [10]:
def run_inference(model, dl, name):
    print(f"--- {name} ---")
    if isinstance(model, torch.nn.Module) and "vit" in name.lower():
        labels = vit_labels
    else:
        labels = resnet_labels

    traffic_total = 0
    streetsign_total = 0
    total_preds = 0

    with torch.no_grad():
        for x, _ in dl:
            x = x.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1).cpu().tolist()
            class_names = [labels[i] for i in preds]
            traffic_total += sum(1 for x in class_names if x == "traffic light")
            streetsign_total += sum(1 for x in class_names if x == "street sign")
            total_preds += len(class_names)
            print(class_names)
        print(f"traffic light: {traffic_total} | street sign: {streetsign_total} | total: {total_preds}")


In [11]:
run_inference(resnet, test_dl, "ResNet50")
run_inference(vit, test_dl, "ViT-B/16")

--- ResNet50 ---
['traffic light', 'traffic light', 'street sign', 'ice lolly', 'traffic light', 'candle', 'picket fence', 'traffic light']
['traffic light', 'book jacket', 'traffic light', 'traffic light', 'traffic light', 'traffic light', 'traffic light', 'traffic light']
['traffic light', 'traffic light', 'switch', 'switch', 'shower curtain', 'shower curtain', 'traffic light', 'window shade']
['prison', 'traffic light', 'traffic light', 'traffic light', 'bolo tie', 'pickelhaube', 'traffic light', 'cowboy hat']
['traffic light', 'parking meter', 'pop bottle', 'pop bottle', 'bottlecap', 'car mirror', 'traffic light', 'traffic light']
['hourglass', 'street sign', 'cassette', 'toilet tissue', 'street sign', 'traffic light', 'birdhouse', 'street sign']
['mobile home', 'street sign', 'medicine chest']
traffic light: 22 | street sign: 5 | total: 51
--- ViT-B/16 ---
['traffic light', 'digital clock', 'street sign', 'switch', 'traffic light', 'digital clock', 'switch', 'iPod']
['street sign'