In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from pathlib import Path

from torchvision.models import (
    resnet50, ResNet50_Weights,
    vit_b_16, ViT_B_16_Weights
)
import torchvision.transforms.functional as F


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [None]:
resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
old_resnet_w = resnet.conv1.weight.data
resnet.conv1 = nn.Conv2d(
    1, 64, kernel_size=7, stride=2, padding=3, bias=False
)
resnet.conv1.weight.data = old_resnet_w.mean(dim=1, keepdim=True)
resnet = resnet.to(device).eval()

vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT).to(device).eval()

resnet_labels = ResNet50_Weights.DEFAULT.meta["categories"]
vit_labels    = ViT_B_16_Weights.DEFAULT.meta["categories"]

In [5]:
class EditedImagesDataset3channel(Dataset):
    def __init__(self, folder):
        self.folder = Path(folder)
        self.files = sorted(self.folder.glob("*.jpg"))
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
        self.std  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        img = Image.open(img_path)
        img = img.convert("L")
        img = img.convert("RGB")
        img = F.resize(img, [224,224])
        x = F.to_tensor(img)   
        x = (x - self.mean) / self.std            
        y = -1
        return x, y
    
class EditedImagesDataset1channel(Dataset):
    def __init__(self, folder):
        self.folder = Path(folder)
        self.files = sorted(self.folder.glob("*.jpg"))
        self.mean = torch.tensor([0.449]).view(1,1,1)
        self.std  = torch.tensor([0.226]).view(1,1,1)


    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        img = Image.open(img_path)
        img = img.convert("L")
        img = F.resize(img, [224,224])
        x = F.to_tensor(img)      
        x = (x - self.mean) / self.std         
        y = -1
        return x, y

In [6]:
def run_inference(model, dl, name):
    print(f"--- {name} ---")
    if isinstance(model, torch.nn.Module) and "vit" in name.lower():
        labels = vit_labels
    else:
        labels = resnet_labels

    traffic_total = 0
    streetsign_total = 0
    total_preds = 0

    with torch.no_grad():
        for x, _ in dl:
            x = x.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1).cpu().tolist()
            class_names = [labels[i] for i in preds]
            traffic_total += sum(1 for x in class_names if x == "traffic light")
            streetsign_total += sum(1 for x in class_names if x == "street sign")
            total_preds += len(class_names)
            #print(class_names)
        print(f"traffic light: {traffic_total} | street sign: {streetsign_total} | total: {total_preds}")
        print(f"error rate: {(total_preds - (traffic_total + streetsign_total))/total_preds}")
        print(f"Street sign/Traffic light ratio: {streetsign_total/traffic_total}")


In [7]:
test_ds3 = EditedImagesDataset3channel("greyFinal/")
test_ds1 = EditedImagesDataset1channel("greyFinal/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")


--- ResNet50 ---
traffic light: 106 | street sign: 169 | total: 346
error rate: 0.20520231213872833
Street sign/Traffic light ratio: 1.5943396226415094
--- ViT-B/16 ---
traffic light: 53 | street sign: 197 | total: 346
error rate: 0.2774566473988439
Street sign/Traffic light ratio: 3.7169811320754715


In [8]:
test_ds3 = EditedImagesDataset3channel("smallFinal/")
test_ds1 = EditedImagesDataset1channel("smallFinal/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")


--- ResNet50 ---
traffic light: 82 | street sign: 124 | total: 346
error rate: 0.4046242774566474
Street sign/Traffic light ratio: 1.5121951219512195
--- ViT-B/16 ---
traffic light: 51 | street sign: 167 | total: 346
error rate: 0.3699421965317919
Street sign/Traffic light ratio: 3.2745098039215685


In [9]:
test_ds3 = EditedImagesDataset3channel("histeqFinal/")
test_ds1 = EditedImagesDataset1channel("histeqFinal/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")


--- ResNet50 ---
traffic light: 76 | street sign: 192 | total: 346
error rate: 0.2254335260115607
Street sign/Traffic light ratio: 2.526315789473684
--- ViT-B/16 ---
traffic light: 49 | street sign: 198 | total: 346
error rate: 0.2861271676300578
Street sign/Traffic light ratio: 4.040816326530612


