In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import pathlib
from typing import Type
import json

import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import equinox as eqx

from mc2.runners.model_setup_jax import get_normalizer
from mc2.data_management import EXPERIMENT_LOGS_ROOT, MODEL_DUMP_ROOT, NORMALIZATION_ROOT, load_hdf5_pretest_data
from mc2.models.jiles_atherton import JAStatic, JAWithGRU
from mc2.models.RNN import GRU
from mc2.models.model_interface import ModelInterface, load_model, RNNwInterface, JAwInterface

In [None]:
material_name = "3C90"

exp_id = '82b01a31-1ce3-4c'
#exp_id = 'b2d68157-3620-49'
# exp_id = 'b6423937-2f51-46'
seed = 0

## Look at stored predictions + losses:

In [None]:
from mc2.utils.model_evaluation import (
    load_gt_and_pred, plot_worst_predictions, plot_first_predictions, plot_loss_trends
)

In [None]:
gt, pred = load_gt_and_pred(
    exp_id=exp_id,
    material_name=material_name,
    seed=seed,
    freq_idx=0
)

In [None]:
plot_worst_predictions(gt, pred);

In [None]:
plot_first_predictions(gt, pred);

In [None]:
plot_loss_trends(exp_id, material_name, seed);

## Loading and evaluating stored models:

In [None]:
from copy import deepcopy
from mc2.runners.model_setup_jax import setup_model
from mc2.models.model_interface import load_model

In [None]:
# TODO: store and extract model_label to automate
# TODO: add 

# extract model classes and normalizer
fresh_wrapped_model, _, params, (train_set, val_set, test_set) = setup_model(
    model_label = "GRU",
    material_name=material_name,
    model_key=jax.random.PRNGKey(0),
)

# replace randomly initialized model with the stored one
model_path = MODEL_DUMP_ROOT/f"{exp_id}.eqx"
model = load_model(
    model_path,
    type(fresh_wrapped_model.model)
)

wrapped_model = eqx.tree_at(lambda t: t.model, fresh_wrapped_model, model)
wrapped_model

### Plotting:

In [None]:
from mc2.training.jax_routine import val_test

In [None]:
test_loss, test_pred_l, test_gt_l = val_test(test_set, wrapped_model, past_size=1)

In [None]:
# plot_worst_predictions(gt, pred)  # to compare to the performance at the end of training
plot_worst_predictions(test_gt_l[0], test_pred_l[0])

### Parameter counting:

**Q:** Assumes that all parameters are `jax.Arrays`. Is this valid? Essentially the same as `.requires_grad()`

In [None]:
from mc2.models.model_interface import count_model_parameters

In [None]:
count_model_parameters(wrapped_model)

In [None]:
# test_model, _, params, data_tuple = setup_model(
#     model_label = "GRU",
#     material_name=material_name,
#     model_key=jax.random.PRNGKey(0),
# )
# count_model_parameters(test_model)

# test_model, _, params, data_tuple = setup_model(
#     model_label = "JAWithExternGRU",
#     material_name=material_name,
#     model_key=jax.random.PRNGKey(0),
# )

## Pretest eval:

In [None]:
from IPython.display import display, HTML
from mc2.utils.pretest_evaluation import evaluate_pretest_scenarios_custom, create_multilevel_df, HOSTS_VALUES_DICT, SCENARIO_LABELS

In [None]:
B, T, H_init, H_true, loss, loss_short, msks_scenarios_N_tup = load_hdf5_pretest_data(material_name)
B.shape, T.shape, H_init.shape, H_true.shape

In [None]:
metrics = evaluate_pretest_scenarios_custom(
    wrapped_model,
    B,
    T,
    H_init,
    H_true,
    loss,
    msks_scenarios_N_tup,
    scenario_labels=SCENARIO_LABELS,
    show_plots=False,
)

In [None]:
df_models_3C90 = create_multilevel_df({
    "hosts": HOSTS_VALUES_DICT["3C90"],
    "GRU": metrics}
)
display(HTML(df_models_3C90.T.to_html(float_format="%.3f", bold_rows=False)))