In [1]:
import pandas as pd
import numpy as np
from source.source.table_utils import (
    collect_scores_into_dict,
    extract_same_different_dataframes,
    ood_detection_pairs_,
    aggregate_over_measures,
)
from source.source.path_config import REPOSITORY_ROOT
from source.metrics.constants import GName
from source.losses.constants import LossName
from IPython.display import display

pd.set_option("display.max_rows", None)

stty: 'standard input': Inappropriate ioctl for device
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
REPOSITORY_ROOT

'/home/nkotelevskii/github/uncertainty_from_proper_scoring_rules'

In [3]:
full_ood_rocauc = pd.read_pickle(
    f"{REPOSITORY_ROOT}/tables/central_tables/full_ood_rocauc.pkl"
)

In [4]:
full_ood_rocauc.sample(10)

Unnamed: 0,UQMetric,Dataset,LossFunction,RocAucScores_array,architecture,training_dataset,base_rule,RiskType
472,BrierScore ExcessRisk outer inner,svhn,BrierScore,"[0.6544549708051629, 0.669543815304241, 0.6660...",vgg19,noisy_cifar100,BrierScore,outer inner
892,ZeroOneScore ExcessRisk central central,svhn,BrierScore,"[0.5, 0.5, 0.5, 0.5, 0.5]",vgg19,cifar100,ZeroOneScore,central central
1089,SphericalScore ExcessRisk outer outer,blurred_cifar100,CrossEntropy,"[0.671128615, 0.673208845, 0.6726021449999999,...",vgg19,noisy_cifar100,SphericalScore,outer outer
73,LogScore TotalRisk inner inner,blurred_cifar10,BrierScore,"[0.874240705, 0.87264773, 0.87150896, 0.874508...",resnet18,noisy_cifar100,LogScore,inner inner
300,LogScore BayesRisk central,cifar10,CrossEntropy,"[0.5, 0.49999999999999994, 0.5, 0.5, 0.5]",vgg19,noisy_cifar10,LogScore,central
880,ZeroOneScore ExcessRisk central inner,blurred_cifar100,BrierScore,"[0.5, 0.5, 0.5, 0.5, 0.5]",resnet18,noisy_cifar10,ZeroOneScore,central inner
920,ZeroOneScore BayesRisk inner,cifar100,SphericalScore,"[0.5, 0.5, 0.5, 0.5, 0.4999999999999999]",resnet18,cifar100,ZeroOneScore,inner
1158,SphericalScore ExcessRisk inner central,cifar100,CrossEntropy,"[0.8670927150000001, 0.8657973050000001, 0.865...",resnet18,cifar10,SphericalScore,inner central
1286,LogScore energy inner,blurred_cifar100,SphericalScore,"[0.9567715449999998, 0.95150472, 0.95037681, 0...",resnet18,cifar10,LogScore,energy inner
473,BrierScore ExcessRisk outer inner,svhn,SphericalScore,"[0.3849337430854333, 0.39125100261216966, 0.36...",resnet18,cifar100,BrierScore,outer inner


In [5]:
full_ood_rocauc.columns

Index(['UQMetric', 'Dataset', 'LossFunction', 'RocAucScores_array',
       'architecture', 'training_dataset', 'base_rule', 'RiskType'],
      dtype='object')

In [6]:
type(full_ood_rocauc.RocAucScores_array.values[0])

list

In [7]:
full_ood_rocauc.RiskType.unique()

array(['outer outer', 'outer inner', 'outer central', 'inner outer',
       'inner inner', 'inner central', 'central outer', 'central inner',
       'central central', 'outer', 'inner', 'central', 'energy outer',
       'energy inner'], dtype=object)

In [8]:
full_ood_rocauc.Dataset.unique()

array(['cifar10', 'cifar100', 'svhn', 'blurred_cifar100',
       'blurred_cifar10'], dtype=object)

In [9]:
# full_ood_rocauc.UQMetric.unique()

In [10]:
def selector(
    df,
    ind_dataset,
    ood_dataset,
    architecture,
    UQMetric,
):
    arr = np.array(
        df[
            (df.UQMetric == UQMetric)
            & (df.training_dataset == ind_dataset)
            & (df.Dataset == ood_dataset)
            & (df.architecture == architecture)
        ].RocAucScores_array.values[0]
    )
    # print(arr)

    # return f"Mean: {arr.mean()}, Std: {arr.std()}"
    return float(arr.mean()), float(arr.std())

In [25]:
ind_dataset = "cifar10"
architecture = "resnet18"