In [10]:
test_ds3 = EditedImagesDataset3channel("gammaFinal/")
test_ds1 = EditedImagesDataset1channel("gammaFinal/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 111 | street sign: 147 | total: 346
error rate: 0.2543352601156069
Street sign/Traffic light ratio: 1.3243243243243243
--- ViT-B/16 ---
traffic light: 54 | street sign: 196 | total: 346
error rate: 0.2774566473988439
Street sign/Traffic light ratio: 3.6296296296296298


In [11]:
test_ds3 = EditedImagesDataset3channel("deblurFinal/")
test_ds1 = EditedImagesDataset1channel("deblurFinal/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 106 | street sign: 169 | total: 346
error rate: 0.20520231213872833
Street sign/Traffic light ratio: 1.5943396226415094
--- ViT-B/16 ---
traffic light: 53 | street sign: 197 | total: 346
error rate: 0.2774566473988439
Street sign/Traffic light ratio: 3.7169811320754715


In [12]:
test_ds3 = EditedImagesDataset3channel("unsharpFinal/")
test_ds1 = EditedImagesDataset1channel("unsharpFinal/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 119 | street sign: 142 | total: 346
error rate: 0.24566473988439305
Street sign/Traffic light ratio: 1.1932773109243697
--- ViT-B/16 ---
traffic light: 74 | street sign: 163 | total: 346
error rate: 0.315028901734104
Street sign/Traffic light ratio: 2.2027027027027026


In [13]:
test_ds3 = EditedImagesDataset3channel("medianFinal/")
test_ds1 = EditedImagesDataset1channel("medianFinal/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 86 | street sign: 168 | total: 346
error rate: 0.2658959537572254
Street sign/Traffic light ratio: 1.9534883720930232
--- ViT-B/16 ---
traffic light: 56 | street sign: 190 | total: 346
error rate: 0.28901734104046245
Street sign/Traffic light ratio: 3.392857142857143


In [14]:
test_ds3 = EditedImagesDataset3channel("MAFinal/")
test_ds1 = EditedImagesDataset1channel("MAFinal/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 90 | street sign: 165 | total: 346
error rate: 0.2630057803468208
Street sign/Traffic light ratio: 1.8333333333333333
--- ViT-B/16 ---
traffic light: 56 | street sign: 194 | total: 346
error rate: 0.2774566473988439
Street sign/Traffic light ratio: 3.4642857142857144


In [15]:
test_ds3 = EditedImagesDataset3channel("waveletFinal/")
test_ds1 = EditedImagesDataset1channel("waveletFinal/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 102 | street sign: 173 | total: 346
error rate: 0.20520231213872833
Street sign/Traffic light ratio: 1.696078431372549
--- ViT-B/16 ---
traffic light: 52 | street sign: 196 | total: 346
error rate: 0.2832369942196532
Street sign/Traffic light ratio: 3.769230769230769


In [16]:
test_ds3 = EditedImagesDataset3channel("spectralFinal/")
test_ds1 = EditedImagesDataset1channel("spectralFinal/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 103 | street sign: 129 | total: 346
error rate: 0.32947976878612717
Street sign/Traffic light ratio: 1.2524271844660195
--- ViT-B/16 ---
traffic light: 62 | street sign: 146 | total: 346
error rate: 0.3988439306358382
Street sign/Traffic light ratio: 2.3548387096774195


In [17]:
test_ds3 = EditedImagesDataset3channel("bsplineFinal/")
test_ds1 = EditedImagesDataset1channel("bsplineFinal/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 81 | street sign: 134 | total: 346
error rate: 0.3786127167630058
Street sign/Traffic light ratio: 1.654320987654321
--- ViT-B/16 ---
traffic light: 61 | street sign: 153 | total: 346
error rate: 0.3815028901734104
Street sign/Traffic light ratio: 2.5081967213114753


In [18]:
test_ds3 = EditedImagesDataset3channel("bsaiFinal/")
test_ds1 = EditedImagesDataset1channel("bsaiFinal/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 68 | street sign: 138 | total: 346
error rate: 0.4046242774566474
Street sign/Traffic light ratio: 2.0294117647058822
--- ViT-B/16 ---
traffic light: 57 | street sign: 161 | total: 346
error rate: 0.3699421965317919
Street sign/Traffic light ratio: 2.824561403508772


In [19]:
test_ds3 = EditedImagesDataset3channel("chain1Final/")
test_ds1 = EditedImagesDataset1channel("chain1Final/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 83 | street sign: 114 | total: 346
error rate: 0.430635838150289
Street sign/Traffic light ratio: 1.3734939759036144
--- ViT-B/16 ---
traffic light: 63 | street sign: 117 | total: 346
error rate: 0.4797687861271676
Street sign/Traffic light ratio: 1.8571428571428572


In [20]:
test_ds3 = EditedImagesDataset3channel("chain2Final/")
test_ds1 = EditedImagesDataset1channel("chain2Final/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 82 | street sign: 123 | total: 346
error rate: 0.407514450867052
Street sign/Traffic light ratio: 1.5
--- ViT-B/16 ---
traffic light: 72 | street sign: 110 | total: 346
error rate: 0.47398843930635837
Street sign/Traffic light ratio: 1.5277777777777777


In [21]:
test_ds3 = EditedImagesDataset3channel("chain3Final/")
test_ds1 = EditedImagesDataset1channel("chain3Final/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 84 | street sign: 85 | total: 346
error rate: 0.5115606936416185
Street sign/Traffic light ratio: 1.0119047619047619
--- ViT-B/16 ---
traffic light: 72 | street sign: 57 | total: 346
error rate: 0.6271676300578035
Street sign/Traffic light ratio: 0.7916666666666666


In [22]:
test_ds3 = EditedImagesDataset3channel("chain4Final/")
test_ds1 = EditedImagesDataset1channel("chain4Final/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 53 | street sign: 148 | total: 346
error rate: 0.4190751445086705
Street sign/Traffic light ratio: 2.792452830188679
--- ViT-B/16 ---
traffic light: 51 | street sign: 130 | total: 346
error rate: 0.476878612716763
Street sign/Traffic light ratio: 2.549019607843137


In [23]:
test_ds3 = EditedImagesDataset3channel("small5Final/")
test_ds1 = EditedImagesDataset1channel("small5Final/")
test_res = DataLoader(test_ds1, batch_size=8, shuffle=False)
test_vit = DataLoader(test_ds3, batch_size=8, shuffle=False)
run_inference(resnet, test_res, "ResNet50")
run_inference(vit, test_vit, "ViT-B/16")

--- ResNet50 ---
traffic light: 57 | street sign: 130 | total: 346
error rate: 0.4595375722543353
Street sign/Traffic light ratio: 2.280701754385965
--- ViT-B/16 ---
traffic light: 46 | street sign: 123 | total: 346
error rate: 0.5115606936416185
Street sign/Traffic light ratio: 2.6739130434782608
