# Quick evaluation:
**Purpose:** Quickly inspect the performance of your newly trained model

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import glob
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pathlib
from typing import Type
import json

import jax
import jax.numpy as jnp
# jax.config.update("jax_platform_name", "cpu")
#jax.config.update("jax_enable_x64", True)
import equinox as eqx

from rhmag.model_setup import setup_normalizer
from rhmag.data_management import EXPERIMENT_LOGS_ROOT, MODEL_DUMP_ROOT, NORMALIZATION_ROOT, DATA_ROOT, MaterialSet
from rhmag.models.jiles_atherton import JAStatic, JAWithGRU
from rhmag.models.RNN import GRU
from rhmag.model_interfaces.model_interface import ModelInterface, load_model, count_model_parameters

In [None]:
from rhmag.utils.model_evaluation import reconstruct_model_from_file, get_exp_ids

In [None]:
from rhmag.data_management import AVAILABLE_MATERIALS

In [None]:
# for material_name in AVAILABLE_MATERIALS:
#     print(material_name)
#     print(get_exp_ids(material_name=material_name, model_type=None))
#     print()

exp_ids = sorted(get_exp_ids(material_name="E", model_type="GRULinearOut"))
exp_ids

In [None]:
#exp_id = 'B_LSTM7_demonstration_b1ccde72_seed1'
#exp_id = 'B_GRU8_reduced-features-f32_c785b2c3_seed12'
#exp_id = 'D_GRU10_default-f32_11f19655_seed12'
exp_id = 'A_GRU8_a8d9fab2_seed201'


material_name = exp_id.split("_")[0]
model_type = exp_id.split("_")[1]

seed = 0

In [None]:
wrapped_model = reconstruct_model_from_file(exp_id)

## Look at stored predictions + losses:

In [None]:
from rhmag.utils.model_evaluation import (
    load_gt_and_pred, plot_worst_predictions, plot_first_predictions, plot_loss_trends
)

In [None]:
seed=exp_id.split("seed")[-1]
gt, pred = load_gt_and_pred(
    exp_id=exp_id,
    seed=seed,
    freq_idx=0
)

In [None]:
plot_worst_predictions(gt, pred);

In [None]:
plot_first_predictions(gt, pred);

In [None]:
plot_loss_trends(exp_id, seed);

In [None]:
# visualize learning rate


### Further Plotting:

In [None]:
from rhmag.training.jax_routine import val_test
from rhmag.model_setup import setup_normalizer, setup_dataset

In [None]:
train_set, eval_set, test_set = setup_dataset(material_name, subsampling_freq=1, use_all_data=False)

In [None]:
test_loss, test_pred_l, test_gt_l = val_test(test_set, wrapped_model, past_size=100)

In [None]:
# plot_worst_predictions(gt, pred)  # to compare to the performance at the end of training
plot_worst_predictions(test_gt_l[0], test_pred_l[0])

In [None]:
plot_first_predictions(test_gt_l[0], test_pred_l[0]);

In [None]:
wrapped_model.n_params

In [None]:
raise

## Cross validation:

In [None]:
from IPython.display import display, HTML
from rhmag.utils.model_evaluation import evaluate_cross_validation
from rhmag.utils.final_data_evaluation import FINAL_SCENARIOS_PER_MATERIAL
from rhmag.utils.pretest_evaluation import create_multilevel_df

In [None]:
loader_key = jax.random.PRNGKey(32)

metrics = evaluate_cross_validation(
    wrapped_model=wrapped_model,
    test_set=test_set,
    scenarios=FINAL_SCENARIOS_PER_MATERIAL[test_set.material_name],
    sequence_length=1000,
    batch_size_per_frequency=1000,
    loader_key=loader_key,
);

In [None]:
df_models_3C90 = create_multilevel_df(
    {
        model_type: metrics,
    }
)
display(HTML(df_models_3C90.T.to_html(float_format="%.4f", bold_rows=False)))

In [None]:
metrics_save_path = DATA_ROOT / "metric_values" / f"metrics_{exp_id}.json"
with open(metrics_save_path, 'w') as f:
    json.dump(metrics, f)

---
---
---

In [None]:
exp_id

In [None]:
# visualize cross validation trajectories
from rhmag.utils.model_evaluation import plot_model_frequency_sweep, get_mixed_frequency_arrays
plot_model_frequency_sweep(wrapped_model, test_set, loader_key, past_size=1)

In [None]:
# visualize cross validation trajectories
from rhmag.utils.model_evaluation import plot_model_frequency_sweep, get_mixed_frequency_arrays
plot_model_frequency_sweep(wrapped_model, test_set, loader_key, past_size=100)

In [None]:
wrapped_model.n_params

In [None]:
wrapped_model

In [None]:
print("gru_cell:", 24*6 + 24*8 + 8 + 24)

print("normalizer:", 1 + 1 + 1 + 4) # Normalizations of B,H,T and 4 features
print("featurization (questionable?):", 1 + 1) # featurization, n_s, time_shift

24*6 + 24*8 + 8 + 24 + 1 + 1 + 1 + 4

In [None]:
H, B, T = get_mixed_frequency_arrays(test_set, sequence_length=1000, batch_size=1, key=loader_key)

past_size = 1

H_past = H[:, :past_size]
B_past = B[:, :past_size]

B_future = B[:, past_size:]
H_future = H[:, past_size:]

H_pred = wrapped_model(B_past, H_past, B_future, T)

