In [1]:
import yaml
import glob

In [3]:
file = "results.yaml"

with open(file, "r") as f:
    results = yaml.load(f, Loader=yaml.FullLoader)

In [22]:
def add_dicts(dest, src):
    for k, v in src.items():
        if k in dest:
            if isinstance(v, dict):
                add_dicts(dest[k], v)
            else:
                raise ValueError(f"Key {k} already exists in {dest}")
        else:
            dest[k] = v

In [45]:
results = {}
for file in glob.glob("results/*.yaml"):
    print(file)
    with open(file, "r") as f:
        data = yaml.load(f, Loader=yaml.FullLoader)
        add_dicts(results, data)

results/VESDE_dinov2_imagenet_sub.yaml
results/subVPSDE_vit_imagenet.yaml
results/VESDE_vit_imagenet200.yaml
results/Residual_resnet18_32x32_cifar100_open_ood_cifar100.yaml
results/VESDE_swin_imagenet.yaml
results/Residual_resnet50d_imagenet200.yaml
results/Residual_dinov2_imagenet_sub.yaml
results/VESDE_resnet50d_imagenet.yaml
results/VESDE_deit_imagenet.yaml
results/VESDE_dinov2_imagenet.yaml
results/VESDE_swin_imagenet200.yaml
results/VESDE_dino_imagenet.yaml
results/subVPSDE_resnet50d_imagenet.yaml
results/VPSDE_clip_imagenet.yaml
results/Residual_repvgg_imagenet_sub.yaml
results/Residual_resnet18_32x32_cifar10_open_ood_cifar10.yaml
results/VPSDE_dino_imagenet_sub.yaml
results/VESDE_repvgg_imagenet_sub.yaml
results/VPSDE_resnet18_224x224_imagenet200_open_ood_imagenet200.yaml
results/Residual_swin_imagenet200.yaml
results/VPSDE_dino_imagenet.yaml
results/Residual_vit_imagenet200.yaml
results/subVPSDE_resnet18_224x224_imagenet200_open_ood_imagenet200.yaml
results/VESDE_clip_imagenet.

In [57]:
results.keys()

dict_keys(['VESDE', 'subVPSDE', 'Residual', 'VPSDE'])

In [75]:
def multicol(name, space=2, bars=False):
    if bars:
        return r"\multicolumn{" +str(space)+r"}{|c|}{\textbf{" + name + r"}}"
    return r"\multicolumn{" +str(space)+r"}{c}{\textbf{" + name + r"}}"

def get_names(data, dataset):
    far_names = None
    near_names = None
    for d in data.values():
        if dataset not in d:
            continue
        far_names = []
        near_names = []

        for _dataset in d[dataset]["farood"]:
            far_names.append(_dataset["dataset"])
       
        for _dataset in d[dataset]["nearood"]:
            near_names.append(_dataset["dataset"])

        return far_names, near_names
    return far_names, near_names
def create_latex_table(result, method, dataset, encoders):
    far_names, near_names = get_names(result[method], dataset)
    if far_names is None:
        return ""
    
    print(far_names, near_names)

    total_items = len(far_names) + len(near_names)

    header = r"\begin{table}[ht]"+ "\n"
    header += r"\caption{" +f'Result for '+ method.replace("_", "\\_") + ' on '+ dataset.replace("_", "\\_") + r"}"+ "\n"

    header += r"""\centering
\resizebox{\textwidth}{!}{% Resize table to fit within \textwidth horizontally
"""
    header += r"\begin{tabular}{@{}l*{" + str(total_items+1) + r"}{SS}@{}}" + "\n"
    header += r"\toprule" + "\n"

    description = r"\textbf{Encoder} & "  + " & ".join(multicol(name.replace("_", r"\_")) for name in near_names+far_names ) + multicol("Average") + r" \\" + "\n"
    description +=  r" & {\footnotesize AUROC} $\uparrow$ & {\footnotesize FPR95} $\downarrow$ "*(total_items +1)+ r" \\" + "\n"
    midrule = r"\midrule" + "\n"
    footer = r"\label{tab:" + f"{method}_{dataset}" + r"}" + "\n"
    footer += r"""
\bottomrule
\end{tabular}
}
\end{table}
"""+"\n"

    rows = [" & "   + multicol("Near OOD", len(near_names)*2, True) + " & " + multicol("Far OOD", len(far_names)*2, True) + r" & \\" + "\n"]
    # rows = []

    lookup = {}
    max_data = {}
    for encoder in sorted(encoders):
        if encoder not in result[method]:
            continue
        if dataset not in result[method][encoder]:
            continue
        for name, data in result[method][encoder][dataset].items():
            if name not in ["nearood", "farood"]:
                continue
            lookup_tmp = {d["dataset"]: d["metrics"] for d in data}
            if encoder not in lookup:
                lookup[encoder] = {}
            lookup[encoder].update(lookup_tmp)
 
    for e in encoders:
        if e not in lookup:
            continue
        for data_name in lookup[e]:
            if data_name not in lookup[e]:
                continue
            if data_name not in max_data:
                max_data[data_name] = {
                    "AUC": 0,
                    "FPR_95": 1
                }
            if lookup[e][data_name]["AUC"] > max_data[data_name]["AUC"]:
                max_data[data_name]["AUC"] = lookup[e][data_name]["AUC"]
            if lookup[e][data_name]["FPR_95"] < max_data[data_name]["FPR_95"]:
                max_data[data_name]["FPR_95"] = lookup[e][data_name]["FPR_95"]

    for encoder in sorted(encoders):
        if encoder not in result[method]:
            continue
        if dataset not in result[method][encoder]:
            continue
        if "resnet18_" in encoder:
            row = [r"resnet18\_open\_ood"]
        elif "resnet50_" in encoder:
            row = [r"resnet50\_open\_ood"]
        else:
            row = [encoder]
        avg_auc = 0
        avg_fpr = 0
        count = 0
        for data_names in near_names + far_names:
            data_res = max_data[data_names]
            max_auc = data_res["AUC"]
            min_fpr = data_res["FPR_95"]
            if data_names in lookup[encoder]:
                metrics = lookup[encoder][data_names]
                if metrics["AUC"] == max_auc:
                    row.append(r"\textbf{" + f"{metrics['AUC']*100:.2f}" + r"}")
                else:
                    row.append(f"{metrics['AUC']*100:.2f}")
                avg_auc += metrics["AUC"]
                if metrics["FPR_95"] == min_fpr:
                    row.append(r"\textbf{" + f"{metrics['FPR_95']*100:.2f}" + r"}")
                else:
                    row.append(f"{metrics['FPR_95']*100:.2f}")
                avg_fpr += metrics["FPR_95"]
                count += 1
            else:
                row.append("-")
                row.append("-")
        print(row)
        row.append(f"{avg_auc/count*100:.2f}")
        row.append(f"{avg_fpr/count*100:.2f}")
        rows.append(" & ".join(row) + r" \\")

    return header + description + midrule + "\n".join(rows) + footer

