# Model Comparison:

**Purpose:** In-depth comparison of multiple models

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import pickle
import collections
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pathlib
from typing import Type
import json

import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", False)
import equinox as eqx

from rhmag.model_setup import setup_normalizer
from rhmag.data_management import MaterialSet, FINAL_MATERIALS, EXPERIMENT_LOGS_ROOT, MODEL_DUMP_ROOT, NORMALIZATION_ROOT
from rhmag.models.jiles_atherton import JAStatic, JAWithGRU
from rhmag.models.RNN import GRU
from rhmag.model_interfaces.model_interface import ModelInterface, load_model, count_model_parameters

In [None]:
from rhmag.utils.final_data_evaluation import generate_metrics_from_exp_ids_without_seed, visualize_df
from rhmag.utils.model_evaluation import reconstruct_model_from_exp_id, get_exp_ids, evaluate_cross_validation, load_parameterization, get_exp_ids_without_seed
from rhmag.utils.final_data_evaluation import FINAL_SCENARIOS_PER_MATERIAL
from IPython.display import display, HTML
from rhmag.utils.pretest_evaluation import HOSTS_VALUES_DICT, evaluate_pretest_scenarios, create_multilevel_df, SCENARIO_LABELS

---

probably makes sense to store the pretest results for different models under the `exp_id`?

## Checkout available models:

In [None]:
FINAL_MATERIALS

In [None]:
for material_name in FINAL_MATERIALS:
    print("MATERIAL:", material_name)
    [print("    " + f"'{element}'") for element in sorted(get_exp_ids(material_name=material_name, model_type=None))]
    #print(len(get_exp_ids(material_name=material_name, model_type=None)))
    print()

#sorted(get_exp_ids(material_name="E", model_type=None))

In [None]:
full_exp_ids = get_exp_ids(material_name="A", model_type=None, exp_name="pareto-front-f32")
exp_ids_without_seed = np.unique(["_".join(exp_id.split("_")[:-1]) for exp_id in full_exp_ids]).tolist()
exp_ids_without_seed

## iterate models and store results:

In [None]:
material_name = "A"

df, _ = generate_metrics_from_exp_ids_without_seed(
    exp_ids_without_seed=get_exp_ids_without_seed(material_name="A", model_type=None, exp_name="pareto-front-f32"),
    material_name=material_name,
    loader_key=jax.random.PRNGKey(99),
)
file_path_pickle = f'results_dump_material_{material_name}.pkl'
with open(file_path_pickle, 'wb') as f:
    pickle.dump(df, f)

In [None]:
material_name = "B"

df, _ = generate_metrics_from_exp_ids_without_seed(
    exp_ids_without_seed=[
        'B_LSTM7_demonstration_b1ccde72',
        'B_GRU8_reduced-features-f32_c785b2c3',
        'B_GRU8_final-reduced-features-f32_6437bf39',
        # 'B_GRU8_final-f32_c314f005',
        # #'B_GRU10_reduced-features-f32_a7ef751f',
        # 'B_GRU8_final-f32_2f803a5a',
        # 'B_GRU8_default-f32_51ca9159',
        
        # 'B_GRU8_long-f32_27d7b57d',
        # 'B_GRU8_shift-f32_5bbe867e',
        # 'B_GRU8_transformed-f32_4d6004e5',
        # 'B_GRU8_reduced-features-f32_c785b2c3',
        # 'B_GRU8_final-reduced-features-f32_6437bf39',
    ],
    material_name=material_name,
    loader_key=jax.random.PRNGKey(99),
)
file_path_pickle = f'results_dump_material_{material_name}.pkl'
with open(file_path_pickle, 'wb') as f:
    pickle.dump(df, f)

In [None]:
material_name = "C"

df, _ = generate_metrics_from_exp_ids_without_seed(
    exp_ids_without_seed=[
        'C_GRU8_final-f32_07005abe',
        'C_GRU8_default-f32_0b214b26',
        'C_GRU8_final-f32_0b011e20',
        'C_GRU8_final-f32_e8fe195e',
        'C_GRU10_larger-kernel-f32_f46418cd',
        'C_GRU8_reduced-features-f32_348e220c',
        'C_GRU8_final-reduced-features-f32_5fe02cfa',
        #'C_GRU8_default-f32_98f4ae79',
        # 'C_GRU8_shift-f32_515dc679',
        # 'C_GRU8_transformed-f32_11aa2385',
    ],
    material_name=material_name,
    loader_key=jax.random.PRNGKey(99),
)
file_path_pickle = f'results_dump_material_{material_name}.pkl'
with open(file_path_pickle, 'wb') as f:
    pickle.dump(df, f)

In [None]:
material_name = "D"

