In [1]:
from psruq.source.postprocessing_utils import (
    get_predicted_labels,
    get_missclassification_dataframe,
    get_ood_detection_dataframe,
    get_raw_scores_dataframe,
)

from source.datasets.constants import DatasetName
from source.losses.constants import LossName
from source.models.constants import ModelName
from source.metrics import (
    ApproximationType,
    GName,
    RiskType,
)
from torch_uncertainty_models.source.notebook_utils import (
    get_new_models_sampled_combinations_uncertainty_scores,
)

import pandas as pd
import numpy as np

pd.set_option("display.max_rows", None)

stty: 'standard input': Inappropriate ioctl for device
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
temperature = 1.0
model_ids = np.arange(20)

loss_function_names = [LossName.CROSS_ENTROPY]
training_dataset_names = [DatasetName.CIFAR10.value, DatasetName.CIFAR100.value]

full_dataframe = None
full_ood_rocauc_dataframe = None
full_mis_rocauc_dataframe = None

In [3]:
for training_dataset_name in training_dataset_names:
    if training_dataset_name == DatasetName.CIFAR10.value:
        list_extraction_datasets = [
            DatasetName.CIFAR10.value,
            DatasetName.CIFAR100.value,
            DatasetName.CIFAR10C.value,
            DatasetName.TINY_IMAGENET.value,
        ]
    elif training_dataset_name == DatasetName.CIFAR100.value:
        list_extraction_datasets = [
            DatasetName.CIFAR10.value,
            DatasetName.CIFAR100.value,
            # DatasetName.CIFAR100C.value,
            DatasetName.TINY_IMAGENET.value,
        ]
    else:
        raise NotImplementedError("Need to implement")

    if DatasetName.CIFAR10C.value in list_extraction_datasets:
        list_extraction_datasets.remove(DatasetName.CIFAR10C.value)
        list_extraction_datasets.extend(
            [DatasetName.CIFAR10C.value + f"_{i}" for i in range(1, 6)]
        )
    if DatasetName.CIFAR100C.value in list_extraction_datasets:
        list_extraction_datasets.remove(DatasetName.CIFAR100C.value)
        list_extraction_datasets.extend(
            [DatasetName.CIFAR100C.value + f"_{i}" for i in range(1, 6)]
        )

    architecture = ModelName.RESNET18

    if training_dataset_name not in [
        "noisy_cifar10",
        "noisy_cifar100",
    ]:
        training_dataset_name_aux = training_dataset_name
    else:
        training_dataset_name_aux = training_dataset_name.split("_")[-1]

    uq_results, embeddings_per_dataset, targets_per_dataset = (
        get_new_models_sampled_combinations_uncertainty_scores(
            loss_function_names=loss_function_names,
            training_dataset_name=training_dataset_name,
            model_ids=model_ids,
            list_extraction_datasets=list_extraction_datasets,
            temperature=temperature,
            use_cached=True,
        )
    )

    df_ood = get_ood_detection_dataframe(
        ind_dataset=training_dataset_name_aux,
        uq_results=uq_results,
        list_ood_datasets=list_extraction_datasets,
    )

    max_ind = int(
        targets_per_dataset[training_dataset_name_aux].shape[0] / len(model_ids)
    )
    true_labels = targets_per_dataset[training_dataset_name_aux][:max_ind]

    pred_labels = get_predicted_labels(
        embeddings_per_dataset=embeddings_per_dataset,
        training_dataset_name=training_dataset_name_aux,
    )

    df_misclassification = get_missclassification_dataframe(
        ind_dataset=training_dataset_name_aux,
        uq_results=uq_results,
        true_labels=true_labels,
        pred_labels=pred_labels,
    )

    scores_df_unravel = get_raw_scores_dataframe(uq_results=uq_results)
    scores_df_unravel["architecture"] = architecture.value
    scores_df_unravel["training_dataset"] = training_dataset_name
    df_ood["architecture"] = architecture.value
    df_ood["training_dataset"] = training_dataset_name
    df_misclassification["architecture"] = architecture.value
    df_misclassification["training_dataset"] = training_dataset_name

    if full_dataframe is None:
        full_dataframe = scores_df_unravel
        full_ood_rocauc_dataframe = df_ood
        full_mis_rocauc_dataframe = df_misclassification
    else:
        full_dataframe = pd.concat([full_dataframe, scores_df_unravel])
        full_ood_rocauc_dataframe = pd.concat([full_ood_rocauc_dataframe, df_ood])
        full_mis_rocauc_dataframe = pd.concat(
            [full_mis_rocauc_dataframe, df_misclassification]
        )

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 110/110 [09:59<00:00,  5.45s/it]