In [26]:
def get_specific_stats(
    ind_dataset_,
    architecture_,
    loss_function_,
    base_rule_,
):
    selected_results = full_ood_rocauc[
        (full_ood_rocauc.base_rule == base_rule_)
        & (full_ood_rocauc.LossFunction == loss_function_)
    ]

    full_res = {}
    for ood_dataset in [el for el in full_ood_rocauc.Dataset.unique()]:
        res_dict = {}
        for uqmetric_name in [
            el for el in full_ood_rocauc.UQMetric.unique() if el.startswith(base_rule_)
        ]:
            # for uqmetric_name in [
            #     # f"{base_rule_} energy outer",
            #     # f"{base_rule_} energy inner",
            #     f"{base_rule_} ExcessRisk central outer",
            #     f"{base_rule_} ExcessRisk inner central",
            #     f"{base_rule_} ExcessRisk central inner",
            #     f"{base_rule_} ExcessRisk outer outer",
            #     f"{base_rule_} ExcessRisk central inner",
            #     f"{base_rule_} BayesRisk inner",
            # ]:
            mean, std = selector(
                df=selected_results,
                UQMetric=uqmetric_name,
                ind_dataset=ind_dataset_,
                ood_dataset=ood_dataset,
                architecture=architecture_,
            )
            res_dict[uqmetric_name] = {"mean": mean, "std": std}
        full_res[ood_dataset] = res_dict

    return full_res

In [27]:
ce_full_res = get_specific_stats(
    ind_dataset_=ind_dataset,
    architecture_=architecture,
    loss_function_=LossName.CROSS_ENTROPY.value,
    base_rule_=GName.LOG_SCORE.value,
)

pd.DataFrame.from_dict(ce_full_res)

