In [1]:
from source.source.postprocessing_utils import (
    get_sampled_combinations_uncertainty_scores,
    get_predicted_labels,
    get_missclassification_dataframe,
    get_ood_detection_dataframe,
    get_raw_scores_dataframe,
)

from source.datasets.constants import DatasetName
from source.losses.constants import LossName
from source.models.constants import ModelName
from source.metrics import (
    ApproximationType,
    GName,
    RiskType,
)

import pandas as pd
import numpy as np

pd.set_option("display.max_rows", None)

stty: 'standard input': Inappropriate ioctl for device


In [2]:
training_dataset_names = [
    "cifar10",
    "cifar100",
    "noisy_cifar100",
    "noisy_cifar10",
]
temperature = 1.0
model_ids = np.arange(20)

list_extraction_datasets = [
    "cifar10",
    "cifar100",
    "svhn",
    "blurred_cifar100",
    "blurred_cifar10",
]
list_ood_datasets = [el for el in list_extraction_datasets]
loss_function_names = [el for el in LossName]

full_dataframe = None
full_ood_rocauc_dataframe = None
full_mis_rocauc_dataframe = None

In [3]:
for training_dataset_name in training_dataset_names:
    architectures = [ModelName.RESNET18, ModelName.VGG19]
    if training_dataset_name not in [
        "noisy_cifar10",
        "noisy_cifar100",
    ]:
        training_dataset_name_aux = training_dataset_name
    else:
        # architectures = [ModelName.RESNET18]
        training_dataset_name_aux = training_dataset_name.split("_")[-1]
    for architecture in architectures:
        # try:
        uq_results, embeddings_per_dataset, targets_per_dataset = (
            get_sampled_combinations_uncertainty_scores(
                loss_function_names=loss_function_names,
                training_dataset_name=training_dataset_name,
                architecture=architecture,
                model_ids=model_ids,
                list_extraction_datasets=list_extraction_datasets,
                temperature=temperature,
                use_cached=True,
            )
        )

        df_ood = get_ood_detection_dataframe(
            ind_dataset=training_dataset_name_aux,
            uq_results=uq_results,
            list_ood_datasets=list_ood_datasets,
        )

        max_ind = int(targets_per_dataset[training_dataset_name_aux].shape[0] / len(model_ids))
        true_labels = targets_per_dataset[training_dataset_name_aux][:max_ind]
        
        pred_labels = get_predicted_labels(
            embeddings_per_dataset=embeddings_per_dataset,
            training_dataset_name=training_dataset_name_aux,
        )

        df_misclassification = get_missclassification_dataframe(
            ind_dataset=training_dataset_name_aux,
            uq_results=uq_results,
            true_labels=true_labels,
            pred_labels=pred_labels,
        )

        scores_df_unravel = get_raw_scores_dataframe(uq_results=uq_results)
        scores_df_unravel['architecture'] = architecture.value
        scores_df_unravel['training_dataset'] = training_dataset_name
        df_ood['architecture'] = architecture.value
        df_ood['training_dataset'] = training_dataset_name
        df_misclassification['architecture'] = architecture.value
        df_misclassification['training_dataset'] = training_dataset_name

        if full_dataframe is None:
            full_dataframe = scores_df_unravel
            full_ood_rocauc_dataframe = df_ood
            full_mis_rocauc_dataframe = df_misclassification
        else:
            full_dataframe = pd.concat([full_dataframe, scores_df_unravel])
            full_ood_rocauc_dataframe = pd.concat([full_ood_rocauc_dataframe, df_ood])
            full_mis_rocauc_dataframe = pd.concat([full_mis_rocauc_dataframe, df_misclassification])

In [4]:
full_dataframe.sample(10)

Unnamed: 0,UQMetric,LossFunction,Dataset,Scores,architecture,training_dataset
115,LogScore TotalRisk central inner,SphericalScore,cifar10,"[[4.4017296, 4.2440276, 4.4650974, 4.46192, 4....",resnet18,noisy_cifar100
266,LogScore ExcessRisk central central,SphericalScore,cifar100,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",resnet18,cifar100
159,LogScore ExcessRisk outer inner,BrierScore,blurred_cifar10,"[[0.5813884, 0.11008257, 0.6064113, 0.29617035...",resnet18,cifar100
684,ZeroOneScore TotalRisk inner outer,BrierScore,blurred_cifar10,"[[0.60292804, 0.572847, 0.7184887, 0.5714789, ...",resnet18,noisy_cifar10
58,LogScore TotalRisk inner outer,SphericalScore,blurred_cifar100,"[[1.5256356, 2.0666738, 1.7377801, 2.3354049, ...",resnet18,noisy_cifar10
1085,SphericalScore ExcessRisk outer outer,BrierScore,cifar10,"[[2.0140214e-06, 2.46303e-06, 2.191584e-05, 3....",vgg19,cifar10
545,BrierScore ExcessRisk central outer,BrierScore,cifar10,"[[0.018202525, 0.07262258, 0.18093716, 0.15145...",vgg19,noisy_cifar100
440,BrierScore TotalRisk central central,BrierScore,cifar10,"[[0.013322556, 0.020798609, 0.038993478, 0.031...",vgg19,cifar10
424,BrierScore TotalRisk central inner,CrossEntropy,blurred_cifar10,"[[0.81523675, 0.6605634, 0.73896885, 0.787258,...",vgg19,cifar100
898,ZeroOneScore ExcessRisk central central,SphericalScore,blurred_cifar100,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",vgg19,noisy_cifar10


In [5]:
pattern_baserule = r'(LogScore|BrierScore|ZeroOneScore|SphericalScore)'
pattern_risk = r'(outer outer|outer inner|outer central|inner outer|inner inner|inner central|central outer|central inner|central central|energy inner|energy outer|outer|inner|central)'

