In [1]:
import __init__
import json
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import make_smoothing_spline
from scipy.stats import spearmanr, pearsonr

from phcad.experiments.constants import EXPROOT, NUMSEEDS
from phcad.train.losses import LOSS_MAP, SEG_LOSS_MAP
from phcad.data.constants import DS_TO_LABELS_MAP
from constants import TABLEDIR, FIGDIR

# Detection

In [143]:
detection_dir = EXPROOT / "detection"
datasets = ["fmnist", "cifar10", "mvtec", "mpdd"]
dmap = {k: v for v, k in enumerate(datasets)}
losses = ["dsvdd", "ssim", "bce", "hsc"]
lmap = {k: v for v, k in enumerate(losses)}
n_cpts = 101

In [144]:
experiment_types = ['full-oe',
'partial-oe-phtrain-oe',
'partial-oe-phtrain-spectral',
'partial-oe-platt-oe',
'partial-oe-platt-spectral',
'partial-oe-beta-oe',
'partial-oe-beta-spectral']
emap = {k: i for i, k in enumerate(experiment_types)}

t = emap.copy()
for k, v in t.items():
    tmp = k.split("-")
    tmp[1] = "none"
    emap["-".join(tmp)] = v

t = emap.copy()
for k, v in t.items():
    emap[f"{k}-perturb"] = v

In [145]:
aurocs = np.zeros((len(datasets)*len(losses), 7, 2))
eces = np.zeros((len(datasets)*len(losses), 7))
mces = np.zeros((len(datasets)*len(losses), 7))
roc_curves = np.zeros((len(datasets), len(losses),7,2,n_cpts,2))

label_aurocs = [None] * 4
label_eces = [None] * 4
label_mces = [None] * 4

for d in dmap:
    labels = DS_TO_LABELS_MAP[d]
    dlabel_aurocs = np.zeros((len(losses), len(labels), 7, 2))
    dlabel_eces = np.zeros((len(losses), len(labels), 7))
    dlabel_mces = np.zeros((len(losses), len(labels), 7))
    
    for loss in lmap:
        i = dmap[d] * len(lmap) + lmap[loss]

        
        # ------
        with open(detection_dir / d / f"{loss}.json") as f:
            results = json.loads(f.read())
        for e,v in results["all"].items():
            if "pert" in e:
                aurocs[i, emap[e],1] = v
            else:
                aurocs[i, emap[e], 0] = v

        avg_auroc_res = results["avg"]
        for lblidx, label in enumerate(labels):
            for etype, v in avg_auroc_res[label].items():
                e = "-".join(etype.split("-")[len(label.split("-")):])
                if "pert" in e:
                    dlabel_aurocs[lmap[loss], lblidx, emap[e], 1] = v
                else:
                    dlabel_aurocs[lmap[loss], lblidx, emap[e], 0] = v
        label_aurocs[dmap[d]] = dlabel_aurocs

        for e, v in results["roc"].items():
            if "pert" in e:
                roc_curves[dmap[d], lmap[loss], emap[e], 1, :, 0] = v["x"]
                roc_curves[dmap[d], lmap[loss], emap[e], 1, :, 1] = v["y"]
            else:
                roc_curves[dmap[d], lmap[loss], emap[e], 0, :, 0] = v["x"]
                roc_curves[dmap[d], lmap[loss], emap[e], 0, :, 1] = v["y"]

        

        # ------
        with open(detection_dir / d / f"{loss}-cal.json") as f:
            results = json.loads(f.read())

        # ECE ALL
        for e,v in results["all"]["ece"].items():
            eces[i, emap[e]] = v
            
        

        # MCE ALL
        for e,v in results["all"]["mce"].items():
            mces[i, emap[e]] = v

        # ECE + MCE Labels
        avg_res = results["avg"]
        for lblidx, label in enumerate(labels):
            avg_lab_res = avg_res[label]
            for etype, v in avg_lab_res["ece"].items():
                e = "-".join(etype.split("-")[len(label.split("-")):])
                dlabel_eces[lmap[loss], lblidx, emap[e]] = v
            for etype, v in avg_lab_res["mce"].items():
                e = "-".join(etype.split("-")[len(label.split("-")):])
                dlabel_mces[lmap[loss], lblidx, emap[e]] = v
        label_eces[dmap[d]] = dlabel_eces
        label_mces[dmap[d]] = dlabel_mces

#### Correlation computation

In [116]:
for di, d in enumerate(datasets):
    dauroc, dauroc_pert = label_aurocs[di][:, :, :, 0], label_aurocs[di][:, :, :, 1]
    dauroc_increase = np.nan_to_num(((dauroc_pert - dauroc) / (1 - dauroc)).flatten())
    #print(auroc_increase_norm)

    print(dauroc.shape)
    dece, dmce = label_eces[di].flatten(), label_mces[di].flatten()

    print(spearmanr(dauroc_increase, dece))
    print(spearmanr(dauroc_increase, dmce))

(4, 10, 7)
SignificanceResult(statistic=np.float64(-0.03165470968765451), pvalue=np.float64(0.597883022010702))
SignificanceResult(statistic=np.float64(-0.01359850019215447), pvalue=np.float64(0.8207822446079237))
(4, 10, 7)
SignificanceResult(statistic=np.float64(-0.45012792407435154), pvalue=np.float64(2.2497581192302603e-15))
SignificanceResult(statistic=np.float64(0.18397177839554585), pvalue=np.float64(0.001994223823077685))
(4, 15, 7)
SignificanceResult(statistic=np.float64(0.04098450101040988), pvalue=np.float64(0.4021536257425248))
SignificanceResult(statistic=np.float64(-0.00010109261098239182), pvalue=np.float64(0.99835188480711))
(4, 6, 7)
SignificanceResult(statistic=np.float64(-0.17982028696492913), pvalue=np.float64(0.019680661688867982))
SignificanceResult(statistic=np.float64(-0.18415882746548576), pvalue=np.float64(0.016865583672672824))


  dauroc_increase = np.nan_to_num(((dauroc_pert - dauroc) / (1 - dauroc)).flatten())


