In [1]:
import sys
import re
import os

sys.path.insert(0, "src/")
sys.path.insert(1, "external_repos/pytorch_cifar100/")
sys.path.insert(1, "external_repos/pytorch_cifar10/")
import numpy as np
import random
from tqdm.auto import tqdm
from src.data_utils import load_model_checkpoint, load_dict, make_load_path
from src.postprocessing_utils import (
    get_metrics_results,
    uq_funcs_with_names,
    get_uncertainty_scores,
    get_predicted_labels,
    make_aggregation,
    get_missclassification_dataframe,
    get_ood_detection_dataframe,
    get_raw_scores_dataframe,
    ravel_df,
    create_gt_embeddings,
    get_sampled_combinations_uncertainty_scores,
)

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import roc_auc_score
from itertools import combinations
from IPython.display import display

pd.set_option("display.max_rows", None)

ModuleNotFoundError: No module named 'src'

In [2]:
training_dataset_names = [
    "cifar10",
    "cifar100",
    "noisy_cifar100",
    "missed_class_cifar10",
    "noisy_cifar10",
]
temperature = 1.0
model_ids = np.arange(20)
list_extraction_datasets = [
    "cifar10",
    "cifar100",
    "svhn",
    "blurred_cifar100",
    "blurred_cifar10",
]
list_ood_datasets = [el for el in list_extraction_datasets]
loss_function_names = ["brier_score", "cross_entropy", "spherical_score"]
use_different_approximations = False
gt_prob_approx = "same"

full_dataframe = None
full_ood_rocauc_dataframe = None
full_mis_rocauc_dataframe = None

In [3]:
# for training_dataset_name in training_dataset_names:
#     if training_dataset_name not in ['missed_class_cifar10', 'noisy_cifar10', 'noisy_cifar100']:
#         architectures = ['resnet18', 'vgg']
#         training_dataset_name_aux = training_dataset_name
#     else:
#         architectures = ['resnet18']
#         training_dataset_name_aux = training_dataset_name.split('_')[-1]
#     for architecture in architectures:
#         # try:
#         uq_results, embeddings_per_dataset, targets_per_dataset = get_sampled_combinations_uncertainty_scores(
#             loss_function_names=loss_function_names,
#             training_dataset_name=training_dataset_name,
#             architecture=architecture,
#             list_extraction_datasets=list_extraction_datasets,
#             use_different_approximations=use_different_approximations,
#         )

In [4]:
res = load_dict(
    "./external_repos/pytorch_cifar10/checkpoints/resnet18/extracted_information_for_notebook_combinations.pkl"
)
len(res["uq_results"]["Total Brier Outer"]["brier_score"]["cifar10"])

5

In [5]:
def get_ood_detection_dataframe(
    ind_dataset: str,
    uq_results: dict,
    list_ood_datasets: list[str],
) -> pd.DataFrame:
    """
    The function transforms uq_results dict into pd.Dataframe
    with ROC AUC scores of OOD detection.
    """
    roc_auc_dict = {}

    for uq_name, _ in uq_funcs_with_names:
        roc_auc_dict[uq_name] = {}

        for ood_dataset in list_ood_datasets:
            roc_auc_dict[uq_name][ood_dataset] = {}
            for loss_ in uq_results[uq_name].keys():
                roc_auc_dict[uq_name][ood_dataset][loss_] = []
                for it_ in range(len(uq_results[uq_name][loss_][ood_dataset])):
                    y_true = np.hstack(
                        [
                            np.ones(
                                uq_results[uq_name][loss_][ood_dataset][it_].shape[0]
                            ),
                            np.zeros(
                                uq_results[uq_name][loss_][ind_dataset][it_].shape[0]
                            ),
                        ]
                    )
                    y_score = np.hstack(
                        [
                            uq_results[uq_name][loss_][ood_dataset][it_],
                            uq_results[uq_name][loss_][ind_dataset][it_],
                        ]
                    )
                    score = roc_auc_score(y_true=y_true, y_score=y_score)
                    roc_auc_dict[uq_name][ood_dataset][loss_].append(score)

    data_list = []
    for metric_name, datasets in roc_auc_dict.items():
        for dataset_name, loss_functions in datasets.items():
            for loss_function_name, values in loss_functions.items():
                data_list.append(
                    (metric_name, dataset_name, loss_function_name, values)
                )

    # Create a DataFrame
    df = pd.DataFrame(
        data_list, columns=["UQMetric", "Dataset", "LossFunction", "RocAucScores_array"]
    )

    return df


# def get_missclassification_dataframe(
#     ind_dataset: str,
#     uq_results: dict,
#     true_labels: np.ndarray,
#     pred_labels: np.ndarray,
# ) -> pd.DataFrame:
#     """
#     The function transforms uq_results dict into pd.Dataframe
#     with ROC AUC scores of misclassification detection.
#     """
#     roc_auc_dict = {}

#     for uq_name, _ in uq_funcs_with_names:
#         roc_auc_dict[uq_name] = {}
#         for loss_ in uq_results[uq_name].keys():
#             y_true = (true_labels != pred_labels[loss_]).astype(np.int32)
#             y_score = uq_results[uq_name][loss_][ind_dataset]

#             score = roc_auc_score(y_true=y_true, y_score=y_score)
#             roc_auc_dict[uq_name][loss_] = score

#             # print(
#             #     f'InD: {ind_dataset} \t loss: {loss_} \t roc_auc: {score}')

#     data_list_misclassification = []
#     for metric_name, loss_function in roc_auc_dict.items():
#         for loss_function_name, value in loss_function.items():
#             data_list_misclassification.append((metric_name, loss_function_name, value))