In [4]:
uq_results["LogScore energy inner"]["CrossEntropy"].keys()

dict_keys(['cifar10', 'cifar100', 'tiny_imagenet', 'cifar10c_1', 'cifar10c_2', 'cifar10c_3', 'cifar10c_4', 'cifar10c_5'])

In [5]:
full_dataframe.sample(10)

Unnamed: 0,UQMetric,LossFunction,Dataset,Scores,architecture,training_dataset
232,BrierScore TotalRisk central central,CrossEntropy,cifar10,"[[0.40422922, 0.75544775, 0.75003016, 0.775929...",resnet18,cifar100
551,SphericalScore TotalRisk inner central,CrossEntropy,cifar10c_5,"[[0.5503747513506348, 0.5926362846990155, 0.50...",resnet18,cifar10
303,BrierScore ExcessRisk central inner,CrossEntropy,cifar10c_5,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",resnet18,cifar100
290,BrierScore ExcessRisk central outer,CrossEntropy,tiny_imagenet,"[[0.40737048, 0.44480655, 0.4267369, 0.5548498...",resnet18,cifar100
303,BrierScore ExcessRisk central inner,CrossEntropy,cifar10c_5,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",resnet18,cifar10
230,BrierScore TotalRisk central inner,CrossEntropy,cifar10c_4,"[[0.63418686, 0.7032609, 0.69550395, 0.6251498...",resnet18,cifar100
227,BrierScore TotalRisk central inner,CrossEntropy,cifar10c_1,"[[0.6454982, 0.64261407, 0.6028866, 0.7708387,...",resnet18,cifar10
519,SphericalScore TotalRisk outer inner,CrossEntropy,cifar10c_5,"[[0.4969115, 0.55103815, 0.45745933, 0.4439876...",resnet18,cifar100
381,ZeroOneScore TotalRisk inner central,CrossEntropy,cifar10c_3,"[[0.50417453, 0.50004387, 0.43402344, 0.575773...",resnet18,cifar100
585,SphericalScore ExcessRisk outer inner,CrossEntropy,cifar100,"[[0.23793499, 0.37648946, 0.22644593, 0.339729...",resnet18,cifar100


In [6]:
full_ood_rocauc_dataframe.sample(10)

Unnamed: 0,UQMetric,Dataset,LossFunction,RocAucScores_array,architecture,training_dataset
100,BrierScore ExcessRisk inner outer,cifar100,CrossEntropy,"[0.49999999999999994, 0.49999999999999994, 0.4...",resnet18,cifar100
183,ZeroOneScore BayesRisk inner,cifar10,CrossEntropy,"[0.53501334, 0.524730345, 0.525307985, 0.52766...",resnet18,cifar100
401,ZeroOneScore TotalRisk central central,cifar100,CrossEntropy,"[0.46498666, 0.47526965499999996, 0.4746920149...",resnet18,cifar10
680,LogScore energy inner,cifar10,CrossEntropy,"[0.5, 0.5, 0.5, 0.5, 0.5]",resnet18,cifar10
180,ZeroOneScore BayesRisk outer,cifar10,CrossEntropy,"[0.067820305, 0.06909691999999999, 0.066266789...",resnet18,cifar100
301,BrierScore ExcessRisk central inner,cifar10c_3,CrossEntropy,"[0.5, 0.5, 0.5, 0.5, 0.5]",resnet18,cifar10
206,SphericalScore TotalRisk inner central,tiny_imagenet,CrossEntropy,"[0.51673053, 0.5150643, 0.5025341400000001, 0....",resnet18,cifar100
172,ZeroOneScore ExcessRisk central outer,cifar100,CrossEntropy,"[0.5, 0.5, 0.5, 0.5000000000000001, 0.5]",resnet18,cifar100
32,LogScore TotalRisk inner inner,cifar10,CrossEntropy,"[0.5, 0.5, 0.5, 0.5, 0.4999999999999999]",resnet18,cifar10
112,BrierScore ExcessRisk central inner,cifar100,CrossEntropy,"[0.5, 0.5, 0.5, 0.5, 0.5]",resnet18,cifar100


