In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import pickle

import os
if "KERAS_BACKEND" not in os.environ:
    os.environ["KERAS_BACKEND"] = "torch"

In [None]:
with open(f'complete_pooling_metrics.pkl', 'rb') as f:
    complete_pooling_metrics = pickle.load(f)

with open(f'complete_pooling_trials_metrics.pkl', 'rb') as f:
    complete_pooling_trials_metrics = pickle.load(f)

with open(f'complete_pooling_subjects_metrics.pkl', 'rb') as f:
    complete_pooling_subjects_metrics = pickle.load(f)

In [None]:
with open(f'partial_pooling_global_metrics.pkl', 'rb') as f:
    partial_pooling_global_metrics = pickle.load(f)

with open(f'partial_pooling_local_metrics.pkl', 'rb') as f:
    partial_pooling_local_metrics = pickle.load(f)

In [None]:
pretty_param_names = [r'$\nu_p$', r'$\alpha_p$', r'$t_{0,p}$', r'$\beta_p$']
n_params = len(pretty_param_names)

In [None]:
pooling_models = [
    'No Pooling (NP)',  # trials
    'Complete Pooling (CP)',  # subjects
    'Partial Pooling (PP)' # local values
]
colors = [
    'blue',
    'green',
    'orange'
]

In [None]:
metrics_names = list(complete_pooling_subjects_metrics.keys())[:2]
metrics_names_pretty = [r'NRMSE ($\circ$)', r'Calibration Error ($\blacktriangledown$)']

In [None]:
fig, axis = plt.subplots(nrows=1, ncols=n_params, sharex=True, sharey=True, figsize=(10, 2), layout='constrained')
for i in range(len(metrics_names)):
    metric = metrics_names[i]
    for j in range(n_params):
        if i == 0:
            ax = axis[j]
            marker = 'o'
            ax.set_ylim(0, 0.25)
        else:
            ax = axis[j].twinx()
            marker = 'v'
            ax.set_ylim(0, 0.25)
            if j != n_params - 1:
                ax.set_yticks(ticks=np.linspace(0, 0.25, 6), labels=[])
            else:
                ax.set_yticks(ticks=np.linspace(0, 0.25, 6))
        values = [
            complete_pooling_trials_metrics[metric][j],
            complete_pooling_subjects_metrics[metric][j],
            partial_pooling_local_metrics[metric][j],
        ]
        ax.scatter(x=[1/4,2/4,3/4], y=values, color=colors, marker=marker, alpha=0.75)
        axis[j].set_title(pretty_param_names[j])
        ax.set_xticks(ticks=[1/4,2/4,3/4], labels=['NP', 'CP', 'PP'], fontsize=10)
        ax.set_xlim(0.1, 0.9)
        if (j == 0 and i == 0) or (j == n_params-1 and i == 1):
            ax.set_ylabel(metrics_names_pretty[i])
handles = [
    Patch(color=colors[i], label=pooling_models[i], alpha=0.75) for i in range(len(pooling_models))
]
fig.legend(handles=handles, bbox_to_anchor=(0.5, -0.15), loc='lower center', ncols=3, fontsize=10)
plt.savefig('pooling_metrics.pdf', bbox_inches='tight')
plt.show()