#     # Create a DataFrame
#     df_misclassification = pd.DataFrame(
#         data_list_misclassification,
#         columns=[
#             "UQMetric",
#             "LossFunction",
#             "RocAucScore",
#         ],
#     )
#     return df_misclassification

In [None]:
full_dataframe = None

for training_dataset_name in training_dataset_names:
    if training_dataset_name not in [
        "missed_class_cifar10",
        "noisy_cifar10",
        "noisy_cifar100",
    ]:
        architectures = ["resnet18", "vgg"]
        training_dataset_name_aux = training_dataset_name
    else:
        architectures = ["resnet18"]
        training_dataset_name_aux = training_dataset_name.split("_")[-1]
    for architecture in architectures:
        ###
        folder_path = make_load_path(
            architecture=architecture,
            dataset_name=training_dataset_name,
            loss_function_name="NaN",
            model_id="NaN",
        )
        extracted_embeddings_file_path = os.path.join(
            *folder_path.split("/")[:-3],
            "extracted_information_for_notebook_combinations.pkl",
        )

        res_dict = load_dict(extracted_embeddings_file_path)
        uq_results, embeddings_per_dataset, targets_per_dataset = (
            res_dict["uq_results"],
            res_dict["embeddings_per_dataset"],
            res_dict["targets_per_dataset"],
        )
        ###

        df_ood = get_ood_detection_dataframe(
            ind_dataset=training_dataset_name_aux,
            uq_results=uq_results,
            list_ood_datasets=list_ood_datasets,
        )
        df_ood["architecture"] = architecture
        df_ood["training_dataset"] = training_dataset_name

        scores_df_unravel = get_raw_scores_dataframe(uq_results=uq_results)
        scores_df_unravel["architecture"] = architecture
        scores_df_unravel["training_dataset"] = training_dataset_name

        if full_dataframe is None:
            full_dataframe = scores_df_unravel
            full_ood_rocauc_dataframe = df_ood
        else:
            full_dataframe = pd.concat([full_dataframe, scores_df_unravel])
            full_ood_rocauc_dataframe = pd.concat([full_ood_rocauc_dataframe, df_ood])

In [None]:
full_ood_rocauc_dataframe["RocAucScoresMean"] = full_ood_rocauc_dataframe[
    "RocAucScores_array"
].apply(lambda x: np.array(x).mean())
full_ood_rocauc_dataframe["RocAucScoresStd"] = full_ood_rocauc_dataframe[
    "RocAucScores_array"
].apply(lambda x: np.array(x).std())

In [None]:
full_ood_rocauc_dataframe.sample(10)

In [None]:
full_ood_rocauc_dataframe.reset_index(drop=True).loc[2576]

In [None]:
full_ood_rocauc_dataframe.reset_index(drop=True).loc[2872]

In [None]:
full_ood_rocauc_dataframe = full_ood_rocauc_dataframe.reset_index(drop=True)

In [None]:
full_ood_rocauc_dataframe["UQMetric"].unique()

In [13]:
base_score_dict = {
    "cross_entropy": "Logscore",
    "brier_score": "Brier",
    "spherical_score": "Spherical",
}

In [14]:
pattern_baserule = r"(Logscore|Brier|Neglog|Maxprob|Spherical)"
pattern_risk = r"(Total|Bayes|Excess|Reverse Bregman Information|Bregman Information|Expected Pairwise Bregman Information|MVBI|MV|BiasBI|Bias)"

full_ood_rocauc_dataframe["base_rule"] = full_ood_rocauc_dataframe[
    "UQMetric"
].str.extract(pattern_baserule)
full_ood_rocauc_dataframe["RiskType"] = full_ood_rocauc_dataframe[
    "UQMetric"
].str.extract(pattern_risk)
full_ood_rocauc_dataframe["LossFunction"] = full_ood_rocauc_dataframe[
    "LossFunction"
].replace(base_score_dict)


full_dataframe["base_rule"] = full_dataframe["UQMetric"].str.extract(pattern_baserule)
full_dataframe["RiskType"] = full_dataframe["UQMetric"].str.extract(pattern_risk)
full_dataframe["LossFunction"] = full_dataframe["LossFunction"].replace(base_score_dict)

In [15]:
full_dataframe.to_csv("./tables/full_dataframe_with_std.csv")
full_ood_rocauc_dataframe.to_csv("./tables/full_ood_rocauc_with_std.csv")

In [2]:
full_ood_rocauc = pd.read_csv("./tables/full_ood_rocauc_with_std.csv", index_col=0)
full_ood_rocauc = full_ood_rocauc[~full_ood_rocauc.UQMetric.str.endswith("Inner Inner")]

In [3]:
full_ood_rocauc.sample(10)

