In [1]:
from source.source.postprocessing_utils import (
    get_sampled_combinations_uncertainty_scores,
    get_predicted_labels,
    get_missclassification_dataframe,
    get_ood_detection_dataframe,
    get_raw_scores_dataframe,
)

from source.datasets.constants import DatasetName
from source.losses.constants import LossName
from source.models.constants import ModelName, ModelSource
from source.metrics import (
    ApproximationType,
    GName,
    RiskType,
)

import pandas as pd
import numpy as np
import pickle
from source.source.path_config import REPOSITORY_ROOT
import os


pd.set_option("display.max_rows", None)

  from .autonotebook import tqdm as notebook_tqdm
stty: 'standard input': Inappropriate ioctl for device


In [2]:
from source.source.postprocessing_utils import remove_and_expand_list

In [3]:
training_dataset_names = [
    DatasetName.CIFAR10.value,
    DatasetName.CIFAR100.value,
    # DatasetName.CIFAR10_NOISY_LABEL.value,
    # DatasetName.CIFAR100_NOISY_LABEL.value,
]
temperature = 1.0
model_ids = np.arange(20)

# list_extraction_datasets = [
#     # DatasetName.CIFAR10C.value,
#     # DatasetName.CIFAR100C.value,
#     DatasetName.CIFAR10.value,
#     DatasetName.CIFAR100.value,
#     DatasetName.SVHN.value,
#     DatasetName.TINY_IMAGENET.value,
#     DatasetName.CIFAR10_BLURRED.value,
#     DatasetName.CIFAR100_BLURRED.value,
# ]
# list_extraction_datasets = remove_and_expand_list(list_extraction_datasets)
# list_ood_datasets = [el for el in list_extraction_datasets]
loss_function_names = [el for el in LossName]

full_dataframe = None
full_ood_rocauc_dataframe = None
full_mis_rocauc_dataframe = None

In [4]:
for training_dataset_name in training_dataset_names:


    if training_dataset_name == DatasetName.CIFAR10.value:
        list_extraction_datasets = [
            DatasetName.CIFAR10C.value,
            # DatasetName.CIFAR100C.value,
            DatasetName.CIFAR10.value,
            DatasetName.CIFAR100.value,
            DatasetName.SVHN.value,
            DatasetName.TINY_IMAGENET.value,
            DatasetName.CIFAR10_BLURRED.value,
            DatasetName.CIFAR100_BLURRED.value,
        ]

        list_extraction_datasets = remove_and_expand_list(list_extraction_datasets)
        list_ood_datasets = [el for el in list_extraction_datasets]
    else:
        list_extraction_datasets = [
            # DatasetName.CIFAR10C.value,
            # DatasetName.CIFAR100C.value,
            DatasetName.CIFAR10.value,
            DatasetName.CIFAR100.value,
            DatasetName.SVHN.value,
            DatasetName.TINY_IMAGENET.value,
            DatasetName.CIFAR10_BLURRED.value,
            DatasetName.CIFAR100_BLURRED.value,
        ]
        list_extraction_datasets = remove_and_expand_list(list_extraction_datasets)
        list_ood_datasets = [el for el in list_extraction_datasets]
    
    # architectures = [ModelName.RESNET18, ModelName.VGG19]
    architectures = [ModelName.RESNET18]
    if training_dataset_name not in [
        DatasetName.CIFAR10_NOISY_LABEL.value,
        DatasetName.CIFAR100_NOISY_LABEL.value,
    ]:
        training_dataset_name_aux = training_dataset_name
    else:
        training_dataset_name_aux = training_dataset_name.split("_")[-1]
    for architecture in architectures:

        uq_results, embeddings_per_dataset, targets_per_dataset = (
            get_sampled_combinations_uncertainty_scores(
                loss_function_names=loss_function_names,
                training_dataset_name=training_dataset_name,
                architecture=architecture,
                model_ids=model_ids,
                list_extraction_datasets=list_extraction_datasets,
                temperature=temperature,
                model_source=ModelSource.OUR_MODELS.value,
                use_cached=True,
            )
        )

        df_ood = get_ood_detection_dataframe(
            ind_dataset=training_dataset_name_aux,
            uq_results=uq_results,
            list_ood_datasets=list_ood_datasets,
        )

        max_ind = int(
            targets_per_dataset[training_dataset_name_aux].shape[0] / len(model_ids)
        )
        true_labels = targets_per_dataset[training_dataset_name_aux][:max_ind]

        pred_labels = get_predicted_labels(
            embeddings_per_dataset=embeddings_per_dataset,
            training_dataset_name=training_dataset_name_aux,
        )

        df_misclassification = get_missclassification_dataframe(
            ind_dataset=training_dataset_name_aux,
            uq_results=uq_results,
            true_labels=true_labels,
            pred_labels=pred_labels,
        )

        scores_df_unravel = get_raw_scores_dataframe(uq_results=uq_results)
        scores_df_unravel["architecture"] = architecture.value
        scores_df_unravel["training_dataset"] = training_dataset_name
        df_ood["architecture"] = architecture.value
        df_ood["training_dataset"] = training_dataset_name
        df_misclassification["architecture"] = architecture.value
        df_misclassification["training_dataset"] = training_dataset_name

        if full_dataframe is None:
            full_dataframe = scores_df_unravel
            full_ood_rocauc_dataframe = df_ood
            full_mis_rocauc_dataframe = df_misclassification
        else:
            full_dataframe = pd.concat([full_dataframe, scores_df_unravel])
            full_ood_rocauc_dataframe = pd.concat([full_ood_rocauc_dataframe, df_ood])
            full_mis_rocauc_dataframe = pd.concat(
                [full_mis_rocauc_dataframe, df_misclassification]
            )

