In [None]:
"""Code for loading studies from the database"""

from optuna import load_study
studies = {}
name = "2023_03_11_13_00_00"
n_trials = 2000
categories = ["carpet", "grid", "leather", "tile", "wood", "bottle", "cable", "capsule", "hazelnut", "metal_nut", "pill", "screw", "toothbrush", "transistor", "zipper"]
seed_list = [0, 1, 2]
k_list = [1, 2, 4]
search_types = ["few", "full"]
storage="sqlite:///studies.db"
study_names = []
for k in k_list:
    studies[k] = {}
    for category in categories:
        studies[k][category] = {}
        for search_type in search_types:
            studies[k][category][search_type] = {}
            for seed in seed_list:
                test_set_search = search_type == "full"
                study_name=f"{name}_n{n_trials}_k{k}_s{seed}_{category}_{test_set_search}"
                study = load_study(study_name=study_name, storage=storage)
                studies[k][category][search_type][seed] = study.best_trials
                studies[k][category][search_type][seed].sort(key=lambda trial: trial.values[0], reverse=True)
                print(f"{study_name} loaded")

In [None]:
"""Code for evaluating architectures on k=4 test set, used to generate table 2 in the paper"""

from logging import getLogger, ERROR
import search
from mvtec import MVTecDataModule
from copy import deepcopy
from statistics import mean, stdev
from pytorch_lightning import seed_everything

getLogger("pytorch_lightning").setLevel(ERROR)

auroc_all = []
partial_auroc_all = []
ap_all = []
wAP_all = []
gflops_all = []
stdev_auroc_all = []
stdev_partial_auroc_all = []
stdev_ap_all = []
stdev_wAP_all = []
stdev_gflops_all = []
k=1
search_type = "few"
dataset_dir = "../MVTec"
gpu = 0
for category in categories:
    datamodule = MVTecDataModule(dataset_dir=dataset_dir, category=category, img_size=224, batch_size=391, k=4)
    auroc = []
    partial_auroc = []
    ap = []
    wAP = []
    gflops = []
    for seed, trials in studies[k][category][search_type].items():
        trial = deepcopy(trials[0])
        old_wAP = trial.user_attrs["test_wAP"]
        objectives = search.objective(
            trial, datamodule, dict(accelerator="gpu", devices=[gpu], enable_progress_bar=False), 224,
            #fixed_supernet_name="ofa_mbv3_d234_e346_k357_w1.2", fixed_kernel_size=7, fixed_expand_ratio=6,
        )
        auroc.append(trial.user_attrs["test_AUROC"] * 100)
        partial_auroc.append(trial.user_attrs["test_partial_AUROC"] * 100)
        ap.append(trial.user_attrs["test_AP"] * 100)
        wAP.append(trial.user_attrs["test_wAP"] * 100)
        gflops.append(objectives[0] / 1e9)
    print(f"{category} AUROC: {mean(auroc):.2f} +- {stdev(auroc):.2f}")
    print(f"{category} partial AUROC: {mean(partial_auroc):.2f} +- {stdev(partial_auroc):.2f}")
    print(f"{category} AP: {mean(ap):.1f} +- {stdev(ap):.1f}")
    print(f"{category} wAP: {mean(wAP):.1f} +- {stdev(wAP):.1f}")
    print(f"{category} GFLOPS: {mean(gflops):.2f} +- {stdev(gflops):.2f}")
    auroc_all.append(mean(auroc))
    partial_auroc_all.append(mean(partial_auroc))
    ap_all.append(mean(ap))
    wAP_all.append(mean(wAP))
    gflops_all.append(mean(gflops))
    stdev_auroc_all.append(stdev(auroc))
    stdev_partial_auroc_all.append(stdev(partial_auroc))
    stdev_ap_all.append(stdev(ap))
    stdev_wAP_all.append(stdev(wAP))
    stdev_gflops_all.append(stdev(gflops))