Unnamed: 0,UQMetric,Dataset,LossFunction,RocAucScores_array,architecture,training_dataset,RocAucScoresMean,RocAucScoresStd,base_rule,RiskType
2178,BiasBI Spherical,cifar100,Brier,"[0.87530019, 0.8743967600000001, 0.869154215, ...",vgg,cifar10,0.873934,0.002521493,Spherical,BiasBI
5194,Reverse Bregman Information Logscore,cifar100,Logscore,"[0.49999999999999994, 0.5, 0.5, 0.500000000000...",resnet18,noisy_cifar100,0.5,5.5511150000000004e-17,Logscore,Reverse Bregman Information
6053,Excess Spherical Inner Outer,svhn,Spherical,"[0.7608834146435157, 0.8440206822372465, 0.937...",resnet18,missed_class_cifar10,0.883869,0.07169191,Spherical,Excess
380,Excess Maxprob Outer Outer,cifar100,Spherical,"[0.7393491849999999, 0.7364092799999997, 0.740...",resnet18,cifar10,0.738947,0.001928681,Maxprob,Excess
7602,MV Logscore,blurred_cifar10,Brier,"[0.7703630399999999, 0.82002716, 0.84421677, 0...",resnet18,noisy_cifar10,0.812146,0.02419833,Logscore,MV
3042,Expected Pairwise Bregman Information Maxprob,blurred_cifar10,Brier,"[0.7814324700000002, 0.7914582050000001, 0.786...",resnet18,cifar100,0.7861,0.003551094,Maxprob,Expected Pairwise Bregman Information
6573,Bias Maxprob,cifar100,Brier,"[0.23352607, 0.223433815, 0.175052895, 0.17497...",resnet18,missed_class_cifar10,0.197005,0.0259204,Maxprob,Bias
7673,MVBI Brier,svhn,Spherical,"[0.90416464159496, 0.8932192417025201, 0.88949...",resnet18,noisy_cifar10,0.88809,0.02276114,Brier,MVBI
2328,Total Brier Inner,cifar100,Brier,"[0.5, 0.49999999999999994, 0.5, 0.5, 0.5]",resnet18,cifar100,0.5,2.4825340000000002e-17,Brier,Total
3114,MVBI Logscore,blurred_cifar100,Brier,"[0.7879475200000001, 0.7851096399999999, 0.784...",resnet18,cifar100,0.787928,0.003134457,Logscore,MVBI


In [4]:
# full_ood_rocauc[
# (full_ood_rocauc.base_rule == 'Neglog')
# & (full_ood_rocauc.Dataset == 'svhn')
# & (full_ood_rocauc.training_dataset == 'cifar10')
# & (full_ood_rocauc.RiskType == 'Excess')
# ].sort_values(by=['RocAucScoresMean'], ascending=False)

In [5]:
# full_ood_rocauc[
# # (full_ood_rocauc.base_rule != 'Neglog')
# (full_ood_rocauc.Dataset == 'svhn')
# & (full_ood_rocauc.training_dataset == 'cifar10')
# & (full_ood_rocauc.RiskType == 'Excess')
# ].sort_values(by=['RocAucScoresMean'], ascending=False)

In [6]:
import sys
import os
import re
import numpy as np

sys.path.insert(0, "src/")

import pandas as pd
from src.table_utils import (
    extract_same_different_dataframes,
    collect_scores_into_dict_with_std,
    ood_detection_pairs_,
    aggregate_over_measures,
)

from IPython.display import display

pd.set_option("display.max_rows", None)

full_ood_rocauc = pd.read_csv("./tables/full_ood_rocauc_with_std.csv", index_col=0)

full_ood_rocauc = full_ood_rocauc[
    full_ood_rocauc.Dataset != full_ood_rocauc.training_dataset
]
full_ood_rocauc = full_ood_rocauc[~full_ood_rocauc.UQMetric.str.endswith("Inner Inner")]
full_ood_rocauc = full_ood_rocauc[full_ood_rocauc.base_rule != "Neglog"]


full_ood_rocauc.loc[
    (full_ood_rocauc.RiskType == "Bayes")
    & full_ood_rocauc.UQMetric.str.endswith("Outer"),
    "RiskType",
] = "Bayes Outer"
full_ood_rocauc.loc[
    (full_ood_rocauc.RiskType == "Bayes")
    & full_ood_rocauc.UQMetric.str.endswith("Inner"),
    "RiskType",
] = "Bayes Inner"

full_ood_rocauc.loc[
    (full_ood_rocauc.RiskType == "Total")
    & full_ood_rocauc.UQMetric.str.endswith("Outer"),
    "RiskType",
] = "Total Outer"
full_ood_rocauc.loc[
    (full_ood_rocauc.RiskType == "Total")
    & full_ood_rocauc.UQMetric.str.endswith("Inner"),
    "RiskType",
] = "Total Inner"


grouped_df = extract_same_different_dataframes(
    dataframe_=full_ood_rocauc,
)

same_dict_mean, same_dict_std = collect_scores_into_dict_with_std(
    dataframes_list=[
        grouped_df.logscore_logscore,
        grouped_df.brier_brier,
        grouped_df.spherical_spherical,
    ],
    ood_detection_pairs=ood_detection_pairs_,
)
same_df_mean = pd.DataFrame.from_dict(same_dict_mean)
same_df_std = pd.DataFrame.from_dict(same_dict_std)

same_agg_df_mean = aggregate_over_measures(
    dataframe_=same_df_mean,
    agg_func_="mean",
    by_=["InD", "OOD"],
)

same_agg_df_std = aggregate_over_measures(
    dataframe_=same_df_std,
    agg_func_="mean",
    by_=["InD", "OOD"],
)

different_dict_mean, different_dict_std = collect_scores_into_dict_with_std(
    dataframes_list=[
        grouped_df.logscore_not_logscore,
        grouped_df.brier_not_brier,
        grouped_df.spherical_not_spherical,
    ],
    ood_detection_pairs=ood_detection_pairs_,
)


different_df_mean = pd.DataFrame.from_dict(different_dict_mean)
different_df_std = pd.DataFrame.from_dict(different_dict_std)

different_agg_df_mean = aggregate_over_measures(
    dataframe_=different_df_mean,
    agg_func_="mean",
    by_=["InD", "OOD"],
)
different_agg_df_std = aggregate_over_measures(
    dataframe_=different_df_std,
    agg_func_="mean",
    by_=["InD", "OOD"],
)

all_dict_mean, all_dict_std = collect_scores_into_dict_with_std(
    dataframes_list=[
        full_ood_rocauc,
    ],
    ood_detection_pairs=ood_detection_pairs_,
)
all_df_mean = pd.DataFrame.from_dict(all_dict_mean)
all_df_std = pd.DataFrame.from_dict(all_dict_std)