#### Table - Master AUROC - Main Thesis

In [147]:
pre = r"""\begin{table}[h]
\centering
\begin{tabular}{c c c c c c c c c c}
"""

post = r"""\end{tabular}
\end{table}"""

heading_cols = [r"Loss", r"Metric", r"Base", r"CalHead\\OE", r"CalHead\\Spectral", r"Platt\\OE", r"Platt\\Spectral", r"$\beta$\\OE", r"$\beta$\\Spectral"]

heading = r"\hline& "
for header in heading_cols:
    heading += r"\multirowcell{2}{" + header + "} & "
heading = heading[:-2] + r"\\" + "\n" + " ".join(["&"] * 8) + r"\\" + "\n" + r"\hline\hline" + "\n"

lines = ""
tbls = []
for i, lrow in enumerate(aurocs):
    critv = lrow.max()
    for j, row in enumerate(lrow.T):
        if i % 4 == 0 and j == 0:
            dname = datasets[i // 4].upper()
            if dname == "MVTEC":
                dname = dname[:3] + dname[-2:].lower() + "AD"
            lines += r"\multirow{16}{*}{\rotatebox[origin=c]{90}{" + dname + "}} & "
            pass
        else:
            lines += "& "
        if j == 0:
            loss = losses[i%4].upper()
            if loss == "DSVDD": loss = "SVDD"
            if loss == "BCE": loss = "LGS"
            lines += "\multirow{4}{*}{" + loss + r"} & \multirow{2}{*}{AUROC} & "
        else:
            lines += "& & "
        for val in row:
            if j == 1:
                lines += r"\cellcolor{gray!15}"
            app = f"{(val*100):.2f}"
            if val == critv:
                app = r"\textbf{" + app + "}"
            app += " & "
            lines += app
        lines = lines[:-2]
        lines += r"\\"
        lines += "\n"

    lines += r"\cline{3-10}" + "\n"
    lines += "& & MCE & "
    critmce = mces[i].min()
    for k, mceval in enumerate(mces[i]):
        app = f"{(mceval*100):.2f}"
        if mceval == critmce:
            app = r"\textbf{" + app + "}"
        app += " & "
        #if losses[i % 4] == "dsvdd" and k == 0:
        #    lines += " - & "
        #else:
        lines += app
    lines = lines[:-2]
    lines += r"\\"
    lines += "\n"

    lines += r"\cline{3-10}" + "\n"
    lines += "& & ECE & "
    critece = eces[i].min()
    for k, eceval in enumerate(eces[i]):
        app = f"{(eceval*100):.2f}"
        if eceval == critece:
            app = r"\textbf{" + app + "}"
        app += " & "
        #if losses[i % 4] == "dsvdd" and k == 0:
        #    lines += " - & "
        #else:
        lines += app
    lines = lines[:-2]
    lines += r"\\" + "\n"
    
    
    
    if (i + 1) % 4 == 0:
        lines += "\hline\n"
    else:
        lines += "\cline{2-10}\n"

    if (i + 1) % 8 == 0:
        tbls += [lines]
        lines = ""

for i, tbl in enumerate(tbls):
    table = pre + heading + tbl + post
    with open(TABLEDIR / f"detection-results-master-{i}.tex", "w") as f:
        f.write(table)

#### TODO: Table - Master MCE + ECE - Main Thesis

#### Table - Label Tabels

In [150]:
pre = r"""\begin{table}[H]
\centering
\begin{tabular}{c c c c c c c c}
"""

post = r"\end{table}"

heading_cols = ["Label", r"Fully\\Trained", r"CalHead\\OE", r"CalHead\\Spectral", r"Platt\\OE", r"Platt\\Spectral", r"$\beta$\\OE", r"$\beta$\\Spectral"]

header = r"\hline" + "\n"
for heading in heading_cols:
    header += r"\multirowcell{2}{" + heading + "} & "
header = header[:-2]
header += r"\\" + "\n" + " ".join(["&"] * 7) + r"\\" + "\n" + r"\hline\hline" + "\n"

tbls = {}
lines = ""
for di, dname in enumerate(datasets):
    tbls[dname] = {}
    for lsi, loss in enumerate(losses):
        tbls[dname][loss] = {}
        
        for lbi, label in enumerate(DS_TO_LABELS_MAP[dname]):
            labname = label
            if "_" in labname:
                labname = " ".join(labname.split("_"))
            elif "-" in labname:
                labname = " ".join(labname.split("-"))
                
            lines += r"\multirow{2}{*}{" + " ".join(" ".join(label.split("_")).split("-")) + r"} & "
            
            rocrows = label_aurocs[di][lsi, lbi].T
            crit_av = rocrows.max()
            for j, row in enumerate(rocrows):
                if j == 1:
                    lines += "& "
                for val in row:
                    if j == 1:
                        lines += r"\cellcolor{gray!15}"
                    app = f"{(val*100):.2f}"
                    if val == crit_av:
                        app = r"\textbf{" + app + "}"
                    app += " & "
                    lines += app
                lines = lines[:-2]
                lines += r"\\"
                lines += "\n"
            lines += "\hline" + "\n"

        lines += "\hline \n"
        lsrocs_avgs = label_aurocs[di][lsi].mean(0).T
        critval = lsrocs_avgs.max()
        for j, row in enumerate(lsrocs_avgs):
            if j == 1:
                lines += "& "
            else:
                lines += r"\multirow{2}{*}{Average} & "
            for val in row:
                if j == 1:
                    lines += r"\cellcolor{gray!15}"
                app = f"{(val*100):.2f}"
                if val == critval:
                    app = r"\textbf{" + app + "}"
                app += " & "
                lines += app
            lines = lines[:-2]
            lines += r"\\"
            lines += "\n"
        lines += "\hline" + "\n"
        
        tbls[dname][loss] = lines
        lines = ""
                    
def caption_label(d, l):
    l2 = l
    l = l.upper()
    if l == "BCE":
        l = "logistic"
    elif l == "DSVDD":
        l = "SVDD"
    d2 = d
    d = d.upper()
    if d == "CIFAR10":
        d = "CIFAR-10"
    elif d == "FMNIST":
        d = "Fashion MNIST"
    elif d == "MVTEC":
        d = "MVTecAD"
    lines = "\end{tabular}" + "\n"
    lines += r"\caption[Per-Class Detection Results for " + l + " on the " + d + r" Dataset]{Average \% AUROC over 5 seeds with the " + l + r" loss for each class of the " + d + r" dataset. AUROC over non-perturbed and perturbed test inputs is shown in the white and gray rows respectively, with the largest value per class emphasized in bold font.}" + "\n"
    lines += r"\label{tab:detection-" + f"{d2}-{l2}" + "}\n"
    return lines

table = pre + header + tbls["mvtec"]["bce"]+ post
#print(table)
for d, lt in tbls.items():
    for l, tbl in lt.items():
        with open(TABLEDIR / f"detection-{d}-{l}.tex", "w") as f:
            table = pre + header + tbl + caption_label(d, l) +  post
            f.write(table)
            print(TABLEDIR / f"detection-{d}-{l}.tex")
#with open(TABLEDIR / "detection-auroc-master.tex", "w") as f:
#    f.write(table)

/home/svarq/Code/calibration-ad/save/tables/detection-fmnist-dsvdd.tex
/home/svarq/Code/calibration-ad/save/tables/detection-fmnist-ssim.tex
/home/svarq/Code/calibration-ad/save/tables/detection-fmnist-bce.tex
/home/svarq/Code/calibration-ad/save/tables/detection-fmnist-hsc.tex
/home/svarq/Code/calibration-ad/save/tables/detection-cifar10-dsvdd.tex
/home/svarq/Code/calibration-ad/save/tables/detection-cifar10-ssim.tex
/home/svarq/Code/calibration-ad/save/tables/detection-cifar10-bce.tex
/home/svarq/Code/calibration-ad/save/tables/detection-cifar10-hsc.tex
/home/svarq/Code/calibration-ad/save/tables/detection-mvtec-dsvdd.tex
/home/svarq/Code/calibration-ad/save/tables/detection-mvtec-ssim.tex
/home/svarq/Code/calibration-ad/save/tables/detection-mvtec-bce.tex
/home/svarq/Code/calibration-ad/save/tables/detection-mvtec-hsc.tex
/home/svarq/Code/calibration-ad/save/tables/detection-mpdd-dsvdd.tex
/home/svarq/Code/calibration-ad/save/tables/detection-mpdd-ssim.tex
/home/svarq/Code/calibrati

#### Detection Results Per Loss - Defence Slides

In [119]:
pre = r"""\begin{table}
\centering
\resizebox{0.85\textwidth}{!}{
\begin{tabular}{c c c c c c c c c}
"""

post = r"""\end{tabular}
}
\end{table}"""

heading_cols = [r"Metric", r"Base", r"CalHead\\OE", r"CalHead\\Spectral", r"Platt\\OE", r"Platt\\Spectral", r"$\beta$\\OE", r"$\beta$\\Spectral"]

header = r"\hline" + "\n& "
for heading in heading_cols:
    header += r"\multirowcell{2}{" + heading + "} & "
header = header[:-2] + r"\\" + "\n" + " ".join(["&"] * 7) + r"\\" + "\n" + r"\hline\hline" + "\n"

tlines = []
lines = ""
for i, lrow in enumerate(aurocs):
    critv = lrow.max()
    for j, row in enumerate(lrow.T):
        if j == 0:
            lines += "\multirow{4}{*}{" + losses[i%4].upper() + "} & \multirow{2}{*}{AUROC} & "
        else:
            lines += "& & "
        for val in row:
            if j == 1:
                lines += r"\cellcolor{gray!15}"
            app = f"{(val*100):.2f}"
            if val == critv:
                app = r"\textbf{" + app + "}"
            app += " & "
            lines += app
        lines = lines[:-2]
        lines += r"\\" + "\n"

        critece = eces[i].min()
        critmce = mces[i].min()
        if losses[i % 4] == "dsvdd":
            critece = eces[i, 1:].min()
            critmce = mces[i, 1:].min()
        if j == 1:
            lines += r"\cline{2-9}" + "\n"
            lines += "& MCE & "
            for k, mceval in enumerate(mces[i]):
                app = f"{(mceval*100):.2f}"
                if mceval == critmce:
                    app = r"\textbf{" + app + "}"
                app += " & "
                if losses[i % 4] == "dsvdd" and k == 0:
                    lines += " - & "
                else:
                    lines += app
            lines = lines[:-2]
            lines += r"\\" + "\n" + r"\cline{2-9}" + "\n"

            
            lines += "& ECE & "
            for k, eceval in enumerate(eces[i]):
                app = f"{(eceval*100):.2f}"
                if eceval == critece:
                    app = r"\textbf{" + app + "}"
                app += " & "
                if losses[i % 4] == "dsvdd" and k == 0:
                    lines += " - & "
                else:
                    lines += app
            lines = lines[:-2]
            lines += r"\\" + "\n" + r"\hline" + "\n"

    if (i + 1) % 4 == 0:
        tlines += [lines]
        lines = ""

for d, i in dmap.items():
    table = pre + header + tlines[i] + post
    with open(TABLEDIR / f"detection-results-{d}.tex", "w") as f:
        f.write(table)

### Plots

#### Calibration curves

In [None]:
# np.zeros((4,4,7,n_calbins,2))
# roc_curves = np.zeros((4,4,7,2,n_cpts,2))

fmndsv = roc_curves[3, 3, :, 1, :, :]

fig, ax = plt.subplots()

ax.set_xlim(-0.003, 1.003)
ax.set_ylim(-0.003, 1)
ax.set_xticks([0, 1])
ax.set_xticks(np.linspace(0.1,0.9,9), minor=True)
ax.set_yticks([1])
ax.set_yticks(np.linspace(0.1,0.9,9), minor=True)
ax.set_xlabel("FPR")
ax.set_ylabel("TPR")
ax.set_facecolor((0.5, 0.5, 0.5, 0.1))
ax.tick_params(which="both", bottom=False, left=False)
ax.grid(which="both", axis="both", alpha=0.2)


plt.plot([0,1],[0,1], color="black", linestyle="--")

for _, sp in ax.spines.items():
    sp.set_color((0,0,0, 0))

linestyles = ["-", "--"]
alphas = [0.6,1]
for i, curve in enumerate(fmndsv[1:]):
    fpr = np.array([-0.0001] + list(curve[:,0]))
    spl = make_smoothing_spline(fpr, [-0.0001] + list(curve[:,1]))
    tpr = list(spl(fpr))

    
    j = i % 2
    ax.plot([0] + list(fpr), [0] + tpr, alpha=alphas[j], linestyle=linestyles[j])


set_size(5,5)

fig.savefig(FIGDIR / "tmp.pdf")

# Segmentation

In [136]:
seg_dir = EXPROOT / "localization"
datasets = ["mvtec", "mpdd"]
dmap = {k: v for v, k in enumerate(datasets)}
losses = ["fcdd", "ssim", "bce"]
lmap = {k: v for v, k in enumerate(losses)}

In [137]:
experiment_types = ['full-oe',
'partial-oe-platt-oe',
'partial-oe-platt-spectral',
'partial-oe-beta-oe',
'partial-oe-beta-spectral']
emap = {k: i for i, k in enumerate(experiment_types)}

t = emap.copy()
for k, v in t.items():
    tmp = k.split("-")
    tmp[1] = "none"
    emap["-".join(tmp)] = v

t = emap.copy()
for k, v in t.items():
    emap[f"{k}-perturb"] = v

In [138]:
print(label_aupros[1])

[[[[0.23589088 0.27461503]
   [0.09212465 0.09219726]
   [0.07053073 0.07060185]
   [0.14761083 0.14766496]
   [0.14736292 0.14742   ]]

  [[0.67603147 0.7446678 ]
   [0.15204914 0.15205055]
   [0.14415626 0.14415771]
   [0.14869366 0.1486945 ]
   [0.1584189  0.15841965]]

  [[0.53262226 0.63999171]
   [0.16866489 0.16906172]
   [0.13773612 0.138025  ]
   [0.14729448 0.1475778 ]
   [0.16569082 0.16603246]]

  [[0.61344089 0.65478233]
   [0.15856914 0.16100031]
   [0.12992165 0.13192962]
   [0.14707772 0.14924172]
   [0.1522132  0.15402923]]

  [[0.74414579 0.72662501]
   [0.47549553 0.5032966 ]
   [0.47505691 0.50321447]
   [0.4736468  0.50229523]
   [0.47351259 0.50232924]]

  [[0.56293683 0.59536447]
   [0.23285839 0.23663072]
   [0.22598782 0.25207452]
   [0.22439877 0.22951732]
   [0.23627772 0.24528402]]]


 [[[0.81648421 0.8165629 ]
   [0.76912864 0.76964689]
   [0.76902491 0.76954099]
   [0.78095161 0.78156505]
   [0.78311925 0.78371066]]

  [[0.76887878 0.77506431]
   [0.781395

In [142]:
pre = r"""\begin{table}
\centering
\begin{tabular}{c c c c c c c c}
"""

post = r"""\end{tabular}
\end{table}"""

heading_cols = [r"Loss", r"Metric", r"Base", r"Platt\\OE", r"Platt\\Spectral", r"$\beta$\\OE", r"$\beta$\\Spectral"]

lines = r"\hline& "
for heading in heading_cols:
    lines += r"\multirowcell{2}{" + heading + "} & "
lines = lines[:-2] + r"\\" + "\n" + " ".join(["&"] * 7) + r"\\" + "\n" + r"\hline\hline" + "\n"

for i, lrows in enumerate(zip(aurocs, aupros, mces, eces)):
    rocrow, prorow, mcerow, ecerow = lrows
    rocrow, prorow, ecerow = rocrow.T, prorow.T, ecerow.T
    
    critroc, critpro, critmce, critece = rocrow.max(), prorow.max(), mcerow.min(), ecerow.min()
    for j in range(len(rocrow)):
        if i % 3 == 0 and j == 0:
            dname = datasets[i // 3].upper()
            if dname == "MVTEC":
                dname = dname[:3] + dname[-2:].lower() + "AD"
            lines += r"\multirow{18}{*}{\rotatebox[origin=c]{90}{\Large " + dname + "}} & "
            pass
        else:
            lines += "& "
        if j == 0:
            loss = losses[i%3].upper()
            if loss == "BCE": loss = "LGS"
            lines += "\multirow{6}{*}{" + loss + "} & \multirow{2}{*}{AUROC} & "
        else:
            lines += "& & "

        for rocval in rocrow[j][:]:
            if j == 1:
                lines += r"\cellcolor{gray!15}"
            app = f"{(rocval*100):.2f}"
            if rocval == critroc:
                app = r"\textbf{" + app + "}"
            app += " & "
            lines += app
        lines = lines[:-2]
        lines += r"\\"
        lines += "\n"

    lines += "\\cline{3-8}\n"
    for j in range(len(prorow)):
        if j == 0:
            lines += "& & \multirow{2}{*}{AUPRO} & "
        else:
            lines += "& & & "

        for proval in prorow[j][:]:
            if j == 1:
                lines += r"\cellcolor{gray!15}"
            app = f"{(proval*100):.2f}"
            if proval == critpro:
                app = r"\textbf{" + app + "}"
            app += " & "
            lines += app
        lines = lines[:-2]
        lines += r"\\"
        lines += "\n"

    lines += "\\cline{3-8}\n"
    lines += "& & MCE & "
    for mceval in mcerow:
        app = f"{(mceval*100):.2f}"
        if mceval == critmce:
            app = r"\textbf{" + app + "}"
        app += " & "
        lines += app
    
    lines = lines[:-2]
    lines += r"\\"
    lines += "\n"
    
    lines += "\\cline{3-8}\n"
    lines += "& & ECE & "
    for eceval in ecerow:
        app = f"{(eceval*100):.2f}"
        if eceval == critece:
            app = r"\textbf{" + app + "}"
        app += " & "
        lines += app
    
    lines = lines[:-2]
    lines += r"\\"
    lines += "\n"

    
    if (i + 1) % 3 == 0:
        lines += "\\hline\n"
    else:
        lines += "\\cline{2-8}\n"

table = pre + lines + post
print(table)
with open(TABLEDIR / "localization-results-master.tex", "w") as f:
    f.write(table)

\begin{table}
\centering
\begin{tabular}{c c c c c c c c}
\hline& \multirowcell{2}{Loss} & \multirowcell{2}{Metric} & \multirowcell{2}{Base} & \multirowcell{2}{Platt\\OE} & \multirowcell{2}{Platt\\Spectral} & \multirowcell{2}{$\beta$\\OE} & \multirowcell{2}{$\beta$\\Spectral} \\
& & & & & & &\\
\hline\hline
\multirow{18}{*}{\rotatebox[origin=c]{90}{\Large MVTecAD}} & \multirow{6}{*}{FCDD} & \multirow{2}{*}{AUROC} & 72.53 & 74.73 & 71.63 & 73.80 & 73.80 & 73.47 & 73.80 \\
& & & \cellcolor{gray!15}72.67 & \cellcolor{gray!15}\textbf{78.08} & \cellcolor{gray!15}75.13 & \cellcolor{gray!15}77.92 & \cellcolor{gray!15}77.92 & \cellcolor{gray!15}77.52 & \cellcolor{gray!15}77.92 \\
\cline{3-8}
& & \multirow{2}{*}{AUPRO} & 48.57 & 30.73 & 30.92 & 30.36 & 30.82 \\
& & & \cellcolor{gray!15}\textbf{51.22} & \cellcolor{gray!15}33.27 & \cellcolor{gray!15}33.55 & \cellcolor{gray!15}33.16 & \cellcolor{gray!15}33.75 \\
\cline{3-8}
& & MCE & \textbf{1.79} & 26.03 & 25.18 & 24.84 & 27.80 & 22.87 & 23.88 \\

In [141]:
pre = r"""\begin{table}
\centering
\begin{tabular}{c c c c c c}
"""

post = r"\end{table}"

heading_cols = ["Label", r"Fully\\Trained", r"Platt\\OE", r"Platt\\Spectral", r"$\beta$\\OE", r"$\beta$\\Spectral"]

header = r"\hline" + "\n"
for heading in heading_cols:
    header += r"\multirowcell{2}{" + heading + "} & "
header = header[:-2]
header += r"\\" + "\n" + " ".join(["&"] * 5) + r"\\" + "\n" + r"\hline\hline" + "\n"

tbls = {}
lines = ""
for di, dname in enumerate(datasets):
    tbls[dname] = {}
    for lsi, loss in enumerate(losses):
        tbls[dname][loss] = {}
        
        for lbi, label in enumerate(DS_TO_LABELS_MAP[dname]):
            labname = label
            if "_" in labname:
                labname = " ".join(labname.split("_"))
            elif "-" in labname:
                labname = " ".join(labname.split("-"))
                
            lines += r"\multirow{2}{*}{" + " ".join(" ".join(label.split("_")).split("-")) + r"} & "
            
            rocrows = label_aurocs[di][lsi, lbi].T
            crit_av = rocrows.max()
            for j, row in enumerate(rocrows):
                if j == 1:
                    lines += "& "
                for val in row:
                    if j == 1:
                        lines += r"\cellcolor{gray!15}"
                    app = f"{(val*100):.2f}"
                    if val == crit_av:
                        app = r"\textbf{" + app + "}"
                    app += " & "
                    lines += app
                lines = lines[:-2]
                lines += r"\\"
                lines += "\n"
            lines += "\hline" + "\n"

        lines += "\hline \n"
        lsrocs_avgs = label_aurocs[di][lsi].mean(0).T
        critval = lsrocs_avgs.max()
        for j, row in enumerate(lsrocs_avgs):
            if j == 1:
                lines += "& "
            else:
                lines += r"\multirow{2}{*}{Average} & "
            for val in row:
                if j == 1:
                    lines += r"\cellcolor{gray!15}"
                app = f"{(val*100):.2f}"
                if val == critval:
                    app = r"\textbf{" + app + "}"
                app += " & "
                lines += app
            lines = lines[:-2]
            lines += r"\\"
            lines += "\n"
        lines += "\hline" + "\n"
        
        tbls[dname][loss] = lines
        lines = ""
                    
def caption_label(d, l):
    l2 = l
    l = l.upper()
    if l == "BCE":
        l = "logistic"
    elif l == "DSVDD":
        l = "SVDD"
    d2 = d
    d = d.upper()
    if d == "CIFAR10":
        d = "CIFAR-10"
    elif d == "FMNIST":
        d = "Fashion MNIST"
    elif d == "MVTEC":
        d = "MVTecAD"
    lines = "\end{tabular}" + "\n"
    lines += r"\caption[Per-Class Localization AUROC for " + l + " on the " + d + r" Dataset]{Average \% per-pixel AUROC over 5 seeds with the " + l + r" loss for each class of the " + d + r" dataset. AUROC over non-perturbed and perturbed test inputs is shown in the white and gray rows respectively, with the largest value per class emphasized in bold font.}" + "\n"
    lines += r"\label{tab:localization-" + f"{d2}-{l2}-auroc" + "}\n"
    return lines

table = pre + header + tbls["mvtec"]["bce"]+ post
#print(table)
for d, lt in tbls.items():
    for l, tbl in lt.items():
        with open(TABLEDIR / f"localization-{d}-{l}-auroc.tex", "w") as f:
            table = pre + header + tbl + caption_label(d, l) +  post
            f.write(table)
            print(TABLEDIR / f"localization-{d}-{l}-auroc.tex")

IndexError: index 10 is out of bounds for axis 1 with size 10

In [140]:
pre = r"""\begin{table}
\centering
\begin{tabular}{c c c c c c}
"""

post = r"\end{table}"

heading_cols = ["Label", r"Fully\\Trained", r"Platt\\OE", r"Platt\\Spectral", r"$\beta$\\OE", r"$\beta$\\Spectral"]

header = r"\hline" + "\n"
for heading in heading_cols:
    header += r"\multirowcell{2}{" + heading + "} & "
header = header[:-2]
header += r"\\" + "\n" + " ".join(["&"] * 5) + r"\\" + "\n" + r"\hline\hline" + "\n"

tbls = {}
lines = ""
for di, dname in enumerate(datasets):
    tbls[dname] = {}
    for lsi, loss in enumerate(losses):
        tbls[dname][loss] = {}
        
        for lbi, label in enumerate(DS_TO_LABELS_MAP[dname]):
            labname = label
            if "_" in labname:
                labname = " ".join(labname.split("_"))
            elif "-" in labname:
                labname = " ".join(labname.split("-"))
                
            lines += r"\multirow{2}{*}{" + " ".join(" ".join(label.split("_")).split("-")) + r"} & "
            
            rocrows = label_aupros[di][lsi, lbi].T
            crit_av = rocrows.max()
            for j, row in enumerate(rocrows):
                if j == 1:
                    lines += "& "
                for val in row:
                    if j == 1:
                        lines += r"\cellcolor{gray!15}"
                    app = f"{(val*100):.2f}"
                    if val == crit_av:
                        app = r"\textbf{" + app + "}"
                    app += " & "
                    lines += app
                lines = lines[:-2]
                lines += r"\\"
                lines += "\n"
            lines += "\hline" + "\n"

        lines += "\hline \n"
        lsrocs_avgs = label_aupros[di][lsi].mean(0).T
        critval = lsrocs_avgs.max()
        for j, row in enumerate(lsrocs_avgs):
            if j == 1:
                lines += "& "
            else:
                lines += r"\multirow{2}{*}{Average} & "
            for val in row:
                if j == 1:
                    lines += r"\cellcolor{gray!15}"
                app = f"{(val*100):.2f}"
                if val == critval:
                    app = r"\textbf{" + app + "}"
                app += " & "
                lines += app
            lines = lines[:-2]
            lines += r"\\"
            lines += "\n"
        lines += "\hline" + "\n"
        
        tbls[dname][loss] = lines
        lines = ""
                    
def caption_label(d, l):
    l2 = l
    l = l.upper()
    if l == "BCE":
        l = "logistic"
    elif l == "DSVDD":
        l = "SVDD"
    d2 = d
    d = d.upper()
    if d == "CIFAR10":
        d = "CIFAR-10"
    elif d == "FMNIST":
        d = "Fashion MNIST"
    elif d == "MVTEC":
        d = "MVTecAD"
    lines = "\end{tabular}" + "\n"
    lines += r"\caption[Per-Class Localization AUPRO for " + l + " on the " + d + r" Dataset]{Average \% AUPRO over 5 seeds with the " + l + r" loss for each class of the " + d + r" dataset. AUPRO over non-perturbed and perturbed test inputs is shown in the white and gray rows respectively, with the largest value per class emphasized in bold font.}" + "\n"
    lines += r"\label{tab:localization-" + f"{d2}-{l2}-aupro" + "}\n"
    return lines

table = pre + header + tbls["mvtec"]["bce"]+ post
#print(table)
for d, lt in tbls.items():
    for l, tbl in lt.items():
        with open(TABLEDIR / f"localization-{d}-{l}-aupro.tex", "w") as f:
            table = pre + header + tbl + caption_label(d, l) +  post
            f.write(table)
            print(TABLEDIR / f"localization-{d}-{l}-aupro.tex")

/home/svarq/Code/calibration-ad/save/tables/localization-mvtec-fcdd-aupro.tex
/home/svarq/Code/calibration-ad/save/tables/localization-mvtec-ssim-aupro.tex
/home/svarq/Code/calibration-ad/save/tables/localization-mvtec-bce-aupro.tex
/home/svarq/Code/calibration-ad/save/tables/localization-mpdd-fcdd-aupro.tex
/home/svarq/Code/calibration-ad/save/tables/localization-mpdd-ssim-aupro.tex
/home/svarq/Code/calibration-ad/save/tables/localization-mpdd-bce-aupro.tex


### DEFENCE

In [None]:
pre = r"""\begin{table}
\centering
\resizebox{0.65\textwidth}{!}{
\begin{tabular}{c c c c c c c}
"""

post = r"""\end{tabular}
}
\end{table}"""

heading_cols = [r"Metric", r"Base", r"Platt\\OE", r"Platt\\Spectral", r"$\beta$\\OE", r"$\beta$\\Spectral"]

header = r"\hline" + "\n" + r"& "
for heading in heading_cols:
    header += r"\multirowcell{2}{" + heading + "} & "
header = header[:-2]
header += r"\\" + "\n" + " ".join(["&"] * 6) + r"\\" + "\n" + r"\hline\hline" + "\n"

tlines = []
lines = ""
for i, lrows in enumerate(zip(aurocs, aupros, eces, mces)):
    rocrow, prorow, ecerow, mcerow = lrows
    rocrow, prorow, ecerow, mcerow = rocrow.T, prorow.T, ecerow.T, mcerow.T
    
    critroc, critpro, critece, critmce = rocrow.max(), prorow.max(), ecerow.min(), mcerow.min()
    for j in range(len(rocrow)):
        if j == 0:
            lines += "\multirow{6}{*}{" + losses[i%3].upper() + "} & \multirow{2}{*}{AUROC} & "
        else:
            lines += "& & "

        for rocval in rocrow[j][:]:
            if j == 1:
                lines += r"\cellcolor{gray!15}"
            app = f"{(rocval*100):.2f}"
            if rocval == critroc:
                app = r"\textbf{" + app + "}"
            app += " & "
            lines += app
        lines = lines[:-2]
        lines += r"\\"
        lines += "\n"

    lines += "\\cline{2-7}\n"
    for j in range(len(prorow)):
        if j == 0:
            lines += "& \multirow{2}{*}{AUPRO} & "
        else:
            lines += "& & "

        for proval in prorow[j][:]:
            if j == 1:
                lines += r"\cellcolor{gray!15}"
            app = f"{(proval*100):.2f}"
            if proval == critpro:
                app = r"\textbf{" + app + "}"
            app += " & "
            lines += app
        lines = lines[:-2]
        lines += r"\\"
        lines += "\n"


    lines += "\\cline{2-7}\n"
    lines += "& MCE & "
    for mceval in mcerow:
        app = f"{(mceval*100):.2f}"
        if mceval == critmce:
            app = r"\textbf{" + app + "}"
        app += " & "
        lines += app
    
    lines = lines[:-2]
    lines += r"\\"
    lines += "\n"
    
    lines += "\\cline{2-7}\n"
    lines += "& ECE & "
    for eceval in ecerow:
        app = f"{(eceval*100):.2f}"
        if eceval == critece:
            app = r"\textbf{" + app + "}"
        app += " & "
        lines += app
    
    lines = lines[:-2]
    lines += r"\\"
    lines += "\n" + "\hline" + "\n"


    
    if (i + 1) % 3 == 0:
        tlines += [lines]
        lines = ""

for d, i in dmap.items():
    table = pre + header + tlines[i] + post
    with open(TABLEDIR / f"localization-results-{d}.tex", "w") as f:
        f.write(table)

In [None]:
mces

# Spectral Images

In [None]:
from phcad.data.spectral_natural_images import generate_natural_image_from_spectrum
from PIL import Image

ims = [None] * 4
fig, axs = plt.subplots(1, 4, sharex=True, sharey=True)
for i in range(len(ims)):
    im = generate_natural_image_from_spectrum()
    axs[i].imshow(im)

ax = axs[0]
ax.tick_params(which="both", bottom=False, left=False)
ax.set_xticks([])
ax.set_yticks([])

fig.set_size_inches(16, 4)
plt.autoscale(tight=True)
fig.savefig("whatever.pdf")

# Beta and Platt graphs

In [None]:
x, y = np.linspace(0, 1, 100001), np.linspace(0, 1, 100001)

def sigmoid(z, T=1, c=0):
    return 1/(1+np.exp(-z/T + c))

def logit(p, a=1, b=1, c=0):
    p = p.clip(10e-16, 1-10e-16)
    return a * np.log(p) - b * np.log(1 - p) + c

In [None]:
def get_sq_aspect(xlim, ylim):
    return (xlim[1] - xlim[0]) / (ylim[1] - ylim[0])

def get_plt(logits=True):
    fig, ax = plt.subplots(figsize=(4,4))

    ylim = [-0.01, 1.01]
    majyticks = [0, 1]
    minyticks = np.linspace(0.1, 0.9, 9)
    if logits:
        xlim = [-6, 6]
        minxticks = np.linspace(-6, 6, 13)
        majxticks = [-5, 0, 5]
    else:
        xlim = ylim
        minxticks = minyticks
        majxticks = [1]
        
    ax.set_xlim(xlim[0], xlim[1])
    ax.set_ylim(ylim[0], ylim[1])
    ax.set_xticks(majxticks)
    ax.set_xticks(minxticks, minor=True)
    ax.set_yticks(majyticks)
    ax.set_yticks(minyticks, minor=True)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)

    ax.set_facecolor((0.5, 0.5, 0.5, 0.1))
    ax.set_aspect(get_sq_aspect(xlim, ylim))
    ax.tick_params(which="both", bottom=False, left=False)
    ax.grid(which="both", axis="both", alpha=0.2)
    return(fig, ax)

In [None]:
Ttp, ctp = [(2, 1/2), (1, 1), (2, 1/2)], [(0, 0), (1, -1), (1, -1)]

fs = 32
plt.rcParams['text.usetex'] = True
plt.rcParams['figure.subplot.left'] = 0
plt.rcParams['figure.subplot.bottom'] = 0
plt.rcParams['figure.subplot.right'] = 1
plt.rcParams['figure.subplot.top'] = 1
plt.rcParams['font.size'] = 20
for i, pairs in enumerate(zip(Ttp, ctp)):
    tpair, cpair = pairs
    
    fig, ax = get_plt(logits=True)
    ax.plot(logit(x), y, color="black", linestyle="--")
    for j, pair in enumerate(zip(tpair, cpair)):
        t, c = pair
        color = "red" if j == 0 else "blue"
        ax.plot(logit(x), sigmoid(logit(y)/t + c), color=color)
        
    ax.set_xlabel(r"$z$", fontsize=fs)
    ax.set_ylabel(r"$\hat{\eta}_P$", fontsize=fs)
    fig.set_tight_layout(True)
    fig.savefig(FIGDIR / f"platt-{i}-logits.pdf", pad_inches=0)
    
    fig, ax = get_plt(logits=False)
    ax.plot(x, sigmoid(logit(y)), color="black", linestyle="--")
    for j, pair in enumerate(zip(tpair, cpair)):
        t, c = pair
        color = "red" if j == 0 else "blue"
        ax.plot(x, sigmoid(logit(y)/t + c), color=color)
    ax.set_xlabel(r"$\hat{\eta}$", fontsize=fs)
    ax.set_ylabel(r"$\hat{\eta}_P$", fontsize=fs)
    fig.set_tight_layout(True)
    fig.savefig(FIGDIR / f"platt-{i}-pests.pdf")

In [None]:
atp = [(2, 1/2), (1, 1), (1, 1), (2, 1/2)]
btp = [(1, 1), (2, 1/2), (1, 1), (1/2, 1.5)]
ctp = [(0, 0), (0, 0), (1, -1), (0.5, -1)]

In [None]:
#atp = [(2, 1/2), (1, 1), (1, 1), (2, 1/2)]
#btp = [(1, 1), (2, 1/2), (1, 1), (8/10, 1.5)]
#ctp = [(0, 0), (0, 0), (1, -1), (1, -1)]

fig, ax = plt.subplots(figsize=(4,4))
plt.rcParams['text.usetex'] = True
ax.plot([1],[1], color="red")
ax.plot([1],[1], color="blue")
ax.legend([r"$a = 2,\ b=\frac{1}{2},$" + "\n" + r"$c=\frac{1}{2}$", r"$a = \frac{1}{2},\ b=\frac{3}{2},$" + "\n" +  r"$c=-1$"])

In [None]:
logit(x, a, b, c)

In [None]:
atp = [(2, 1/2), (1, 1), (1, 1), (2, 1/2)]
btp = [(1, 1), (2, 1/2), (1, 1), (1/2, 1.5)]
ctp = [(0, 0), (0, 0), (1, -1), (0.5, -1)]

fs = 32
plt.rcParams['text.usetex'] = True
plt.rcParams['figure.subplot.left'] = 0
plt.rcParams['figure.subplot.bottom'] = 0
plt.rcParams['figure.subplot.right'] = 1
plt.rcParams['figure.subplot.top'] = 1
plt.rcParams['font.size'] = 20
for i, pairs in enumerate(zip(atp, btp, ctp)):
    apair, bpair, cpair = pairs
    
    fig, ax = get_plt(logits=True)
    ax.plot(logit(x), y, color="black", linestyle="--")
    for j, pair in enumerate(zip(apair, bpair, cpair)):
        a, b, c = pair
        color = "red" if j == 0 else "blue"
        ax.plot(logit(x), sigmoid(logit(y, a, b, c)), color=color)
    ax.set_xlabel(r"$z$", fontsize=fs)
    ax.set_ylabel(r"$\hat{\eta}_\beta$", fontsize=fs)
    fig.set_tight_layout(True)
    fig.savefig(FIGDIR / f"beta-{i}-logits.pdf", pad_inches=0)
    
    fig, ax = get_plt(logits=False)
    ax.plot(sigmoid(logit(x)), y, color="black", linestyle="--")
    for j, pair in enumerate(zip(apair, bpair, cpair)):
        a, b, c = pair
        color = "red" if j == 0 else "blue"
        ax.plot(x, sigmoid(logit(y, a, b, c)), color=color)
    ax.set_xlabel(r"$\hat{\eta}$", fontsize=fs)
    ax.set_ylabel(r"$\hat{\eta}_\beta$", fontsize=fs)
    fig.set_tight_layout(True)
    fig.savefig(FIGDIR / f"beta-{i}-pests.pdf")

In [None]:
m = 0.2*np.log(1 - 0.5) - 0.2*np.log(0.5)
print(m)