In [1]:
import json
import os

import helper
import torch
from torchvision import datasets
from torchvision import transforms as T
from tqdm import tqdm

torch.manual_seed(2022)

In [2]:
backbones = ["resnet26d", "convnext_nano", "resmlp_12_224", "densenet169"]
transform = T.Compose(
    [
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)
criterion = torch.nn.CrossEntropyLoss().cuda().eval()

result = {}
datalist = os.listdir("dataset")
with tqdm(total=len(backbones) * len(datalist)) as pbar:
    for b in backbones:
        result[b] = {}
        backbone = helper.pretrained(b)
        classifier = helper.generate_classifier(backbone, 8)
        classifier.load_state_dict(torch.load(f"weights/{b}.pt"))
        classifier = classifier.cuda().eval()
        dataset = datasets.ImageFolder("dataset/test", transform=transform)
        dataloader = {"test": torch.utils.data.DataLoader(dataset, 512)}
        result[b]["default"] = helper.evaluate(
            dataloader, classifier, criterion, "test"
        )
        pbar.set_postfix(backbone=b, **result[b]["default"])
        for p in datalist:
            if p in ["train", "valid", "test"]:
                pbar.update()
                continue

            result[b][p] = {}
            path = os.path.join("dataset", p)
            for v in os.listdir(path):
                datadir = os.path.join(path, v)
                dataset = datasets.ImageFolder(datadir, transform=transform)
                dataloader = {"test": torch.utils.data.DataLoader(dataset, 512)}
                result[b][p][v] = helper.evaluate(
                    dataloader, classifier, criterion, "test"
                )
                pbar.set_postfix(backbone=b, perturbation=p, value=v, **result[b][p][v])
            pbar.update()
        del backbone, classifier

100%|██████████| 44/44 [03:48<00:00,  5.20s/it, Accuracy=0.654, F1-Macro=0.653, Loss=1.11, Subset=Test, backbone=densenet169, perturbation=occlusion, value=25]             


In [3]:
with open("logs/evaluation.json", "w") as f:
    f.write(json.dumps(result))