all_agg_df_mean = aggregate_over_measures(
    dataframe_=all_df_mean,
    agg_func_="mean",
    by_=["InD", "OOD"],
)
all_agg_df_std = aggregate_over_measures(
    dataframe_=all_df_std,
    agg_func_="mean",
    by_=["InD", "OOD"],
)

In [7]:
display(all_agg_df_mean)
display(all_agg_df_std)

Unnamed: 0_level_0,Unnamed: 1_level_0,Total Outer,Total Inner,Bayes Inner,Bayes Outer,Excess,Bregman Information,Reverse Bregman Information,Expected Pairwise Bregman Information,Bias,MV,MVBI,BiasBI
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean
InD,OOD,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
cifar10,blurred_cifar10,0.826072,0.821161,0.821161,0.80081,0.808991,0.808889,0.808158,0.809926,0.591934,0.753126,0.808991,0.692797
cifar10,blurred_cifar100,0.934231,0.932439,0.932439,0.918125,0.906515,0.907438,0.904727,0.907381,0.620273,0.823016,0.907174,0.741044
cifar10,cifar100,0.89268,0.894875,0.894875,0.894342,0.848855,0.849681,0.847892,0.848992,0.600554,0.786962,0.849357,0.708305
cifar10,svhn,0.929429,0.931731,0.931731,0.930334,0.885519,0.886415,0.884501,0.88564,0.611235,0.813256,0.886221,0.726453
cifar100,blurred_cifar10,0.881626,0.867802,0.867802,0.84644,0.791173,0.797296,0.782728,0.793494,0.568576,0.714237,0.794112,0.639677
cifar100,blurred_cifar100,0.72653,0.707566,0.707566,0.687757,0.712594,0.710662,0.712087,0.715032,0.5637,0.664063,0.710781,0.621372
cifar100,cifar10,0.778428,0.781168,0.781168,0.777045,0.695321,0.700819,0.690131,0.695013,0.529135,0.64609,0.698315,0.582409
cifar100,svhn,0.827033,0.830303,0.830303,0.828556,0.711152,0.71725,0.705897,0.710309,0.536091,0.656635,0.713821,0.583142


Unnamed: 0_level_0,Unnamed: 1_level_0,Total Outer,Total Inner,Bayes Inner,Bayes Outer,Excess,Bregman Information,Reverse Bregman Information,Expected Pairwise Bregman Information,Bias,MV,MVBI,BiasBI
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean
InD,OOD,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
cifar10,blurred_cifar10,0.006415,0.006481,0.006481,0.006797,0.006056,0.006032,0.006,0.006135,0.004187,0.00469,0.006036,0.006137
cifar10,blurred_cifar100,0.003436,0.003705,0.003705,0.004326,0.003167,0.003259,0.003132,0.003111,0.002425,0.001907,0.003237,0.002907
cifar10,cifar100,0.001489,0.001515,0.001515,0.001695,0.001965,0.002019,0.001965,0.001911,0.001302,0.001297,0.002017,0.001651
cifar10,svhn,0.008176,0.00833,0.00833,0.009112,0.011637,0.011788,0.011355,0.011767,0.007671,0.006089,0.011729,0.008419
cifar100,blurred_cifar10,0.002317,0.002433,0.002433,0.002545,0.00347,0.003476,0.003417,0.003516,0.002312,0.002784,0.00349,0.003398
cifar100,blurred_cifar100,0.003205,0.00298,0.00298,0.003141,0.003653,0.003621,0.003675,0.003662,0.002686,0.002767,0.003651,0.003565
cifar100,cifar10,0.001997,0.002047,0.002047,0.002158,0.002774,0.002697,0.002816,0.002809,0.002243,0.002042,0.002689,0.002703
cifar100,svhn,0.012803,0.013264,0.013264,0.014041,0.01297,0.013166,0.012606,0.013139,0.011632,0.009907,0.013077,0.014107


In [8]:
display(same_agg_df_mean)
display(same_agg_df_std)

Unnamed: 0_level_0,Unnamed: 1_level_0,Total Outer,Total Inner,Bayes Inner,Bayes Outer,Excess,Bregman Information,Reverse Bregman Information,Expected Pairwise Bregman Information,Bias,MV,MVBI,BiasBI
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean
InD,OOD,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
cifar10,blurred_cifar10,0.827607,0.821493,0.821493,0.801022,0.839059,0.838089,0.839734,0.839354,0.715337,0.83825,0.838372,0.83906
cifar10,blurred_cifar100,0.935158,0.933333,0.933333,0.918944,0.9319,0.931995,0.931698,0.932009,0.777801,0.931217,0.931735,0.931397
cifar10,cifar100,0.892287,0.895729,0.895729,0.895145,0.883591,0.884482,0.882807,0.883482,0.745695,0.883531,0.883995,0.883079
cifar10,svhn,0.92948,0.932688,0.932688,0.931244,0.918413,0.919206,0.917672,0.91836,0.769165,0.918459,0.918871,0.917842
cifar100,blurred_cifar10,0.886792,0.870221,0.870221,0.848485,0.812526,0.815845,0.808197,0.813536,0.696163,0.806982,0.813343,0.814094
cifar100,blurred_cifar100,0.732984,0.709404,0.709404,0.689309,0.73384,0.729821,0.736137,0.73556,0.656341,0.729202,0.73115,0.735573
cifar100,cifar10,0.777776,0.782585,0.782585,0.778405,0.715351,0.720717,0.710379,0.714956,0.624165,0.714004,0.717538,0.714874
cifar100,svhn,0.82788,0.834085,0.834085,0.83162,0.732692,0.73877,0.72696,0.732347,0.642112,0.728872,0.734135,0.733829


