# Model Comparison:

**Purpose:** In-depth comparison of multiple models

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pathlib
from typing import Type
import json

import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", False)
import equinox as eqx

from mc2.runners.model_setup_jax import get_normalizer
from mc2.data_management import EXPERIMENT_LOGS_ROOT, MODEL_DUMP_ROOT, NORMALIZATION_ROOT, load_hdf5_pretest_data
from mc2.models.jiles_atherton import JAStatic, JAWithGRU
from mc2.models.RNN import GRU
from mc2.model_interfaces.model_interface import ModelInterface, load_model, count_model_parameters

In [None]:
from mc2.utils.model_evaluation import reconstruct_model_from_exp_id, get_exp_ids
from IPython.display import display, HTML
from mc2.utils.pretest_evaluation import HOSTS_VALUES_DICT, evaluate_pretest_scenarios, create_multilevel_df, SCENARIO_LABELS

---

probably makes sense to store the pretest results for different models under the `exp_id`?

In [None]:
exp_ids = [
    '3C90_GRU_23db58e4-948c-42',
    '3C90_GRU_996b1949-71d9-4c',
    '3C90_GRU_97c4047f-c2d8-48',
    '3C90_GRU_fbfaa278-d274-46', # f32
]
shifts = ["5_f32","3_f32","0_f64","0_f32"]

material_name = exp_ids[0].split("_")[0]
model_type = exp_ids[0].split("_")[1]

In [None]:
models = [reconstruct_model_from_exp_id(exp_id) for exp_id in exp_ids]

In [None]:
B, T, H_init, H_true, loss, loss_short, msks_scenarios_N_tup = load_hdf5_pretest_data(material_name)
B.shape, T.shape, H_init.shape, H_true.shape

In [None]:
all_metrics = {
    str(shift): evaluate_pretest_scenarios(
        model,
        B,
        T,
        H_init,
        H_true,
        loss,
        msks_scenarios_N_tup,
        scenario_labels=SCENARIO_LABELS,
        show_plots=False,
    ) for model, shift in zip(models, shifts)
}

In [None]:
all_metrics["hosts"] = HOSTS_VALUES_DICT["3C90"]

df_models_3C90 = create_multilevel_df(all_metrics)
display(HTML(df_models_3C90.T.to_html(float_format="%.4f", bold_rows=False)))

In [None]:
for model in models:
    print(count_model_parameters(model))

In [None]:
from mc2.utils.pretest_evaluation import store_predictions_to_csv

In [None]:
SCENARIO_LABELS

In [None]:
store_predictions_to_csv(
    exp_ids[-1],
    models[-1],
    B,
    T,
    H_init,
    H_true,
    loss,
    list(reversed(msks_scenarios_N_tup)),
    scenario_labels=list(reversed(SCENARIO_LABELS)),
)

In [None]:
pred = pd.read_csv(f'{exp_ids[-1]}_pred.csv', header=None)
meas = pd.read_csv(f'{exp_ids[-1]}_meas.csv', header=None)