Unnamed: 0,cifar10,cifar100,svhn,blurred_cifar100,blurred_cifar10
LogScore TotalRisk outer outer,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.9080571840000001, 'std': 0.00036716...","{'mean': 0.953863343577136, 'std': 0.007085300...","{'mean': 0.9596558270000001, 'std': 0.00161261...","{'mean': 0.8871565119999998, 'std': 0.00357141..."
LogScore TotalRisk outer inner,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.9129897650000001, 'std': 0.00051389...","{'mean': 0.9606241510448678, 'std': 0.00491236...","{'mean': 0.9584844369999999, 'std': 0.00164442...","{'mean': 0.878910048, 'std': 0.003578268926909..."
LogScore TotalRisk outer central,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.9115055569999999, 'std': 0.00051647...","{'mean': 0.9591314389981562, 'std': 0.00559997...","{'mean': 0.96026257, 'std': 0.0015374769479475...","{'mean': 0.885344473, 'std': 0.003496752955318..."
LogScore TotalRisk inner outer,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.908057157, 'std': 0.000367180266607...","{'mean': 0.9538633408881376, 'std': 0.00708530...","{'mean': 0.959655841, 'std': 0.001612617785510...","{'mean': 0.887156512, 'std': 0.003571416991943..."
LogScore TotalRisk inner inner,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.9129897550000001, 'std': 0.00051390...","{'mean': 0.9606241514290105, 'std': 0.00491237...","{'mean': 0.958484433, 'std': 0.001644427299030...","{'mean': 0.878910042, 'std': 0.003578285745728..."
LogScore TotalRisk inner central,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.911505555, 'std': 0.000516460563276...","{'mean': 0.9591314416871543, 'std': 0.00559997...","{'mean': 0.960262556, 'std': 0.001537470749898...","{'mean': 0.885344454, 'std': 0.00349674806107531}"
LogScore TotalRisk central outer,"{'mean': 0.5, 'std': 2.482534153247273e-17}","{'mean': 0.907233386, 'std': 0.000372314613094...","{'mean': 0.9544046508143824, 'std': 0.00526546...","{'mean': 0.9549947209999999, 'std': 0.00176633...","{'mean': 0.870735714, 'std': 0.003683895913432..."
LogScore TotalRisk central inner,"{'mean': 0.5, 'std': 3.510833468576701e-17}","{'mean': 0.909522012, 'std': 0.000398843793252...","{'mean': 0.9579711558850645, 'std': 0.00441916...","{'mean': 0.9538404719999999, 'std': 0.00189558...","{'mean': 0.864681672, 'std': 0.003675805090984..."
LogScore TotalRisk central central,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.909344899, 'std': 0.000345986631972...","{'mean': 0.9576028318992009, 'std': 0.00414633...","{'mean': 0.950249854, 'std': 0.002181744797583...","{'mean': 0.8525495009999998, 'std': 0.00369358..."
LogScore ExcessRisk outer outer,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.9017172160000001, 'std': 0.00063449...","{'mean': 0.9407609787953289, 'std': 0.01017231...","{'mean': 0.9545872879999999, 'std': 0.00167401...","{'mean': 0.894843928, 'std': 0.003109634483314..."


In [None]:
# проверить: амплитуды значений?

# равномерное распределение для OOD? overconfident на InD?

In [28]:
ce_full_res = get_specific_stats(
    ind_dataset_=ind_dataset,
    architecture_=architecture,
    loss_function_=LossName.BRIER_SCORE.value,
    base_rule_=GName.BRIER_SCORE.value,
)

pd.DataFrame.from_dict(ce_full_res)

Unnamed: 0,cifar10,cifar100,svhn,blurred_cifar100,blurred_cifar10
BrierScore TotalRisk outer outer,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.9004229269999999, 'std': 0.00184132...","{'mean': 0.9583011493546405, 'std': 0.00342279...","{'mean': 0.951915082, 'std': 0.002837802694636...","{'mean': 0.8754695009999999, 'std': 0.00474153..."
BrierScore TotalRisk outer inner,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.9037557060000001, 'std': 0.00178891...","{'mean': 0.9626046269975415, 'std': 0.00341075...","{'mean': 0.9517880389999999, 'std': 0.00337229...","{'mean': 0.8714103339999999, 'std': 0.00490787..."
BrierScore TotalRisk outer central,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.9037557060000001, 'std': 0.00178891...","{'mean': 0.9626046269975415, 'std': 0.00341075...","{'mean': 0.9517880389999999, 'std': 0.00337229...","{'mean': 0.8714103339999999, 'std': 0.00490787..."
BrierScore TotalRisk inner outer,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.9004229300000001, 'std': 0.00184133...","{'mean': 0.9583011489704978, 'std': 0.00342279...","{'mean': 0.951915088, 'std': 0.002837799209185...","{'mean': 0.875469481, 'std': 0.004741517187949..."
BrierScore TotalRisk inner inner,"{'mean': 0.5, 'std': 2.482534153247273e-17}","{'mean': 0.9037557210000001, 'std': 0.00178893...","{'mean': 0.9626046258451136, 'std': 0.00341075...","{'mean': 0.951788043, 'std': 0.003372293117237...","{'mean': 0.8714103479999998, 'std': 0.00490785..."
BrierScore TotalRisk inner central,"{'mean': 0.5, 'std': 2.482534153247273e-17}","{'mean': 0.9037557210000001, 'std': 0.00178893...","{'mean': 0.9626046258451136, 'std': 0.00341075...","{'mean': 0.951788043, 'std': 0.003372293117237...","{'mean': 0.8714103479999998, 'std': 0.00490785..."
BrierScore TotalRisk central outer,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.9004229300000001, 'std': 0.00184133...","{'mean': 0.9583011489704978, 'std': 0.00342279...","{'mean': 0.951915088, 'std': 0.002837799209185...","{'mean': 0.875469481, 'std': 0.004741517187949..."
BrierScore TotalRisk central inner,"{'mean': 0.5, 'std': 2.482534153247273e-17}","{'mean': 0.9037557210000001, 'std': 0.00178893...","{'mean': 0.9626046258451136, 'std': 0.00341075...","{'mean': 0.951788043, 'std': 0.003372293117237...","{'mean': 0.8714103479999998, 'std': 0.00490785..."
BrierScore TotalRisk central central,"{'mean': 0.5, 'std': 2.482534153247273e-17}","{'mean': 0.9037557210000001, 'std': 0.00178893...","{'mean': 0.9626046258451136, 'std': 0.00341075...","{'mean': 0.951788043, 'std': 0.003372293117237...","{'mean': 0.8714103479999998, 'std': 0.00490785..."
BrierScore ExcessRisk outer outer,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.8923923420000002, 'std': 0.00192955...","{'mean': 0.9419448321296866, 'std': 0.00487508...","{'mean': 0.943664447, 'std': 0.002066786144001...","{'mean': 0.8804309659999999, 'std': 0.00486060..."


In [29]:
ce_full_res = get_specific_stats(
    ind_dataset_=ind_dataset,
    architecture_=architecture,
    loss_function_=LossName.SPHERICAL_SCORE.value,
    base_rule_=GName.SPHERICAL_SCORE.value,
)

pd.DataFrame.from_dict(ce_full_res)

Unnamed: 0,cifar10,cifar100,svhn,blurred_cifar100,blurred_cifar10
SphericalScore TotalRisk outer outer,"{'mean': 0.5, 'std': 4.2998752849492583e-17}","{'mean': 0.9024201730000001, 'std': 0.00033360...","{'mean': 0.9600153334357714, 'std': 0.00597145...","{'mean': 0.9517489510000001, 'std': 0.00414517...","{'mean': 0.870025593, 'std': 0.008851167628688..."
SphericalScore TotalRisk outer inner,"{'mean': 0.5, 'std': 4.965068306494546e-17}","{'mean': 0.9041684219999999, 'std': 0.00029834...","{'mean': 0.9621895978027043, 'std': 0.00535563...","{'mean': 0.9501645360000002, 'std': 0.00461991...","{'mean': 0.8658952449999999, 'std': 0.00910068..."
SphericalScore TotalRisk outer central,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.9039240450000001, 'std': 0.00029071...","{'mean': 0.961985999539029, 'std': 0.005519133...","{'mean': 0.951014762, 'std': 0.004493768158445...","{'mean': 0.867609401, 'std': 0.009057706114455..."
SphericalScore TotalRisk inner outer,"{'mean': 0.5, 'std': 3.510833468576701e-17}","{'mean': 0.902420156, 'std': 0.000333581505713...","{'mean': 0.9600153315150584, 'std': 0.00597144...","{'mean': 0.9517489400000001, 'std': 0.00414517...","{'mean': 0.870025598, 'std': 0.008851159491483..."
SphericalScore TotalRisk inner inner,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.904168439, 'std': 0.000298347101718...","{'mean': 0.9621896001075599, 'std': 0.00535563...","{'mean': 0.950164526, 'std': 0.004619913043442...","{'mean': 0.865895264, 'std': 0.00910068591805771}"
SphericalScore TotalRisk inner central,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.903924043, 'std': 0.000290714284299...","{'mean': 0.9619860033804548, 'std': 0.00551913...","{'mean': 0.951014762, 'std': 0.004493771022058...","{'mean': 0.867609439, 'std': 0.009057709240543..."
SphericalScore TotalRisk central outer,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.8982018270000001, 'std': 0.00059593...","{'mean': 0.9535957298709281, 'std': 0.00675321...","{'mean': 0.9509990899999998, 'std': 0.00341377...","{'mean': 0.8746209830000001, 'std': 0.00842897..."
SphericalScore TotalRisk central inner,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.9002729730000001, 'std': 0.00047715...","{'mean': 0.9569725714505225, 'std': 0.00651046...","{'mean': 0.952266094, 'std': 0.003638578710021...","{'mean': 0.8746428009999999, 'std': 0.00855689..."
SphericalScore TotalRisk central central,"{'mean': 0.5, 'std': 0.0}","{'mean': 0.9008602910000001, 'std': 0.00043751...","{'mean': 0.9578292616779349, 'std': 0.00639877...","{'mean': 0.9522525900000002, 'std': 0.00376049...","{'mean': 0.873904717, 'std': 0.008601317312451..."
SphericalScore ExcessRisk outer outer,"{'mean': 0.5, 'std': 2.482534153247273e-17}","{'mean': 0.89110868, 'std': 0.001461348907581685}","{'mean': 0.9377348620928089, 'std': 0.00655154...","{'mean': 0.943039711, 'std': 0.002293229415035...","{'mean': 0.878327284, 'std': 0.007189090450579..."


In [30]:
full_dataframe = pd.read_pickle(
    f"{REPOSITORY_ROOT}/tables/central_tables/full_dataframe.pkl"
)

In [33]:
np.vstack(full_dataframe["Scores"].values[0]).shape

(5, 10000)

In [6]:
full_ood_rocauc.loc[
    (full_ood_rocauc.RiskType == "Bayes")
    & full_ood_rocauc.UQMetric.str.endswith("Outer"),
    "RiskType",
] = "Bayes Outer"
full_ood_rocauc.loc[
    (full_ood_rocauc.RiskType == "Bayes")
    & full_ood_rocauc.UQMetric.str.endswith("Inner"),
    "RiskType",
] = "Bayes Inner"

full_ood_rocauc.loc[
    (full_ood_rocauc.RiskType == "Total")
    & full_ood_rocauc.UQMetric.str.endswith("Outer"),
    "RiskType",
] = "Total Outer"
full_ood_rocauc.loc[
    (full_ood_rocauc.RiskType == "Total")
    & full_ood_rocauc.UQMetric.str.endswith("Inner"),
    "RiskType",
] = "Total Inner"

In [7]:
# trunc_df = full_ood_rocauc[
# ~full_ood_rocauc.RiskType.isin(['Bias', 'MV', 'MVBI', 'BiasBI', 'Bregman Information', 'Reverse Bregman Information']) &
# # full_ood_rocauc.base_rule.isin(['Brier', 'Logscore', 'Spherical']) &
# # full_ood_rocauc.LossFunction.isin(['Brier', 'Logscore', 'Spherical']) &
# ~(np.isclose(full_ood_rocauc.RocAucScore, np.float64(0.5)))
# ]

# # trunc_df.sort_values(by='RocAucScore')

# trunc_df.to_csv(os.path.join('tables', 'full_ood_rocauc_only_risks.csv'), index=False)

In [8]:
# full_ood_rocauc[
# (full_ood_rocauc.RiskType != 'Bias') &
# (full_ood_rocauc.base_rule == 'Neglog')
# ].sort_values(by=['RocAucScore'])

In [9]:
# full_ood_rocauc = full_ood_rocauc[full_ood_rocauc.base_rule != 'Neglog']

In [10]:
grouped_df = extract_same_different_dataframes(
    dataframe_=full_ood_rocauc,
)

In [11]:
same_dict, _ = collect_scores_into_dict(
    dataframes_list=[
        grouped_df.logscore_logscore,
        grouped_df.brier_brier,
        grouped_df.spherical_spherical,
    ],
    ood_detection_pairs=ood_detection_pairs_,
)
same_df = pd.DataFrame.from_dict(same_dict)

same_agg_df = aggregate_over_measures(
    dataframe_=same_df,
    agg_func_="mean",
    by_=["InD", "OOD"],
)

In [12]:
different_dict, _ = collect_scores_into_dict(
    dataframes_list=[
        grouped_df.logscore_not_logscore,
        grouped_df.brier_not_brier,
        grouped_df.spherical_not_spherical,
    ],
    ood_detection_pairs=ood_detection_pairs_,
)
different_df = pd.DataFrame.from_dict(different_dict)

different_agg_df = aggregate_over_measures(
    dataframe_=different_df,
    agg_func_="mean",
    by_=["InD", "OOD"],
)

In [13]:
all_dict, _ = collect_scores_into_dict(
    dataframes_list=[
        full_ood_rocauc,
    ],
    ood_detection_pairs=ood_detection_pairs_,
)
all_df = pd.DataFrame.from_dict(all_dict)

all_agg_df = aggregate_over_measures(
    dataframe_=all_df,
    agg_func_="mean",
    by_=["InD", "OOD"],
)

In [14]:
display(all_agg_df)
display(same_agg_df)
display(different_agg_df)

Unnamed: 0_level_0,Unnamed: 1_level_0,Total Outer,Total Inner,Bayes Inner,Bayes Outer,Excess,Bregman Information,Reverse Bregman Information,Expected Pairwise Bregman Information,Bias,MV,MVBI,BiasBI
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean
InD,OOD,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
cifar10,blurred_cifar10,0.866536,0.85309,0.85309,0.826965,0.868306,0.8678,0.867567,0.869551,0.66066,0.802464,0.865765,0.750017
cifar10,blurred_cifar100,0.956324,0.9508,0.9508,0.935119,0.950863,0.951213,0.949054,0.952321,0.674925,0.863725,0.950518,0.787422
cifar10,cifar100,0.90591,0.90992,0.90992,0.909563,0.885812,0.887123,0.884213,0.8861,0.648434,0.818162,0.88798,0.744753
cifar10,svhn,0.941423,0.945143,0.945143,0.944335,0.92178,0.922266,0.920479,0.922596,0.661076,0.84447,0.924062,0.76313
cifar100,blurred_cifar10,0.878724,0.889896,0.889896,0.864338,0.784151,0.787492,0.769291,0.79567,0.541619,0.730607,0.807259,0.620265
cifar100,blurred_cifar100,0.747711,0.725706,0.725706,0.695672,0.73775,0.728504,0.739952,0.744794,0.568505,0.699635,0.744755,0.62885
cifar100,cifar10,0.752243,0.791126,0.791126,0.788203,0.658923,0.664299,0.650057,0.662414,0.482271,0.623093,0.671924,0.541911
cifar100,svhn,0.803006,0.849334,0.849334,0.848893,0.671667,0.679762,0.660314,0.674925,0.486484,0.625822,0.680426,0.539022


Unnamed: 0_level_0,Unnamed: 1_level_0,Total Outer,Total Inner,Bayes Inner,Bayes Outer,Excess,Bregman Information,Reverse Bregman Information,Expected Pairwise Bregman Information,Bias,MV,MVBI,BiasBI
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean
InD,OOD,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
cifar10,blurred_cifar10,0.862998,0.854926,0.854926,0.832594,0.878007,0.876301,0.879345,0.878376,0.742875,0.876733,0.876695,0.879063
cifar10,blurred_cifar100,0.956178,0.952183,0.952183,0.937482,0.956472,0.956019,0.956756,0.956642,0.788508,0.954771,0.955415,0.957182
cifar10,cifar100,0.907467,0.909927,0.909927,0.911068,0.90077,0.901993,0.899631,0.900687,0.754901,0.899628,0.900749,0.900815
cifar10,svhn,0.943443,0.945436,0.945436,0.945453,0.933077,0.933888,0.932275,0.933069,0.776912,0.93176,0.932753,0.933303
cifar100,blurred_cifar10,0.914538,0.89184,0.89184,0.867418,0.853572,0.858332,0.844763,0.85762,0.705564,0.84149,0.856035,0.860863
cifar100,blurred_cifar100,0.756761,0.726193,0.726193,0.701755,0.775527,0.767559,0.780346,0.778675,0.683394,0.767353,0.770708,0.779179
cifar100,cifar10,0.790967,0.794445,0.794445,0.792496,0.726959,0.734864,0.718211,0.727801,0.609855,0.722217,0.730338,0.72968
cifar100,svhn,0.843255,0.848508,0.848508,0.849544,0.73994,0.750673,0.728042,0.741104,0.625403,0.729552,0.742918,0.746241


Unnamed: 0_level_0,Unnamed: 1_level_0,Total Outer,Total Inner,Bayes Inner,Bayes Outer,Excess,Bregman Information,Reverse Bregman Information,Expected Pairwise Bregman Information,Bias,MV,MVBI,BiasBI
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean
InD,OOD,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
cifar10,blurred_cifar10,0.86742,0.852631,0.852631,0.825558,0.865881,0.865675,0.864623,0.867344,0.640106,0.783897,0.863032,0.717755
cifar10,blurred_cifar100,0.956361,0.950454,0.950454,0.934528,0.94946,0.950012,0.947128,0.951241,0.64653,0.840964,0.949293,0.744982
cifar10,cifar100,0.905521,0.909918,0.909918,0.909187,0.882072,0.883405,0.880358,0.882453,0.621818,0.797796,0.884787,0.705737
cifar10,svhn,0.940918,0.945069,0.945069,0.944056,0.918956,0.919361,0.91753,0.919977,0.632117,0.822647,0.92189,0.720587
cifar100,blurred_cifar10,0.869771,0.88941,0.88941,0.863569,0.766796,0.769782,0.750424,0.780182,0.500633,0.702887,0.795065,0.560116
cifar100,blurred_cifar100,0.745449,0.725584,0.725584,0.694151,0.728306,0.71874,0.729854,0.736324,0.539783,0.682705,0.738267,0.591268
cifar100,cifar10,0.742562,0.790296,0.790296,0.787129,0.641915,0.646658,0.633018,0.646067,0.450375,0.598313,0.65732,0.494969
cifar100,svhn,0.792944,0.849541,0.849541,0.848731,0.654599,0.662034,0.643382,0.658381,0.451754,0.59989,0.664803,0.487217


In [15]:
(same_agg_df - different_agg_df) / different_agg_df * 100

Unnamed: 0_level_0,Unnamed: 1_level_0,Total Outer,Total Inner,Bayes Inner,Bayes Outer,Excess,Bregman Information,Reverse Bregman Information,Expected Pairwise Bregman Information,Bias,MV,MVBI,BiasBI
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean
InD,OOD,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
cifar10,blurred_cifar10,-0.509738,0.269179,0.269179,0.852162,1.400479,1.22738,1.702765,1.271907,16.054943,11.842995,1.583135,22.473856
cifar10,blurred_cifar100,-0.019136,0.181881,0.181881,0.316115,0.738526,0.63231,1.016533,0.567799,21.960057,13.532907,0.644933,28.483916
cifar10,cifar100,0.214914,0.001027,0.001027,0.206975,2.119807,2.10412,2.189179,2.066305,21.402198,12.764164,1.803994,27.641685
cifar10,svhn,0.268292,0.038803,0.038803,0.147923,1.536685,1.58017,1.607072,1.42303,22.906476,13.26357,1.178338,29.51983
cifar100,blurred_cifar10,5.147023,0.273158,0.273158,0.445772,11.31668,11.503281,12.571472,9.925638,40.934306,19.719178,7.668554,53.693719
cifar100,blurred_cifar100,1.517402,0.083983,0.083983,1.095461,6.483645,6.79228,6.918211,5.751633,26.605389,12.398986,4.39412,31.780974
cifar100,cifar10,6.518565,0.525031,0.525031,0.68181,13.248542,13.640329,13.458176,12.650998,35.410572,20.709056,11.108369,47.419366
cifar100,svhn,6.344892,-0.121529,-0.121529,0.095839,13.037129,13.388909,13.158715,12.564581,38.438683,21.614327,11.750093,53.163875


In [16]:
full_scores = pd.read_csv(
    "./tables/full_dataframe.csv",
)

In [19]:
def enhance_latex_table(input_latex):
    lines = input_latex.split("\n")
    enhanced_lines = []

    for i, line in enumerate(lines):
        if "\\begin{tabular}" in line:
            # Start centering the table
            enhanced_lines.append(r"\begin{center}")

        if "\\toprule" in line:
            # Add multicolumn headers
            enhanced_lines.append(line)
            enhanced_lines.append(
                r"\multicolumn{2}{c}{Dataset} & \multicolumn{5}{c}{Metrics} \\"
            )
            enhanced_lines.append(r"\cmidrule(lr){1-2} \cmidrule(lr){3-7}")
            continue

        # Add row coloring
        if "\\midrule" in line:
            enhanced_lines.append(line)
            enhanced_lines.append(r"\rowcolor{gray!10}")
        elif "\\bottomrule" in line:
            enhanced_lines.append(r"\end{tabular}")
            enhanced_lines.append(r"\end{center}")
        else:
            enhanced_lines.append(line)

    return "\n".join(enhanced_lines)

In [20]:
index_pairs = [
    ("CIFAR10", "Blurred CIFAR10"),
    ("CIFAR10", "Blurred CIFAR100"),
    ("CIFAR10", "CIFAR100"),
    ("CIFAR10", "SVHN"),
    ("CIFAR100", "Blurred CIFAR10"),
    ("CIFAR100", "Blurred CIFAR100"),
    ("CIFAR100", "CIFAR10"),
    ("CIFAR100", "SVHN"),
]


def get_nice_df(df_):
    df_.index = pd.MultiIndex.from_tuples(index_pairs, names=["InD", "OOD"])
    df_.columns = [
        # 'Bayes',
        # 'Excess',
        # 'Total',
        "Bayes(O)",
        "Bayes(I)",
        "Total(O)",
        "Total(I)",
        "BI",
        "RBI",
        "EPBI",
        # 'Bias',
        # 'MV',
        # 'MVBI',
        # 'BiasBI',
    ]
    # df_ = df_[['Bayes', 'Excess', 'Total', 'BI', 'RBI']]
    df_ = (100 * df_).round(2)

    display(df_)

    return df_, df_.to_latex(float_format="%.2f")

In [21]:
# measures = [c for c in same_agg_df.columns if c not in ['OOD', 'InD', 'ScoringRule']]
# measures

measures = [
    "Bayes Outer",
    "Bayes Inner",
    "Total Outer",
    "Total Inner",
    "Bregman Information",
    "Reverse Bregman Information",
    "Expected Pairwise Bregman Information",
]

# measures = ['Bayes', 'Excess', 'Total', 'Bregman Information', 'Reverse Bregman Information', 'Expected Pairwise Bregman Information']

In [22]:
nice_same = get_nice_df(same_agg_df[measures].copy())
enhanced_latex = enhance_latex_table(nice_same[1])
print(enhanced_latex)

Unnamed: 0_level_0,Unnamed: 1_level_0,Bayes(O),Bayes(I),Total(O),Total(I),BI,RBI,EPBI
InD,OOD,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
CIFAR10,Blurred CIFAR10,83.26,85.49,86.3,85.49,87.63,87.93,87.84
CIFAR10,Blurred CIFAR100,93.75,95.22,95.62,95.22,95.6,95.68,95.66
CIFAR10,CIFAR100,91.11,90.99,90.75,90.99,90.2,89.96,90.07
CIFAR10,SVHN,94.55,94.54,94.34,94.54,93.39,93.23,93.31
CIFAR100,Blurred CIFAR10,86.74,89.18,91.45,89.18,85.83,84.48,85.76
CIFAR100,Blurred CIFAR100,70.18,72.62,75.68,72.62,76.76,78.03,77.87
CIFAR100,CIFAR10,79.25,79.44,79.1,79.44,73.49,71.82,72.78
CIFAR100,SVHN,84.95,84.85,84.33,84.85,75.07,72.8,74.11


\begin{center}
\begin{tabular}{llrrrrrrr}
\toprule
\multicolumn{2}{c}{Dataset} & \multicolumn{5}{c}{Metrics} \\
\cmidrule(lr){1-2} \cmidrule(lr){3-7}
 &  & Bayes(O) & Bayes(I) & Total(O) & Total(I) & BI & RBI & EPBI \\
InD & OOD &  &  &  &  &  &  &  \\
\midrule
\rowcolor{gray!10}
\multirow[t]{4}{*}{CIFAR10} & Blurred CIFAR10 & 83.26 & 85.49 & 86.30 & 85.49 & 87.63 & 87.93 & 87.84 \\
 & Blurred CIFAR100 & 93.75 & 95.22 & 95.62 & 95.22 & 95.60 & 95.68 & 95.66 \\
 & CIFAR100 & 91.11 & 90.99 & 90.75 & 90.99 & 90.20 & 89.96 & 90.07 \\
 & SVHN & 94.55 & 94.54 & 94.34 & 94.54 & 93.39 & 93.23 & 93.31 \\
\cline{1-9}
\multirow[t]{4}{*}{CIFAR100} & Blurred CIFAR10 & 86.74 & 89.18 & 91.45 & 89.18 & 85.83 & 84.48 & 85.76 \\
 & Blurred CIFAR100 & 70.18 & 72.62 & 75.68 & 72.62 & 76.76 & 78.03 & 77.87 \\
 & CIFAR10 & 79.25 & 79.44 & 79.10 & 79.44 & 73.49 & 71.82 & 72.78 \\
 & SVHN & 84.95 & 84.85 & 84.33 & 84.85 & 75.07 & 72.80 & 74.11 \\
\cline{1-9}
\end{tabular}
\end{center}
\end{tabular}



In [23]:
nice_same[0].std()

Bayes(O)    8.124515
Bayes(I)    7.692756
Total(O)    7.170457
Total(I)    7.692756
BI          8.584582
RBI         9.116546
EPBI        8.741953
dtype: float64

In [24]:
nice_same[0].mean()

Bayes(O)    85.47375
Bayes(I)    86.54125
Total(O)    87.19625
Total(I)    86.54125
BI          84.74625
RBI         84.24125
EPBI        84.67500
dtype: float64

In [25]:
nice_different = get_nice_df(different_agg_df[measures].copy())
enhanced_latex = enhance_latex_table(nice_different[1])
print(enhanced_latex)

Unnamed: 0_level_0,Unnamed: 1_level_0,Bayes(O),Bayes(I),Total(O),Total(I),BI,RBI,EPBI
InD,OOD,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
CIFAR10,Blurred CIFAR10,82.56,85.26,86.74,85.26,86.57,86.46,86.73
CIFAR10,Blurred CIFAR100,93.45,95.05,95.64,95.05,95.0,94.71,95.12
CIFAR10,CIFAR100,90.92,90.99,90.55,90.99,88.34,88.04,88.25
CIFAR10,SVHN,94.41,94.51,94.09,94.51,91.94,91.75,92.0
CIFAR100,Blurred CIFAR10,86.36,88.94,86.98,88.94,76.98,75.04,78.02
CIFAR100,Blurred CIFAR100,69.42,72.56,74.54,72.56,71.87,72.99,73.63
CIFAR100,CIFAR10,78.71,79.03,74.26,79.03,64.67,63.3,64.61
CIFAR100,SVHN,84.87,84.95,79.29,84.95,66.2,64.34,65.84


\begin{center}
\begin{tabular}{llrrrrrrr}
\toprule
\multicolumn{2}{c}{Dataset} & \multicolumn{5}{c}{Metrics} \\
\cmidrule(lr){1-2} \cmidrule(lr){3-7}
 &  & Bayes(O) & Bayes(I) & Total(O) & Total(I) & BI & RBI & EPBI \\
InD & OOD &  &  &  &  &  &  &  \\
\midrule
\rowcolor{gray!10}
\multirow[t]{4}{*}{CIFAR10} & Blurred CIFAR10 & 82.56 & 85.26 & 86.74 & 85.26 & 86.57 & 86.46 & 86.73 \\
 & Blurred CIFAR100 & 93.45 & 95.05 & 95.64 & 95.05 & 95.00 & 94.71 & 95.12 \\
 & CIFAR100 & 90.92 & 90.99 & 90.55 & 90.99 & 88.34 & 88.04 & 88.25 \\
 & SVHN & 94.41 & 94.51 & 94.09 & 94.51 & 91.94 & 91.75 & 92.00 \\
\cline{1-9}
\multirow[t]{4}{*}{CIFAR100} & Blurred CIFAR10 & 86.36 & 88.94 & 86.98 & 88.94 & 76.98 & 75.04 & 78.02 \\
 & Blurred CIFAR100 & 69.42 & 72.56 & 74.54 & 72.56 & 71.87 & 72.99 & 73.63 \\
 & CIFAR10 & 78.71 & 79.03 & 74.26 & 79.03 & 64.67 & 63.30 & 64.61 \\
 & SVHN & 84.87 & 84.95 & 79.29 & 84.95 & 66.20 & 64.34 & 65.84 \\
\cline{1-9}
\end{tabular}
\end{center}
\end{tabular}



In [26]:
nice_different[0].mean()

Bayes(O)    85.08750
Bayes(I)    86.41125
Total(O)    85.26125
Total(I)    86.41125
BI          80.19625
RBI         79.57875
EPBI        80.52500
dtype: float64

In [27]:
nice_different[0].std()

Bayes(O)     8.324288
Bayes(I)     7.721749
Total(O)     8.370061
Total(I)     7.721749
BI          11.836036
RBI         12.289760
EPBI        11.745028
dtype: float64

In [28]:
print(
    enhance_latex_table(
        pd.concat([nice_same[0], nice_different[0]], axis=1).to_latex(
            float_format="%.2f"
        )
    )
)

\begin{center}
\begin{tabular}{llrrrrrrrrrrrrrr}
\toprule
\multicolumn{2}{c}{Dataset} & \multicolumn{5}{c}{Metrics} \\
\cmidrule(lr){1-2} \cmidrule(lr){3-7}
 &  & Bayes(O) & Bayes(I) & Total(O) & Total(I) & BI & RBI & EPBI & Bayes(O) & Bayes(I) & Total(O) & Total(I) & BI & RBI & EPBI \\
InD & OOD &  &  &  &  &  &  &  &  &  &  &  &  &  &  \\
\midrule
\rowcolor{gray!10}
\multirow[t]{4}{*}{CIFAR10} & Blurred CIFAR10 & 83.26 & 85.49 & 86.30 & 85.49 & 87.63 & 87.93 & 87.84 & 82.56 & 85.26 & 86.74 & 85.26 & 86.57 & 86.46 & 86.73 \\
 & Blurred CIFAR100 & 93.75 & 95.22 & 95.62 & 95.22 & 95.60 & 95.68 & 95.66 & 93.45 & 95.05 & 95.64 & 95.05 & 95.00 & 94.71 & 95.12 \\
 & CIFAR100 & 91.11 & 90.99 & 90.75 & 90.99 & 90.20 & 89.96 & 90.07 & 90.92 & 90.99 & 90.55 & 90.99 & 88.34 & 88.04 & 88.25 \\
 & SVHN & 94.55 & 94.54 & 94.34 & 94.54 & 93.39 & 93.23 & 93.31 & 94.41 & 94.51 & 94.09 & 94.51 & 91.94 & 91.75 & 92.00 \\
\cline{1-16}
\multirow[t]{4}{*}{CIFAR100} & Blurred CIFAR10 & 86.74 & 89.18 & 91.45

In [29]:
(same_agg_df - all_agg_df) > 0

Unnamed: 0_level_0,Unnamed: 1_level_0,Total Outer,Total Inner,Bayes Inner,Bayes Outer,Excess,Bregman Information,Reverse Bregman Information,Expected Pairwise Bregman Information,Bias,MV,MVBI,BiasBI
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean,mean
InD,OOD,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
cifar10,blurred_cifar10,False,True,True,True,True,True,True,True,True,True,True,True
cifar10,blurred_cifar100,False,True,True,True,True,True,True,True,True,True,True,True
cifar10,cifar100,True,True,True,True,True,True,True,True,True,True,True,True
cifar10,svhn,True,True,True,True,True,True,True,True,True,True,True,True
cifar100,blurred_cifar10,True,True,True,True,True,True,True,True,True,True,True,True
cifar100,blurred_cifar100,True,True,True,True,True,True,True,True,True,True,True,True
cifar100,cifar10,True,True,True,True,True,True,True,True,True,True,True,True
cifar100,svhn,True,False,False,True,True,True,True,True,True,True,True,True