Unnamed: 0_level_0,Unnamed: 1_level_0,Total Outer,Total Inner,Bayes Inner,Bayes Outer,Excess,Bregman Information,Reverse Bregman Information,Expected Pairwise Bregman Information,Bias,MV,MVBI,BiasBI
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean
InD,OOD,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
cifar10,blurred_cifar10,0.006417,0.006462,0.006462,0.00681,0.005856,0.005911,0.005798,0.005859,0.003263,0.005841,0.005891,0.00584
cifar10,blurred_cifar100,0.003399,0.003772,0.003772,0.004382,0.002485,0.00253,0.002449,0.002477,0.002075,0.002433,0.002475,0.002541
cifar10,cifar100,0.001444,0.00151,0.00151,0.001698,0.001747,0.001761,0.001735,0.001745,0.000974,0.001703,0.001725,0.001822
cifar10,svhn,0.008038,0.008332,0.008332,0.009214,0.008126,0.008119,0.00815,0.008109,0.007633,0.00798,0.008042,0.008101
cifar100,blurred_cifar10,0.002318,0.002317,0.002317,0.002414,0.003435,0.003389,0.003516,0.003399,0.00226,0.003363,0.003327,0.003477
cifar100,blurred_cifar100,0.003155,0.002992,0.002992,0.003062,0.003618,0.003544,0.00366,0.00365,0.001972,0.003622,0.0036,0.003697
cifar100,cifar10,0.001825,0.001901,0.001901,0.002052,0.002597,0.002567,0.002613,0.002611,0.001739,0.002475,0.002513,0.002716
cifar100,svhn,0.012098,0.013432,0.013432,0.014082,0.012588,0.012673,0.012575,0.012515,0.010591,0.011824,0.012057,0.013231


In [9]:
# different_df_mean[(different_df_mean.InD == 'cifar100') & (different_df_mean.OOD == 'svhn')]
# # ['Bregman Information'].mean()

In [10]:
display(different_agg_df_mean * 100)
display(different_agg_df_std * 100)

Unnamed: 0_level_0,Unnamed: 1_level_0,Total Outer,Total Inner,Bayes Inner,Bayes Outer,Excess,Bregman Information,Reverse Bregman Information,Expected Pairwise Bregman Information,Bias,MV,MVBI,BiasBI
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean
InD,OOD,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
cifar10,blurred_cifar10,82.556083,82.105015,82.105015,80.073907,79.89683,79.915545,79.763263,80.011682,55.079938,72.475101,79.919804,64.404253
cifar10,blurred_cifar100,93.392257,93.21416,93.21416,91.785211,89.805376,89.925251,89.573662,89.917215,56.776332,78.69496,89.898727,67.759256
cifar10,cifar100,89.281045,89.459031,89.459031,89.407397,83.727651,83.808053,83.625335,83.749566,55.217397,75.477256,83.781139,65.004756
cifar10,svhn,92.941238,93.141148,93.141148,93.003108,87.455446,87.548474,87.344466,87.473398,55.859244,77.818846,87.533718,66.265669
cifar100,blurred_cifar10,87.990377,86.69953,86.69953,84.575823,78.4055,79.111291,77.423796,78.681413,52.604696,68.332133,78.770181,58.153821
cifar100,blurred_cifar100,72.437877,70.695318,70.695318,68.724034,70.551176,70.42749,70.407049,70.818989,53.281973,64.235048,70.399112,58.330516
cifar100,cifar10,77.864512,78.069548,78.069548,77.659245,68.864436,69.418649,68.338109,68.836548,49.745784,62.345218,69.190725,53.825417
cifar100,svhn,82.67501,82.904258,82.904258,82.753503,70.397161,71.007626,69.887592,70.296264,50.075038,63.25562,70.70502,53.291298


Unnamed: 0_level_0,Unnamed: 1_level_0,Total Outer,Total Inner,Bayes Inner,Bayes Outer,Excess,Bregman Information,Reverse Bregman Information,Expected Pairwise Bregman Information,Bias,MV,MVBI,BiasBI
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean
InD,OOD,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
cifar10,blurred_cifar10,0.64136,0.648745,0.648745,0.679279,0.612233,0.607246,0.606701,0.622751,0.449538,0.430631,0.608419,0.623522
cifar10,blurred_cifar100,0.344847,0.36822,0.36822,0.430788,0.339475,0.350243,0.336003,0.33218,0.2542,0.17322,0.349077,0.302865
cifar10,cifar100,0.150432,0.151687,0.151687,0.169361,0.20382,0.210545,0.204206,0.19671,0.14117,0.116228,0.211458,0.15948
cifar10,svhn,0.82219,0.832888,0.832888,0.907787,1.280697,1.301131,1.242319,1.298642,0.768351,0.54582,1.295872,0.852563
cifar100,blurred_cifar10,0.231606,0.2472,0.2472,0.258823,0.348147,0.350501,0.338397,0.355543,0.232981,0.259172,0.3544,0.33718
cifar100,blurred_cifar100,0.322124,0.297611,0.297611,0.31674,0.366424,0.364672,0.367984,0.366616,0.292408,0.248264,0.366815,0.352114
cifar100,cifar10,0.205425,0.209574,0.209574,0.219357,0.283305,0.273978,0.288408,0.287527,0.241052,0.189713,0.274806,0.269855
cifar100,svhn,1.303775,1.32085,1.32085,1.402769,1.30975,1.332998,1.26158,1.334671,1.197875,0.926753,1.341707,1.439931


In [11]:
# full_scores = pd.read_csv('./tables/full_dataframe_with_std.csv', )

In [12]:
# full_scores.columns

