In [None]:
import numpy as np
from source.source.postprocessing_utils import get_metrics_results
from source.source.table_utils import pretty_matplotlib_config
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
pd.set_option('display.max_rows', None)

In [None]:
architecture = 'resnet18' # vgg resnet18
training_dataset_name = 'cifar10' # 'cifar10' 'cifar100' 'noisy_cifar10' 'noisy_cifar100' 'missed_class_cifar10'
if training_dataset_name not in ['missed_class_cifar10', 'noisy_cifar10', 'noisy_cifar100']:
    dataset_name_to_extract_from = training_dataset_name
else:
    dataset_name_to_extract_from = training_dataset_name.split('_')[-1]
print(dataset_name_to_extract_from)

temperature = 1.

model_ids = np.arange(20)
# model_ids = np.arange(6)

list_extraction_datasets=['cifar10', 'cifar100', 'svhn', 'blurred_cifar100', 'blurred_cifar10']
list_ood_datasets = [el for el in list_extraction_datasets]
loss_function_names = [
    'brier_score',
    'cross_entropy',
    'spherical_score'
]

use_different_approximations = False #########
gt_prob_approx = 'same' # same as the predictive prob approximation
# gt_prob_approx = 'flat' # uniform
# gt_prob_approx = 'diracs' # mixture of diracs
# gt_prob_approx = 'random' # random samples

In [None]:
metrics_df = get_metrics_results(
    loss_function_names=loss_function_names,
    training_dataset_name=training_dataset_name,
    architecture=architecture,
    model_ids=model_ids,
)


pretty_matplotlib_config(fontsize=25, fonttype=42)
# Set up the matplotlib figure
plt.figure(figsize=(12, 8), dpi=150)

# Draw a nested violinplot
sns.violinplot(data=metrics_df, x="Metric", y="Value", hue="Loss Function", linewidth=1.5)
plt.title('Distribution of Evaluation Metrics for Ensemble Members by Loss Function')
plt.ylabel('Metric Value')
plt.xlabel('Evaluation Metric')
plt.legend(title='Loss Function')

plt.tight_layout()
plt.savefig(f'imgs/trainstats_{architecture}_{training_dataset_name}.pdf')
plt.show()
