# Quick evaluation:
**Purpose:** Quickly inspect the performance of your newly trained model

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pathlib
from typing import Type
import json

import jax
import jax.numpy as jnp
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
import equinox as eqx

from mc2.runners.model_setup_jax import get_normalizer
from mc2.data_management import EXPERIMENT_LOGS_ROOT, MODEL_DUMP_ROOT, NORMALIZATION_ROOT
from mc2.models.jiles_atherton import JAStatic, JAWithGRU
from mc2.models.RNN import GRU
from mc2.model_interfaces.model_interface import ModelInterface, load_model, count_model_parameters

In [None]:
from mc2.utils.model_evaluation import reconstruct_model_from_exp_id, get_exp_ids

In [None]:
from mc2.data_management import AVAILABLE_MATERIALS

In [None]:
# for material_name in AVAILABLE_MATERIALS:
#     print(material_name)
#     print(get_exp_ids(material_name=material_name, model_type=None))
#     print()

exp_ids = get_exp_ids(material_name="B", model_type=None)
exp_ids

In [None]:
# exp_id = '3C90_GRU_6dca3c2e-c8a0-47'
# exp_id = '3C90_JA_4ec8f810-298b-49'
# exp_id = '3C90_GRUwLinearModel_22c14f81-a820-44'
# exp_id = '3C90_GRU_96f322d4-b17d-4e' # 5 steps
# exp_id = '3C90_GRU_8cb34afd-919b-44' # 3 steps
# exp_id = '3C90_GRU_a5e52cbd-9708-45' # 0 steps
# exp_id = '3C90_GRU_23db58e4-948c-42'
# exp_id = 'N87_GRU_8ba07f4f-c59a-42'
# exp_id = '3C94_GRU_b7cf990c-33b5-49'
# exp_id = '3E6_GRU_e45054a0-3df2-4c'
# exp_id = '3F4_GRU_6d364e15-88db-46'
# exp_id = '77_GRU_db53aa04-5f27-43'
# exp_id = '3C90_GRU_31895366-dd82-4f'
# exp_id = 'N49_GRU_9d34af17-f5bc-46'
# exp_id = '3C90_HNODE_9f68493b-bb55-46'
# exp_id = '3C90_HNODE_4ccefbbd-4fbd-47'
# exp_id = '3C90_GRU_72562eee-55a6-48'
exp_id = exp_ids[0]


# material_name = "N87"
material_name = exp_id.split("_")[0]
model_type = exp_id.split("_")[1]

seed = 0

In [None]:
# experiment_path = EXPERIMENT_LOGS_ROOT / "jax_experiments"
# with open(experiment_path / f"{exp_id}.json", "r") as f:
#     params = json.load(f)["params"]
# params  

wrapped_model = reconstruct_model_from_exp_id(exp_id)
wrapped_model

In [None]:
wrapped_model.n_params

## Look at stored predictions + losses:

In [None]:
from mc2.utils.model_evaluation import (
    load_gt_and_pred, plot_worst_predictions, plot_first_predictions, plot_loss_trends
)

In [None]:
seed=0
gt, pred = load_gt_and_pred(
    exp_id=exp_id,
    seed=seed,
    freq_idx=1
)

In [None]:
plot_worst_predictions(gt, pred);

In [None]:
plot_first_predictions(gt, pred);

In [None]:
plot_loss_trends(exp_id, seed);

### Further Plotting:

In [None]:
from mc2.training.jax_routine import val_test
from mc2.runners.model_setup_jax import get_normalizer

In [None]:
_, (train_set, eval_set, test_set) = get_normalizer(material_name, wrapped_model.featurize, subsampling_freq=1, do_normalization=True)

In [None]:
test_loss, test_pred_l, test_gt_l = val_test(test_set, wrapped_model, past_size=1)

In [None]:
# plot_worst_predictions(gt, pred)  # to compare to the performance at the end of training
plot_worst_predictions(test_gt_l[0], test_pred_l[0])

In [None]:
plot_first_predictions(test_gt_l[0], test_pred_l[0]);

In [None]:
wrapped_model.model

In [None]:
from mc2.utils.model_evaluation import plot_model_frequency_sweep