In [13]:
def enhance_latex_table(input_latex):
    lines = input_latex.split("\n")
    enhanced_lines = []

    for i, line in enumerate(lines):
        if "\\begin{tabular}" in line:
            # Start centering the table
            enhanced_lines.append(r"\begin{center}")

        if "\\toprule" in line:
            # Add multicolumn headers
            enhanced_lines.append(line)
            enhanced_lines.append(
                r"\multicolumn{2}{c}{Dataset} & \multicolumn{5}{c}{Metrics} \\"
            )
            enhanced_lines.append(r"\cmidrule(lr){1-2} \cmidrule(lr){3-7}")
            continue

        # Add row coloring
        if "\\midrule" in line:
            enhanced_lines.append(line)
            enhanced_lines.append(r"\rowcolor{gray!10}")
        elif "\\bottomrule" in line:
            enhanced_lines.append(r"\end{tabular}")
            enhanced_lines.append(r"\end{center}")
        else:
            enhanced_lines.append(line)

    return "\n".join(enhanced_lines)

In [14]:
index_pairs = [
    ("CIFAR10", "Blurred CIFAR10"),
    ("CIFAR10", "Blurred CIFAR100"),
    ("CIFAR10", "CIFAR100"),
    ("CIFAR10", "SVHN"),
    ("CIFAR100", "Blurred CIFAR10"),
    ("CIFAR100", "Blurred CIFAR100"),
    ("CIFAR100", "CIFAR10"),
    ("CIFAR100", "SVHN"),
]


def get_nice_df(df_):
    df_.index = pd.MultiIndex.from_tuples(index_pairs, names=["InD", "OOD"])
    df_.columns = [
        # 'Bayes',
        # 'Excess',
        # 'Total',
        "Bayes(O)",
        "Bayes(I)",
        "Total(O)",
        "Total(I)",
        "BI",
        "RBI",
        "EPBI",
        # 'Bias',
        # 'MV',
        # 'MVBI',
        # 'BiasBI',
    ]
    # df_ = df_[['Bayes', 'Excess', 'Total', 'BI', 'RBI']]
    df_ = (100 * df_).round(2)

    display(df_)

    return df_, df_.to_latex(float_format="%.2f")

In [15]:
measures = [
    "Bayes Outer",
    "Bayes Inner",
    "Total Outer",
    "Total Inner",
    "Bregman Information",
    "Reverse Bregman Information",
    "Expected Pairwise Bregman Information",
]


# measures = [
#     'Bayes',
#     'Excess',
#     'Total'
# ]

In [16]:
nice_same = get_nice_df(same_agg_df_mean[measures].copy())
enhanced_latex = enhance_latex_table(nice_same[1])
print(enhanced_latex)

Unnamed: 0_level_0,Unnamed: 1_level_0,Bayes(O),Bayes(I),Total(O),Total(I),BI,RBI,EPBI
InD,OOD,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
CIFAR10,Blurred CIFAR10,80.1,82.15,82.76,82.15,83.81,83.97,83.94
CIFAR10,Blurred CIFAR100,91.89,93.33,93.52,93.33,93.2,93.17,93.2
CIFAR10,CIFAR100,89.51,89.57,89.23,89.57,88.45,88.28,88.35
CIFAR10,SVHN,93.12,93.27,92.95,93.27,91.92,91.77,91.84
CIFAR100,Blurred CIFAR10,84.85,87.02,88.68,87.02,81.58,80.82,81.35
CIFAR100,Blurred CIFAR100,68.93,70.94,73.3,70.94,72.98,73.61,73.56
CIFAR100,CIFAR10,77.84,78.26,77.78,78.26,72.07,71.04,71.5
CIFAR100,SVHN,83.16,83.41,82.79,83.41,73.88,72.7,73.23


\begin{center}
\begin{tabular}{llrrrrrrr}
\toprule
\multicolumn{2}{c}{Dataset} & \multicolumn{5}{c}{Metrics} \\
\cmidrule(lr){1-2} \cmidrule(lr){3-7}
 &  & Bayes(O) & Bayes(I) & Total(O) & Total(I) & BI & RBI & EPBI \\
InD & OOD &  &  &  &  &  &  &  \\
\midrule
\rowcolor{gray!10}
\multirow[t]{4}{*}{CIFAR10} & Blurred CIFAR10 & 80.10 & 82.15 & 82.76 & 82.15 & 83.81 & 83.97 & 83.94 \\
 & Blurred CIFAR100 & 91.89 & 93.33 & 93.52 & 93.33 & 93.20 & 93.17 & 93.20 \\
 & CIFAR100 & 89.51 & 89.57 & 89.23 & 89.57 & 88.45 & 88.28 & 88.35 \\
 & SVHN & 93.12 & 93.27 & 92.95 & 93.27 & 91.92 & 91.77 & 91.84 \\
\cline{1-9}
\multirow[t]{4}{*}{CIFAR100} & Blurred CIFAR10 & 84.85 & 87.02 & 88.68 & 87.02 & 81.58 & 80.82 & 81.35 \\
 & Blurred CIFAR100 & 68.93 & 70.94 & 73.30 & 70.94 & 72.98 & 73.61 & 73.56 \\
 & CIFAR10 & 77.84 & 78.26 & 77.78 & 78.26 & 72.07 & 71.04 & 71.50 \\
 & SVHN & 83.16 & 83.41 & 82.79 & 83.41 & 73.88 & 72.70 & 73.23 \\
\cline{1-9}
\end{tabular}
\end{center}
\end{tabular}



In [18]:
nice_same = get_nice_df(same_agg_df_std[measures].copy())
enhanced_latex = enhance_latex_table(nice_same[1])
print(enhanced_latex)