In [5]:
uq_results['LogScore energy outer']['CrossEntropy'].keys()

dict_keys(['cifar10', 'cifar100', 'svhn', 'tiny_imagenet', 'blurred_cifar10', 'blurred_cifar100'])

In [6]:
full_dataframe.sample(10)

Unnamed: 0,UQMetric,LossFunction,Dataset,Scores,architecture,training_dataset
972,ZeroOneScore ExcessRisk inner outer,CrossEntropy,cifar10,"[[0.07764289, 0.010377191, 0.109281726, 0.0674...",resnet18,cifar100
203,LogScore TotalRisk central outer,CrossEntropy,blurred_cifar100,"[[0.76353204, 1.5226543, 2.150982, 2.283438, 1...",resnet18,cifar10
1721,ZeroOneScore ExcessRisk outer inner,CrossEntropy,blurred_cifar100,"[[0.20656396, 0.007073313, 0.2654513, 0.285878...",resnet18,cifar10
2106,SphericalScore TotalRisk outer outer,SphericalScore,blurred_cifar100,"[[0.44282287, 0.5256607, 0.6059738, 0.65760016...",resnet18,cifar10
369,LogScore BayesRisk central,BrierScore,tiny_imagenet,"[[4.406076, 4.410195, 4.257076, 4.380786, 4.35...",resnet18,cifar100
2108,SphericalScore TotalRisk outer outer,SphericalScore,cifar10c_2,"[[0.6518314, 0.7633928, 0.715951, 0.42040896, ...",resnet18,cifar10
54,LogScore TotalRisk inner outer,CrossEntropy,cifar10,"[[2.0534458, 2.8518505, 3.1770434, 3.870422, 3...",resnet18,cifar100
2301,SphericalScore TotalRisk central outer,SphericalScore,svhn,"[[0.5756672344107909, 0.565931580893034, 0.551...",resnet18,cifar10
2257,SphericalScore TotalRisk inner central,BrierScore,svhn,"[[0.42734102276277497, 0.4957879162095734, 0.4...",resnet18,cifar10
774,ZeroOneScore TotalRisk outer inner,CrossEntropy,cifar10,"[[0.43504608, 0.72646177, 0.6964442, 0.7666945...",resnet18,cifar100


In [7]:
full_ood_rocauc_dataframe.Dataset.unique()

array(['cifar10', 'cifar100', 'svhn', 'tiny_imagenet', 'blurred_cifar10',
       'blurred_cifar100', 'cifar10c_1', 'cifar10c_2', 'cifar10c_3',
       'cifar10c_4', 'cifar10c_5'], dtype=object)

In [8]:
pattern_baserule = r"(LogScore|BrierScore|ZeroOneScore|SphericalScore)"
pattern_risk = r"(outer outer|outer inner|outer central|inner outer|inner inner|inner central|central outer|central inner|central central|energy inner|energy outer|outer|inner|central)"

full_ood_rocauc_dataframe["base_rule"] = full_ood_rocauc_dataframe[
    "UQMetric"
].str.extract(pattern_baserule)
full_ood_rocauc_dataframe["RiskType"] = full_ood_rocauc_dataframe[
    "UQMetric"
].str.extract(pattern_risk)

full_mis_rocauc_dataframe["base_rule"] = full_mis_rocauc_dataframe[
    "UQMetric"
].str.extract(pattern_baserule)
full_mis_rocauc_dataframe["RiskType"] = full_mis_rocauc_dataframe[
    "UQMetric"
].str.extract(pattern_risk)

full_dataframe["base_rule"] = full_dataframe["UQMetric"].str.extract(pattern_baserule)
full_dataframe["RiskType"] = full_dataframe["UQMetric"].str.extract(pattern_risk)

In [9]:
full_dataframe.sample(10)