# Generate the LaTeX table
datasets = ['imagenet_sub', 'imagenet','imagenet200', 'cifar10', 'cifar100']
encoders = [ 'dino', 'dinov2', 'vit', 'clip', 'repvgg', 'resnet50d', 'swin', 'deit'] + ['resnet18_32x32_cifar10_open_ood', 'resnet18_32x32_cifar100_open_ood', 'resnet18_224x224_imagenet200_open_ood', 'resnet50_224x224_imagenet_open_ood']
with open("out.txt", "w") as f:
    for method, result in results.items():
        for dataset in datasets:
            f.write(create_latex_table(results, method, dataset, encoders))



['openimageo', 'inaturalist', 'textures'] ['imagenet-o']
['deit', '75.25', '89.80', '87.92', '66.69', '90.04', '63.21', '82.17', '78.39']
['dino', '81.59', '71.65', '86.71', '60.79', '87.97', '66.29', '\\textbf{97.43}', '\\textbf{12.87}']
['dinov2', '85.56', '65.45', '93.88', '31.41', '\\textbf{99.33}', '\\textbf{1.27}', '93.43', '32.38']
['repvgg', '76.00', '79.85', '80.95', '70.91', '83.06', '72.77', '92.59', '31.71']
['resnet50d', '79.76', '72.65', '85.52', '61.16', '85.20', '68.03', '93.87', '28.99']
['swin', '85.43', '72.40', '93.53', '39.62', '98.07', '9.33', '89.96', '50.52']
['vit', '\\textbf{90.18}', '\\textbf{44.95}', '\\textbf{95.16}', '\\textbf{25.20}', '98.68', '6.95', '93.39', '28.76']
['textures', 'openimageo', 'inaturalist'] ['imagenet-o', 'ninco', 'ssb_hard']
['clip', '67.74', '90.75', '59.69', '93.96', '53.27', '95.67', '70.85', '94.15', '76.78', '86.05', '61.69', '99.43']
['deit', '75.81', '89.80', '78.99', '82.14', '64.24', '91.61', '83.08', '76.80', '87.08', '71.04

In [60]:
for k, v in results['Residual'].items():
    print(k, v.keys()) 

resnet18_32x32_cifar100_open_ood dict_keys(['cifar100'])
resnet50d dict_keys(['imagenet200', 'imagenet', 'imagenet_sub'])
dinov2 dict_keys(['imagenet_sub', 'imagenet', 'imagenet200'])
repvgg dict_keys(['imagenet_sub', 'imagenet200', 'imagenet'])
resnet18_32x32_cifar10_open_ood dict_keys(['cifar10'])
swin dict_keys(['imagenet200', 'imagenet_sub', 'imagenet'])
vit dict_keys(['imagenet200', 'imagenet', 'imagenet_sub'])
deit dict_keys(['imagenet200', 'imagenet_sub', 'imagenet'])
resnet18_224x224_imagenet200_open_ood dict_keys(['imagenet200'])
clip dict_keys(['imagenet200', 'imagenet'])
dino dict_keys(['imagenet200', 'imagenet_sub', 'imagenet'])
resnet50_224x224_imagenet_open_ood dict_keys(['imagenet'])