Unnamed: 0_level_0,Unnamed: 1_level_0,Bayes(O),Bayes(I),Total(O),Total(I),BI,RBI,EPBI
InD,OOD,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
CIFAR10,Blurred CIFAR10,0.68,0.65,0.64,0.65,0.59,0.58,0.59
CIFAR10,Blurred CIFAR100,0.44,0.38,0.34,0.38,0.25,0.24,0.25
CIFAR10,CIFAR100,0.17,0.15,0.14,0.15,0.18,0.17,0.17
CIFAR10,SVHN,0.92,0.83,0.8,0.83,0.81,0.81,0.81
CIFAR100,Blurred CIFAR10,0.24,0.23,0.23,0.23,0.34,0.35,0.34
CIFAR100,Blurred CIFAR100,0.31,0.3,0.32,0.3,0.35,0.37,0.37
CIFAR100,CIFAR10,0.21,0.19,0.18,0.19,0.26,0.26,0.26
CIFAR100,SVHN,1.41,1.34,1.21,1.34,1.27,1.26,1.25


\begin{center}
\begin{tabular}{llrrrrrrr}
\toprule
\multicolumn{2}{c}{Dataset} & \multicolumn{5}{c}{Metrics} \\
\cmidrule(lr){1-2} \cmidrule(lr){3-7}
 &  & Bayes(O) & Bayes(I) & Total(O) & Total(I) & BI & RBI & EPBI \\
InD & OOD &  &  &  &  &  &  &  \\
\midrule
\rowcolor{gray!10}
\multirow[t]{4}{*}{CIFAR10} & Blurred CIFAR10 & 0.68 & 0.65 & 0.64 & 0.65 & 0.59 & 0.58 & 0.59 \\
 & Blurred CIFAR100 & 0.44 & 0.38 & 0.34 & 0.38 & 0.25 & 0.24 & 0.25 \\
 & CIFAR100 & 0.17 & 0.15 & 0.14 & 0.15 & 0.18 & 0.17 & 0.17 \\
 & SVHN & 0.92 & 0.83 & 0.80 & 0.83 & 0.81 & 0.81 & 0.81 \\
\cline{1-9}
\multirow[t]{4}{*}{CIFAR100} & Blurred CIFAR10 & 0.24 & 0.23 & 0.23 & 0.23 & 0.34 & 0.35 & 0.34 \\
 & Blurred CIFAR100 & 0.31 & 0.30 & 0.32 & 0.30 & 0.35 & 0.37 & 0.37 \\
 & CIFAR10 & 0.21 & 0.19 & 0.18 & 0.19 & 0.26 & 0.26 & 0.26 \\
 & SVHN & 1.41 & 1.34 & 1.21 & 1.34 & 1.27 & 1.26 & 1.25 \\
\cline{1-9}
\end{tabular}
\end{center}
\end{tabular}



In [19]:
# different_agg_df_mean

In [20]:
# display(full_ood_rocauc[(full_ood_rocauc.training_dataset == 'cifar100') & (full_ood_rocauc.Dataset == 'svhn') & full_ood_rocauc.UQMetric.str.endswith('Inner Outer')].head(10))
# # full_ood_rocauc[full_ood_rocauc.RocAucScoresMean.values.isclose(0.666047)]
# display(full_ood_rocauc[(full_ood_rocauc.training_dataset == 'cifar100') & (full_ood_rocauc.Dataset == 'svhn') & (full_ood_rocauc.RiskType == 'Reverse Bregman Information')].head(10))

In [21]:
# full_ood_rocauc.loc[2586]

In [22]:
# full_ood_rocauc.loc[2872]

In [23]:
different_df_mean.sample(10)

Unnamed: 0,Total Outer,Total Inner,Bayes Inner,Bayes Outer,Excess,Bregman Information,Reverse Bregman Information,Expected Pairwise Bregman Information,Bias,MV,MVBI,BiasBI,OOD,InD,ScoringRule
23,0.880747,0.863043,0.863043,0.839023,0.728368,0.735029,0.720185,0.72989,0.483691,0.632492,0.73313,0.526921,blurred_cifar10,cifar100,[Spherical]
13,0.827065,0.8263,0.8263,0.825795,0.729555,0.737446,0.723678,0.727542,0.554977,0.658624,0.732065,0.563549,svhn,cifar100,[Brier]
8,0.892713,0.894592,0.894592,0.894228,0.839297,0.84044,0.838064,0.839388,0.634593,0.755594,0.839954,0.651157,cifar100,cifar10,[Brier]
16,0.889302,0.891456,0.891456,0.890634,0.832077,0.833074,0.831182,0.831974,0.516087,0.749927,0.832907,0.646047,cifar100,cifar10,[Spherical]
11,0.935122,0.933238,0.933238,0.919937,0.899281,0.900697,0.896842,0.900303,0.660454,0.78743,0.900233,0.678079,blurred_cifar100,cifar10,[Brier]
3,0.936491,0.93584,0.93584,0.921372,0.899925,0.900809,0.89779,0.901177,0.515921,0.788833,0.90072,0.6778,blurred_cifar100,cifar10,[Logscore]
5,0.826246,0.832356,0.832356,0.833322,0.733237,0.73816,0.727778,0.733774,0.488891,0.659003,0.736238,0.557753,svhn,cifar100,[Logscore]
7,0.881782,0.875221,0.875221,0.850487,0.81552,0.8223,0.803278,0.820982,0.496525,0.71053,0.820096,0.606594,blurred_cifar10,cifar100,[Logscore]
6,0.73176,0.722372,0.722372,0.700992,0.715151,0.714932,0.712523,0.717998,0.503697,0.650679,0.714556,0.586516,blurred_cifar100,cifar100,[Logscore]
14,0.720087,0.703382,0.703382,0.688986,0.712488,0.709476,0.712777,0.71521,0.584775,0.648712,0.70926,0.593182,blurred_cifar100,cifar100,[Brier]


In [24]:
nice_same = get_nice_df(different_agg_df_mean[measures].copy())
enhanced_latex = enhance_latex_table(nice_same[1])
print(enhanced_latex)