Unnamed: 0,UQMetric,LossFunction,Dataset,Scores,architecture,training_dataset,base_rule,RiskType
395,LogScore ExcessRisk outer central,SphericalScore,cifar10c_5,"[[0.6235098, 0.85425436, 0.9694624, 1.3917482,...",resnet18,cifar10,LogScore,outer central
1130,ZeroOneScore BayesRisk central,SphericalScore,svhn,"[[0.91857195, 0.84478194, 0.8735173, 0.9314295...",resnet18,cifar100,ZeroOneScore,central
1025,BrierScore ExcessRisk outer inner,CrossEntropy,svhn,"[[0.26314476, 0.42994598, 0.36967888, 0.292465...",resnet18,cifar10,BrierScore,outer inner
352,LogScore BayesRisk inner,BrierScore,blurred_cifar10,"[[3.2149823, 4.3325047, 2.4687939, 4.021557, 3...",resnet18,cifar100,LogScore,inner
1038,ZeroOneScore ExcessRisk central outer,SphericalScore,cifar10,"[[0.035975236, 0.031987347, 0.008710482, 0.012...",resnet18,cifar100,ZeroOneScore,central outer
974,ZeroOneScore ExcessRisk inner outer,CrossEntropy,svhn,"[[0.16517946, 0.0, 0.13314305, 0.066269964, 7....",resnet18,cifar100,ZeroOneScore,inner outer
1701,ZeroOneScore ExcessRisk outer outer,BrierScore,cifar10c_2,"[[0.6040471, 0.661398, 0.6624508, 0.46783128, ...",resnet18,cifar10,ZeroOneScore,outer outer
839,ZeroOneScore TotalRisk inner inner,BrierScore,blurred_cifar100,"[[0.8441912, 0.65572673, 0.4625911, 0.88744223...",resnet18,cifar100,ZeroOneScore,inner inner
338,LogScore ExcessRisk outer inner,CrossEntropy,cifar10c_3,"[[1.3458617, 1.2578782, 0.7609928, 1.0299634, ...",resnet18,cifar10,LogScore,outer inner
1456,SphericalScore ExcessRisk central central,SphericalScore,blurred_cifar10,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",resnet18,cifar100,SphericalScore,central central


In [10]:
full_ood_rocauc_dataframe.sample(10)

Unnamed: 0,UQMetric,Dataset,LossFunction,RocAucScores_array,architecture,training_dataset,base_rule,RiskType
1241,BrierScore ExcessRisk central inner,cifar10c_1,SphericalScore,"[0.5, 0.5, 0.5, 0.5, 0.5]",resnet18,cifar10,BrierScore,central inner
415,BrierScore TotalRisk outer central,cifar10,BrierScore,"[0.804186665, 0.803577, 0.79764932, 0.80111374...",resnet18,cifar100,BrierScore,outer central
2487,SphericalScore ExcessRisk inner outer,blurred_cifar10,CrossEntropy,"[0.879263605, 0.8769025899999999, 0.8854249049...",resnet18,cifar10,SphericalScore,inner outer
1629,ZeroOneScore TotalRisk central inner,blurred_cifar10,CrossEntropy,"[0.87418027, 0.8735818, 0.882402825, 0.8772126...",resnet18,cifar10,ZeroOneScore,central inner
161,LogScore TotalRisk central central,blurred_cifar100,SphericalScore,"[0.7435547450000001, 0.7402604199999999, 0.749...",resnet18,cifar100,LogScore,central central
2683,SphericalScore BayesRisk outer,tiny_imagenet,BrierScore,"[0.94782045, 0.94602842, 0.9491533299999998, 0...",resnet18,cifar10,SphericalScore,outer
1234,BrierScore ExcessRisk central inner,blurred_cifar10,BrierScore,"[0.5, 0.5, 0.5, 0.5, 0.5]",resnet18,cifar10,BrierScore,central inner
1124,ZeroOneScore BayesRisk central,svhn,SphericalScore,"[0.864843990089121, 0.8578520513214505, 0.8680...",resnet18,cifar100,ZeroOneScore,central
1823,ZeroOneScore ExcessRisk inner inner,svhn,SphericalScore,"[0.5, 0.5, 0.5, 0.5, 0.5]",resnet18,cifar10,ZeroOneScore,inner inner
2260,SphericalScore TotalRisk inner central,blurred_cifar100,BrierScore,"[0.94810841, 0.95029975, 0.9559166900000001, 0...",resnet18,cifar10,SphericalScore,inner central


In [11]:
# full_dataframe.to_csv('../../tables/central_tables/full_dataframe.csv')
# full_ood_rocauc_dataframe.to_csv('../../tables/central_tables/full_ood_rocauc.csv')
# full_mis_rocauc_dataframe.to_csv('../../tables/central_tables/full_mis_rocauc.csv')

In [12]:
full_dataframe.to_pickle(os.path.join(REPOSITORY_ROOT, "tables/central_tables/full_dataframe.pkl"))
full_ood_rocauc_dataframe.to_pickle(os.path.join(REPOSITORY_ROOT, "tables/central_tables/full_ood_rocauc.pkl"))
full_mis_rocauc_dataframe.to_pickle(os.path.join(REPOSITORY_ROOT, "tables/central_tables/full_mis_rocauc.pkl"))

In [13]:
# full_dataframe.to_pickle(os.path.join(REPOSITORY_ROOT, "tables/central_tables/full_dataframe.csv"))
# full_ood_rocauc_dataframe.to_pickle(os.path.join(REPOSITORY_ROOT, "tables/central_tables/full_ood_rocauc.csv"))
# full_mis_rocauc_dataframe.to_pickle(os.path.join(REPOSITORY_ROOT, "tables/central_tables/full_mis_rocauc.csv"))