df, _ = generate_metrics_from_exp_ids_without_seed(
    exp_ids_without_seed=[
        'D_GRU8_final-f32_09d3ce02',
        'D_GRU10_reduced-features-f32_3c349983',
        'D_GRU8_final-f32_b7cb3edb',
        'D_GRU8_default-f32_726c3a66',
        'D_GRU10_default-f32_11f19655',
        'D_GRU8_reduced-features-f32_b6ac55b5',
        'D_GRU8_final-reduced-features-f32_3d0f8de4',
        #'D_GRU10_reduced-features-f32_3eda5160',
    ],
    material_name=material_name,
    loader_key=jax.random.PRNGKey(99),
)
file_path_pickle = f'results_dump_material_{material_name}.pkl'
with open(file_path_pickle, 'wb') as f:
    pickle.dump(df, f)

In [None]:
material_name = "E"

df, _ = generate_metrics_from_exp_ids_without_seed(
    exp_ids_without_seed=[
        'E_GRU10_final-f32_7ac8b027',
        'E_GRU8_final-f32_c6c7dc08',
        'E_GRU8_default-f32_8015a369',
        'E_GRU10_final-f32_0e90f783',
        'E_GRU8_final-f32_cfbcb9e6',
        'E_GRU8_reduced-features-f32_e88a2583',
        'E_GRU8_final-reduced-features-f32_8f8a200e',
    ],
    material_name=material_name,
    loader_key=jax.random.PRNGKey(99),
)
file_path_pickle = f'results_dump_material_{material_name}.pkl'
with open(file_path_pickle, 'wb') as f:
    pickle.dump(df, f)

In [None]:
raise

---

## Visualize model performance:

In [None]:
material_name = "A"

file_path_pickle = f'results_dump_material_{material_name}.pkl'
with open(file_path_pickle, 'rb') as f:
    loaded_all_results = pickle.load(f)
df = loaded_all_results.sort_values(by="exp_id")
df = df.reset_index(drop=True)

fig, axs = visualize_df(
    df,
    scenarios=FINAL_SCENARIOS_PER_MATERIAL[material_name],
    metrics=["sre", "nere"],
    x_label=None,# "seed",
    scale_log=True,
)
plt.show()

In [None]:
display(df)

### Store plot for each material

In [None]:
for material_name in FINAL_MATERIALS:

    file_path_pickle = f'results_dump_material_{material_name}.pkl'
    with open(file_path_pickle, 'rb') as f:
        loaded_all_results = pickle.load(f)

    df = loaded_all_results.sort_values(by="exp_id")
    df = df.reset_index(drop=True)

    fig, axs = visualize_df(
        df,
        scenarios=FINAL_SCENARIOS_PER_MATERIAL[material_name],
        metrics=["sre", "nere"],
        x_label=None,
    )
    plt.savefig(f"{material_name}_model_comparison.png", bbox_inches="tight")

In [None]:
# look at parameterization:
params = load_parameterization(exp_id=None)
params

In [None]:
raise

## pretest materials:

In [None]:
exp_ids = [
    '3C90_GRU_23db58e4-948c-42',
    '3C90_GRU_996b1949-71d9-4c',
    '3C90_GRU_97c4047f-c2d8-48',
    '3C90_GRU_fbfaa278-d274-46', # f32
]
shifts = ["5_f32","3_f32","0_f64","0_f32"]

material_name = exp_ids[0].split("_")[0]
model_type = exp_ids[0].split("_")[1]

In [None]:
models = [reconstruct_model_from_exp_id(exp_id) for exp_id in exp_ids]

In [None]:
B, T, H_init, H_true, loss, loss_short, msks_scenarios_N_tup = load_hdf5_pretest_data(material_name)
B.shape, T.shape, H_init.shape, H_true.shape

In [None]:
all_metrics = {
    str(shift): evaluate_pretest_scenarios(
        model,
        B,
        T,
        H_init,
        H_true,
        loss,
        msks_scenarios_N_tup,
        scenario_labels=SCENARIO_LABELS,
        show_plots=False,
    ) for model, shift in zip(models, shifts)
}

In [None]:
all_metrics["hosts"] = HOSTS_VALUES_DICT["3C90"]

df_models_3C90 = create_multilevel_df(all_metrics)
display(HTML(df_models_3C90.T.to_html(float_format="%.4f", bold_rows=False)))

In [None]:
for model in models:
    print(count_model_parameters(model))

### Testing CSV saving of the results

In [None]:
raise

In [None]:
from rhmag.utils.pretest_evaluation import store_predictions_to_csv

In [None]:
SCENARIO_LABELS

In [None]:
store_predictions_to_csv(
    exp_ids[-1],
    models[-1],
    B,
    T,
    H_init,
    H_true,
    loss,
    list(reversed(msks_scenarios_N_tup)),
    scenario_labels=list(reversed(SCENARIO_LABELS)),
)

In [None]:
pred = pd.read_csv(f'{exp_ids[-1]}_pred.csv', header=None)
meas = pd.read_csv(f'{exp_ids[-1]}_meas.csv', header=None)