In [7]:
full_mis_rocauc_dataframe.sample(10)

Unnamed: 0,UQMetric,LossFunction,RocAucScores_array,architecture,training_dataset
41,BrierScore BayesRisk central,CrossEntropy,"[0.4839095403184386, 0.42035389201542195, 0.51...",resnet18,cifar100
83,SphericalScore BayesRisk central,CrossEntropy,"[0.49742277496514775, 0.44295383527143833, 0.5...",resnet18,cifar100
76,SphericalScore ExcessRisk inner inner,CrossEntropy,"[0.5, 0.5, 0.5, 0.5, 0.5]",resnet18,cifar10
46,ZeroOneScore TotalRisk inner inner,CrossEntropy,"[0.5094209727419473, 0.41703767256389535, 0.50...",resnet18,cifar100
42,ZeroOneScore TotalRisk outer outer,CrossEntropy,"[0.6407233994702266, 0.5109178346306207, 0.491...",resnet18,cifar10
26,BrierScore TotalRisk inner central,CrossEntropy,"[0.6431605422166953, 0.5070854470725878, 0.488...",resnet18,cifar10
82,SphericalScore BayesRisk inner,CrossEntropy,"[0.6431591969205616, 0.5070844336085581, 0.488...",resnet18,cifar10
84,LogScore energy outer,CrossEntropy,"[0.5128014963451294, 0.510691458770523, 0.5072...",resnet18,cifar10
2,LogScore TotalRisk outer central,CrossEntropy,"[0.4549673490351457, 0.45748184969840183, 0.50...",resnet18,cifar100
39,BrierScore BayesRisk outer,CrossEntropy,"[0.5130052980324211, 0.5084377281294196, 0.500...",resnet18,cifar10


In [8]:
pattern_baserule = r"(LogScore|BrierScore|ZeroOneScore|SphericalScore)"
pattern_risk = r"(outer outer|outer inner|outer central|inner outer|inner inner|inner central|central outer|central inner|central central|energy inner|energy outer|outer|inner|central)"

full_ood_rocauc_dataframe["base_rule"] = full_ood_rocauc_dataframe[
    "UQMetric"
].str.extract(pattern_baserule)
full_ood_rocauc_dataframe["RiskType"] = full_ood_rocauc_dataframe[
    "UQMetric"
].str.extract(pattern_risk)

full_mis_rocauc_dataframe["base_rule"] = full_mis_rocauc_dataframe[
    "UQMetric"
].str.extract(pattern_baserule)
full_mis_rocauc_dataframe["RiskType"] = full_mis_rocauc_dataframe[
    "UQMetric"
].str.extract(pattern_risk)

full_dataframe["base_rule"] = full_dataframe["UQMetric"].str.extract(pattern_baserule)
full_dataframe["RiskType"] = full_dataframe["UQMetric"].str.extract(pattern_risk)

In [9]:
from psruq.source.path_config import REPOSITORY_ROOT
import os

In [10]:
full_dataframe.to_pickle(
    os.path.join(REPOSITORY_ROOT, "tables/central_tables/new_models_full_dataframe.pkl")
)
full_ood_rocauc_dataframe.to_pickle(
    os.path.join(
        REPOSITORY_ROOT, "tables/central_tables/new_models_full_ood_rocauc.pkl"
    )
)
full_mis_rocauc_dataframe.to_pickle(
    os.path.join(
        REPOSITORY_ROOT, "tables/central_tables/new_models_full_mis_rocauc.pkl"
    )
)