Unnamed: 0_level_0,Unnamed: 1_level_0,Bayes(O),Bayes(I),Total(O),Total(I),BI,RBI,EPBI
InD,OOD,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
CIFAR10,Blurred CIFAR10,80.07,82.11,82.56,82.11,79.92,79.76,80.01
CIFAR10,Blurred CIFAR100,91.79,93.21,93.39,93.21,89.93,89.57,89.92
CIFAR10,CIFAR100,89.41,89.46,89.28,89.46,83.81,83.63,83.75
CIFAR10,SVHN,93.0,93.14,92.94,93.14,87.55,87.34,87.47
CIFAR100,Blurred CIFAR10,84.58,86.7,87.99,86.7,79.11,77.42,78.68
CIFAR100,Blurred CIFAR100,68.72,70.7,72.44,70.7,70.43,70.41,70.82
CIFAR100,CIFAR10,77.66,78.07,77.86,78.07,69.42,68.34,68.84
CIFAR100,SVHN,82.75,82.9,82.68,82.9,71.01,69.89,70.3


\begin{center}
\begin{tabular}{llrrrrrrr}
\toprule
\multicolumn{2}{c}{Dataset} & \multicolumn{5}{c}{Metrics} \\
\cmidrule(lr){1-2} \cmidrule(lr){3-7}
 &  & Bayes(O) & Bayes(I) & Total(O) & Total(I) & BI & RBI & EPBI \\
InD & OOD &  &  &  &  &  &  &  \\
\midrule
\rowcolor{gray!10}
\multirow[t]{4}{*}{CIFAR10} & Blurred CIFAR10 & 80.07 & 82.11 & 82.56 & 82.11 & 79.92 & 79.76 & 80.01 \\
 & Blurred CIFAR100 & 91.79 & 93.21 & 93.39 & 93.21 & 89.93 & 89.57 & 89.92 \\
 & CIFAR100 & 89.41 & 89.46 & 89.28 & 89.46 & 83.81 & 83.63 & 83.75 \\
 & SVHN & 93.00 & 93.14 & 92.94 & 93.14 & 87.55 & 87.34 & 87.47 \\
\cline{1-9}
\multirow[t]{4}{*}{CIFAR100} & Blurred CIFAR10 & 84.58 & 86.70 & 87.99 & 86.70 & 79.11 & 77.42 & 78.68 \\
 & Blurred CIFAR100 & 68.72 & 70.70 & 72.44 & 70.70 & 70.43 & 70.41 & 70.82 \\
 & CIFAR10 & 77.66 & 78.07 & 77.86 & 78.07 & 69.42 & 68.34 & 68.84 \\
 & SVHN & 82.75 & 82.90 & 82.68 & 82.90 & 71.01 & 69.89 & 70.30 \\
\cline{1-9}
\end{tabular}
\end{center}
\end{tabular}



In [25]:
nice_same = get_nice_df(different_agg_df_std[measures].copy())
enhanced_latex = enhance_latex_table(nice_same[1])
print(enhanced_latex)

Unnamed: 0_level_0,Unnamed: 1_level_0,Bayes(O),Bayes(I),Total(O),Total(I),BI,RBI,EPBI
InD,OOD,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
CIFAR10,Blurred CIFAR10,0.68,0.65,0.64,0.65,0.61,0.61,0.62
CIFAR10,Blurred CIFAR100,0.43,0.37,0.34,0.37,0.35,0.34,0.33
CIFAR10,CIFAR100,0.17,0.15,0.15,0.15,0.21,0.2,0.2
CIFAR10,SVHN,0.91,0.83,0.82,0.83,1.3,1.24,1.3
CIFAR100,Blurred CIFAR10,0.26,0.25,0.23,0.25,0.35,0.34,0.36
CIFAR100,Blurred CIFAR100,0.32,0.3,0.32,0.3,0.36,0.37,0.37
CIFAR100,CIFAR10,0.22,0.21,0.21,0.21,0.27,0.29,0.29
CIFAR100,SVHN,1.4,1.32,1.3,1.32,1.33,1.26,1.33


\begin{center}
\begin{tabular}{llrrrrrrr}
\toprule
\multicolumn{2}{c}{Dataset} & \multicolumn{5}{c}{Metrics} \\
\cmidrule(lr){1-2} \cmidrule(lr){3-7}
 &  & Bayes(O) & Bayes(I) & Total(O) & Total(I) & BI & RBI & EPBI \\
InD & OOD &  &  &  &  &  &  &  \\
\midrule
\rowcolor{gray!10}
\multirow[t]{4}{*}{CIFAR10} & Blurred CIFAR10 & 0.68 & 0.65 & 0.64 & 0.65 & 0.61 & 0.61 & 0.62 \\
 & Blurred CIFAR100 & 0.43 & 0.37 & 0.34 & 0.37 & 0.35 & 0.34 & 0.33 \\
 & CIFAR100 & 0.17 & 0.15 & 0.15 & 0.15 & 0.21 & 0.20 & 0.20 \\
 & SVHN & 0.91 & 0.83 & 0.82 & 0.83 & 1.30 & 1.24 & 1.30 \\
\cline{1-9}
\multirow[t]{4}{*}{CIFAR100} & Blurred CIFAR10 & 0.26 & 0.25 & 0.23 & 0.25 & 0.35 & 0.34 & 0.36 \\
 & Blurred CIFAR100 & 0.32 & 0.30 & 0.32 & 0.30 & 0.36 & 0.37 & 0.37 \\
 & CIFAR10 & 0.22 & 0.21 & 0.21 & 0.21 & 0.27 & 0.29 & 0.29 \\
 & SVHN & 1.40 & 1.32 & 1.30 & 1.32 & 1.33 & 1.26 & 1.33 \\
\cline{1-9}
\end{tabular}
\end{center}
\end{tabular}