In [None]:
for past_size in [900, 500, 100]:
    print("past_size:", past_size)
    for seed in jnp.arange(10, 20, 1):
        print("seed:", seed)
        plot_model_frequency_sweep(wrapped_model, test_set, jax.random.PRNGKey(seed), past_size=past_size, figsize=(18,6))
        plt.show()

In [None]:
wrapped_model.n_params

## Cross validation:

In [None]:
from IPython.display import display, HTML
from mc2.utils.model_evaluation import evaluate_cross_validation
from mc2.utils.final_data_evaluation import FINAL_SCENARIOS_PER_MATERIAL
from mc2.utils.pretest_evaluation import create_multilevel_df

In [None]:
loader_key = jax.random.PRNGKey(32)

metrics = evaluate_cross_validation(
    wrapped_model=wrapped_model,
    test_set=test_set,
    scenarios=FINAL_SCENARIOS_PER_MATERIAL[test_set.material_name],
    sequence_length=1000,
    batch_size_per_frequency=1000,
    loader_key=loader_key,
);

# metrics_per_sequence = get_metrics_per_sequence(
#     wrapped_model,
#     test_set,
#     scenarios[test_set.material_name],
#     loader_key,
# )


#H, B, T = get_mixed_frequency_arrays(test_set, sequence_length=1000, batch_size=100, key=loader_key)


In [None]:
df_models_3C90 = create_multilevel_df(
    {
        model_type: metrics,
    }
)
display(HTML(df_models_3C90.T.to_html(float_format="%.4f", bold_rows=False)))

---

In [None]:
raise

## Pretest eval:

In [None]:
from IPython.display import display, HTML
from mc2.utils.pretest_evaluation import evaluate_pretest_scenarios, create_multilevel_df, HOSTS_VALUES_DICT, SCENARIO_LABELS, load_hdf5_pretest_data

In [None]:
B, T, H_init, H_true, loss, loss_short, msks_scenarios_N_tup = load_hdf5_pretest_data(material_name)
B.shape, T.shape, H_init.shape, H_true.shape

In [None]:
msks_scenarios_N_tup

In [None]:
metrics = evaluate_pretest_scenarios(
    wrapped_model,
    B,
    T,
    H_init,
    H_true,
    loss,
    list(reversed(msks_scenarios_N_tup)),
    scenario_labels=list(reversed(SCENARIO_LABELS)),
    show_plots=False,
)

In [None]:
df_models_3C90 = create_multilevel_df({
    model_type: metrics,
    "hosts": HOSTS_VALUES_DICT[material_name],
}
)
display(HTML(df_models_3C90.T.to_html(float_format="%.3f", bold_rows=False)))

In [None]:
raise

## create histograms:

In [None]:
from mc2.utils.pretest_evaluation import produce_pretest_histograms

In [None]:
ADAPTED_SCENARIO_LABELS = [
    "\\textbf{90\% known, 10\% unknown}",
    "\\textbf{50\% known, 50\% unknown}",
    "\\textbf{10\% known, 90\% unknown}",
]

produce_pretest_histograms(
    material_name,
    wrapped_model,
    B,
    T,
    H_init,
    H_true,
    loss,
    list(reversed(msks_scenarios_N_tup)),
    scenario_labels=list(reversed(SCENARIO_LABELS)),
    adapted_scenario_labels=list(reversed(ADAPTED_SCENARIO_LABELS)),
    show_plots=False,
);

## Cross-Data Modelling:

In [None]:
model = wrapped_model

In [None]:
model

In [None]:
AVAILABLE_MATERIALS

In [None]:
cross_material_metrics = {}

for material_name in AVAILABLE_MATERIALS:
    B, T, H_init, H_true, loss, loss_short, msks_scenarios_N_tup = load_hdf5_pretest_data(material_name)
    metrics = evaluate_pretest_scenarios(
        wrapped_model,
        B,
        T,
        H_init,
        H_true,
        loss,
        list(reversed(msks_scenarios_N_tup)),
        scenario_labels=list(reversed(SCENARIO_LABELS)),
        show_plots=False,
    )
    cross_material_metrics[material_name] = metrics

In [None]:
df_models_3C90 = create_multilevel_df(cross_material_metrics)
display(HTML(df_models_3C90.T.to_html(float_format="%.3f", bold_rows=False)))