full_ood_rocauc_dataframe['base_rule'] = full_ood_rocauc_dataframe['UQMetric'].str.extract(pattern_baserule)
full_ood_rocauc_dataframe['RiskType'] = full_ood_rocauc_dataframe['UQMetric'].str.extract(pattern_risk)

full_mis_rocauc_dataframe['base_rule'] = full_mis_rocauc_dataframe['UQMetric'].str.extract(pattern_baserule)
full_mis_rocauc_dataframe['RiskType'] = full_mis_rocauc_dataframe['UQMetric'].str.extract(pattern_risk)

full_dataframe['base_rule'] = full_dataframe['UQMetric'].str.extract(pattern_baserule)
full_dataframe['RiskType'] = full_dataframe['UQMetric'].str.extract(pattern_risk)

In [11]:
full_dataframe.sample(10)

Unnamed: 0,UQMetric,LossFunction,Dataset,Scores,architecture,training_dataset,base_rule,RiskType
336,BrierScore TotalRisk outer inner,BrierScore,cifar100,"[[0.9494289, 0.6158097, 0.3813681, 0.8758469, ...",resnet18,cifar100,BrierScore,outer inner
655,ZeroOneScore TotalRisk outer inner,SphericalScore,cifar10,"[[0.79131055, 0.9066206, 0.94712746, 0.9631396...",resnet18,cifar100,ZeroOneScore,outer inner
35,LogScore TotalRisk outer central,BrierScore,cifar10,"[[1.0061693, 0.84668225, 0.8956466, 0.88200366...",vgg19,noisy_cifar10,LogScore,outer central
136,LogScore ExcessRisk outer outer,CrossEntropy,cifar100,"[[4.0369096, 2.143008, 0.11844545, 0.002299733...",vgg19,noisy_cifar100,LogScore,outer outer
295,LogScore BayesRisk inner,SphericalScore,cifar10,"[[0.4150831, 0.4211406, 0.32187587, 0.52776235...",resnet18,cifar10,LogScore,inner
273,LogScore BayesRisk outer,CrossEntropy,blurred_cifar100,"[[1.6843468, 0.29034778, 0.28406045, 0.4205756...",vgg19,cifar100,LogScore,outer
1087,SphericalScore ExcessRisk outer outer,BrierScore,svhn,"[[0.18032727, 0.007887023, 0.028852534, 0.0307...",resnet18,cifar100,SphericalScore,outer outer
1002,SphericalScore TotalRisk inner outer,SphericalScore,svhn,"[[0.4003244, 0.08564179, 0.5865064, 0.27314433...",vgg19,cifar10,SphericalScore,inner outer
1196,SphericalScore ExcessRisk central inner,SphericalScore,cifar100,"[[0.000364002023360655, 0.0052225366006558305,...",resnet18,cifar100,SphericalScore,central inner
95,LogScore TotalRisk central outer,BrierScore,cifar10,"[[3.0931637, 3.4161448, 4.0479693, 4.348086, 3...",resnet18,cifar100,LogScore,central outer


In [18]:
full_ood_rocauc_dataframe.sample(10)

Unnamed: 0,UQMetric,Dataset,LossFunction,RocAucScores_array,architecture,training_dataset,base_rule,RiskType
73,LogScore TotalRisk inner inner,blurred_cifar10,BrierScore,"[0.7287710849999999, 0.752614795, 0.7695673949...",vgg19,noisy_cifar10,LogScore,inner inner
196,LogScore ExcessRisk inner inner,cifar10,BrierScore,"[0.5, 0.5, 0.5, 0.5, 0.5]",vgg19,cifar100,LogScore,inner inner
1081,SphericalScore ExcessRisk outer outer,cifar10,BrierScore,"[0.49999999999999994, 0.5, 0.5, 0.5, 0.5]",resnet18,noisy_cifar10,SphericalScore,outer outer
892,ZeroOneScore ExcessRisk central central,svhn,BrierScore,"[0.5, 0.5, 0.5, 0.5, 0.5]",vgg19,cifar100,ZeroOneScore,central central
808,ZeroOneScore ExcessRisk outer central,blurred_cifar10,BrierScore,"[0.7947829099999999, 0.8059960649999999, 0.800...",resnet18,cifar100,ZeroOneScore,outer central
1209,SphericalScore ExcessRisk central central,blurred_cifar100,CrossEntropy,"[0.5, 0.5, 0.5, 0.5, 0.5]",vgg19,cifar10,SphericalScore,central central
787,ZeroOneScore ExcessRisk outer inner,svhn,BrierScore,"[0.6848994506760909, 0.7097015461739397, 0.690...",resnet18,cifar100,ZeroOneScore,outer inner
564,BrierScore ExcessRisk central inner,blurred_cifar100,CrossEntropy,"[0.5, 0.5, 0.5, 0.5, 0.5]",resnet18,noisy_cifar10,BrierScore,central inner
418,BrierScore TotalRisk central outer,blurred_cifar10,BrierScore,"[0.890928635, 0.8924638700000002, 0.8872899399...",resnet18,noisy_cifar100,BrierScore,central outer
412,BrierScore TotalRisk central outer,svhn,BrierScore,"[0.897865559695759, 0.8854699312384756, 0.9045...",vgg19,noisy_cifar10,BrierScore,central outer


In [10]:
full_dataframe.to_csv('../../tables/central_tables/full_dataframe.csv')
full_ood_rocauc_dataframe.to_csv('../../tables/central_tables/full_ood_rocauc.csv')
full_mis_rocauc_dataframe.to_csv('../../tables/central_tables/full_mis_rocauc.csv')