print(f"Mean AUROC: {mean(auroc_all):.2f} +- {mean(stdev_auroc_all):.2f}")
print(f"Mean partial AUROC: {mean(partial_auroc_all):.2f} +- {mean(stdev_partial_auroc_all):.2f}")
print(f"Mean AP: {mean(ap_all):.1f} +- {mean(stdev_ap_all):.1f}")
print(f"Mean wAP: {mean(wAP_all):.1f} +- {mean(stdev_wAP_all):.1f}")
print(f"Mean GFLOPS: {mean(gflops_all):.2f} +- {mean(stdev_gflops_all):.2f}")

In [None]:
""""Code for generating figure 3 in the paper"""

import matplotlib.pyplot as plt
k_list = [1, 2, 4]
num_types = [5, 5, 5, 5, 5, 3, 8, 5, 4, 4, 7, 5, 1, 4, 7]
patchcore_baselines = [[0.5997576117515564, 0.3929395377635956, 0.5172891020774841, 0.6604946851730347, 0.636153519153595, 0.8081642389297485, 0.6672737002372742, 0.35833027958869934, 0.6666530966758728, 0.8983450531959534, 0.7295960187911987, 0.36510705947875977, 0.33493757247924805, 0.6988446116447449, 0.6415698528289795], [0.6075531840324402, 0.3912016451358795, 0.5178176164627075, 0.6669076085090637, 0.6299311518669128, 0.8124648928642273, 0.6639524698257446, 0.383198082447052, 0.6687524318695068, 0.8977358341217041, 0.7215025424957275, 0.35935840010643005, 0.3367730975151062, 0.6897060871124268, 0.6469941139221191], [0.5966348052024841, 0.3844696879386902, 0.5127370357513428, 0.6831071972846985, 0.6313714385032654, 0.810759425163269, 0.6644478440284729, 0.3955104649066925, 0.6875630617141724,  0.8911733031272888, 0.7105074524879456, 0.35898569226264954, 0.3484451472759247, 0.6971675157546997, 0.646364688873291]]
fig, axs = plt.subplots(15, len(k_list), figsize=(16.54,21.25), sharex=True, sharey=True) #23.38
seed = 1
for i_k, k in enumerate(k_list):
    i = 0
    for category in categories:
        few_study = studies[k][category]["few"][seed]
        full_study = studies[k][category]["full"][seed]
        plot = axs[i, i_k]
        x_search_on_k = []
        y_search_on_k = []
        for trial in few_study:
            if "test_wAP" not in trial.user_attrs:
                continue
            x_search_on_k.append(trial._get_values()[0])
            y_search_on_k.append(trial.user_attrs["test_wAP"])
        
        x_search_on_all = []
        y_search_on_all = []
        for trial in full_study:
            if "test_wAP" not in trial.user_attrs:
                continue
            x_search_on_all.append(trial._get_values()[0])
            y_search_on_all.append(trial.user_attrs["test_wAP"])
        
        plot.scatter(x_search_on_all, y_search_on_all, label=f'Test search', alpha=0.5)
        plot.scatter(x_search_on_k, y_search_on_k, label=f'Val search', alpha=0.5)
        plot.axhline(y=patchcore_baselines[i_k][i], linestyle="--", label="PatchCore")
        plot.set_xlim((0, 1e9))
        plot.set_ylim((0, 1.0))
        if i == 14:
            plot.set_xlabel("GFLOPS")
        if i_k == 0:
            plot.set_ylabel("Test wAP (%)")
        if k == k_list[-1] and i == 14:
            plot.legend(loc="upper right")
        plot.set_title(f"{category.capitalize()} ({num_types[i]} types) k={k}", loc="right", y=0)
        plot.set_xticks([0.2 * 1e9, 0.4 * 1e9, 0.6 * 1e9, 0.8 * 1e9])
        plot.set_xticklabels([0.2, 0.4, 0.6, 0.8])
        plot.set_yticks([0.25, 0.5, 0.75])
        plot.set_yticklabels([25, 50, 75])
        i += 1
fig.subplots_adjust(hspace=0, wspace=0)