In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
import matplotlib as mpl
from matplotlib import rc
rc('font',**{'family':'serif','serif':['Helvetica']})
mpl.rcParams['text.usetex'] = True
mpl.rcParams.update({'font.size': 10})
mpl.rcParams['text.latex.preamble']=r"\usepackage{bm}\usepackage{amsmath}\usepackage{upgreek}"

In [None]:
from mc2.utils.model_evaluation import reconstruct_model_from_exp_id
from mc2.model_interfaces.model_interface import count_model_parameters
from mc2.utils.pretest_evaluation import produce_pretest_histograms, SCENARIO_LABELS
from mc2.data_management import load_hdf5_pretest_data

In [None]:
# exp_id = '3C90_GRU_96f322d4-b17d-4e' # 5 steps
# exp_id = '3C90_GRU_8cb34afd-919b-44' # 3 steps

exp_ids = [
    '3C90_GRU_97c4047f-c2d8-48',
    'N87_GRU_8ba07f4f-c59a-42',
    '3C94_GRU_b7cf990c-33b5-49',
    '3E6_GRU_e45054a0-3df2-4c',
    '3F4_GRU_6d364e15-88db-46',
    '77_GRU_db53aa04-5f27-43',
    '78_GRU_69379b64-4b79-42',
    'N27_GRU_41070eb0-0850-4f',
    'N30_GRU_9eca1a85-eaec-4a',
    'N49_GRU_9d34af17-f5bc-46',
]

In [None]:
wrapped_model = reconstruct_model_from_exp_id('3C90_GRU_97c4047f-c2d8-48')

In [None]:
wrapped_model

In [None]:
count_model_parameters(wrapped_model) + 7 # + normalization constants

In [None]:
ADAPTED_SCENARIO_LABELS = [
    "\\textbf{90\% known, 10\% unknown}",
    "\\textbf{50\% known, 50\% unknown}",
    "\\textbf{10\% known, 90\% unknown}",
]

In [None]:
for exp_id in exp_ids:
    material_name = exp_id.split("_")[0]
    model_type = exp_id.split("_")[1]
    seed = 0

    wrapped_model = reconstruct_model_from_exp_id(exp_id)
    B, T, H_init, H_true, loss, loss_short, msks_scenarios_N_tup = load_hdf5_pretest_data(material_name)
    
    fig, axs = produce_pretest_histograms(
        material_name,
        wrapped_model,
        B,
        T,
        H_init,
        H_true,
        loss,
        list(reversed(msks_scenarios_N_tup)),
        scenario_labels=list(reversed(SCENARIO_LABELS)),
        adapted_scenario_labels=list(reversed(ADAPTED_SCENARIO_LABELS)),
        show_plots=False,
    );
    
    plt.savefig(f"histograms/{exp_id}_preevaluation_results.pdf", bbox_inches="tight")