In [None]:
fig, axs = plt.subplots(3, 7, figsize=(30,8))
for freq_idx in range(len(test_set.frequencies)):
    axs[0, freq_idx].plot(B_future[freq_idx])
    axs[1, freq_idx].plot(H_future[freq_idx])
    axs[1, freq_idx].plot(H_pred[freq_idx])
    axs[1, freq_idx].plot(H_future[freq_idx] - H_pred[freq_idx], color="tab:red", linestyle="--")

    axs[2, freq_idx].plot(B_future[freq_idx], H_future[freq_idx])
    axs[2, freq_idx].plot(B_future[freq_idx], H_pred[freq_idx])

    axs[0, freq_idx].grid(True, alpha=0.3)
    axs[1, freq_idx].grid(True, alpha=0.3)
    axs[2, freq_idx].grid(True, alpha=0.3)

    axs[0, freq_idx].set_ylabel("B")
    axs[0, freq_idx].set_xlabel("k")
    axs[1, freq_idx].set_ylabel("H")
    axs[1, freq_idx].set_xlabel("k")
    axs[2, freq_idx].set_ylabel("H")
    axs[2, freq_idx].set_xlabel("B")

fig.tight_layout(pad=-0.2)

In [None]:
raise

In [None]:
comparisons = {
    "float64 GRU8": 'A_GRU8_default_setup_93f8a137_seed3',
    "float32 GRU8": 'A_GRU8_default_setup_7e988c2b_seed3',
}

In [None]:
# comparisons = {
#     "default setup GRU8": 'D_GRU8_default_setup_de584153_seed3',
#     "shifted GRU8": 'D_GRU8_default_setup_4cb4bf75_seed3',
# }

# metrics_for_comparison = {
#     "hosts": HOSTS_VALUES_DICT[material_name]
# }#
metrics_for_comparison = {}

for key, value in comparisons.items():
    metrics_load_path = DATA_ROOT / "metric_values" / f"metrics_{value}.json"
    with open(metrics_load_path) as f:
        metrics_for_comparison[key] = json.load(f)

In [None]:
df_models_3C90 = create_multilevel_df(metrics_for_comparison)
display(HTML(df_models_3C90.T.to_html(float_format="%.3f", bold_rows=False)))

In [None]:
raise

## Pretest eval:

In [None]:
from IPython.display import display, HTML
from mc2.utils.pretest_evaluation import evaluate_pretest_scenarios, create_multilevel_df, HOSTS_VALUES_DICT, SCENARIO_LABELS, load_hdf5_pretest_data

In [None]:
B, T, H_init, H_true, loss, loss_short, msks_scenarios_N_tup = load_hdf5_pretest_data(material_name)
B.shape, T.shape, H_init.shape, H_true.shape

In [None]:
msks_scenarios_N_tup

In [None]:
metrics = evaluate_pretest_scenarios(
    wrapped_model,
    B,
    T,
    H_init,
    H_true,
    loss,
    list(msks_scenarios_N_tup),
    scenario_labels=list(SCENARIO_LABELS),
    show_plots=False,
)

In [None]:
# save metrics

metrics_save_path = DATA_ROOT / "metric_values" / f"metrics_{exp_id}.json"
with open(metrics_save_path, 'w') as f:
    json.dump(metrics, f)

In [None]:
exp_id

In [None]:
df_models_3C90 = create_multilevel_df({
    model_type: metrics,
    "hosts": HOSTS_VALUES_DICT[material_name],
}
)
display(HTML(df_models_3C90.T.to_html(float_format="%.3f", bold_rows=False)))

In [None]:
raise

In [None]:
comparisons = {
    "no features GRU8": '3C90_GRU8_9900e560-82ef-4d',
    "transform GRU8": '3C90_GRU8_bf3b3fd7-7d1b-4b',
    "default GRU8": '3C90_GRU8_331bf5f1-c1fe-46',
}

metrics_for_comparison = {
    "hosts": HOSTS_VALUES_DICT[material_name]
}

for key, value in comparisons.items():
    metrics_load_path = DATA_ROOT / "metric_values" / f"metrics_{value}.json"
    with open(metrics_load_path) as f:
        metrics_for_comparison[key] = json.load(f)

In [None]:
df_models_3C90 = create_multilevel_df(metrics_for_comparison)
display(HTML(df_models_3C90.T.to_html(float_format="%.3f", bold_rows=False)))

In [None]:
raise

## create histograms:

In [None]:
from mc2.utils.pretest_evaluation import produce_pretest_histograms

In [None]:
ADAPTED_SCENARIO_LABELS = [
    "\\textbf{90\% known, 10\% unknown}",
    "\\textbf{50\% known, 50\% unknown}",
    "\\textbf{10\% known, 90\% unknown}",
]

produce_pretest_histograms(
    material_name,
    wrapped_model,
    B,
    T,
    H_init,
    H_true,
    loss,
    list(reversed(msks_scenarios_N_tup)),
    scenario_labels=list(reversed(SCENARIO_LABELS)),
    adapted_scenario_labels=list(reversed(ADAPTED_SCENARIO_LABELS)),
    show_plots=False,
);

## Cross-Data Modelling:

In [None]:
model = wrapped_model

In [None]:
model

In [None]:
AVAILABLE_MATERIALS

In [None]:
cross_material_metrics = {}

for material_name in AVAILABLE_MATERIALS:
    B, T, H_init, H_true, loss, loss_short, msks_scenarios_N_tup = load_hdf5_pretest_data(material_name)
    metrics = evaluate_pretest_scenarios(
        wrapped_model,
        B,
        T,
        H_init,
        H_true,
        loss,
        list(reversed(msks_scenarios_N_tup)),
        scenario_labels=list(reversed(SCENARIO_LABELS)),
        show_plots=False,
    )
    cross_material_metrics[material_name] = metrics

In [None]:
df_models_3C90 = create_multilevel_df(cross_material_metrics)
display(HTML(df_models_3C90.T.to_html(float_format="%.3f", bold_rows=False)))