# Model inspection:

This notebook provides a basic overview on how to load models and inspect their performance on a data set.

**Contents:**
- Loading of data
- Loading of models
- Inference on the data + Exemplary plotting
- Evaluation of metrics

In [None]:
# optional setup
%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # choose cuda-device
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"  # disable preallocation of memory

import jax
#jax.config.update("jax_platform_name", "cpu")  # optionally run on cpu

In [None]:
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import jax.numpy as jnp

from mc2.utils.pretest_evaluation import create_multilevel_df
from mc2.data_management import FINAL_MATERIALS, MaterialSet, DataSet
from mc2.utils.data_plotting import plot_sequence_prediction, plot_hysteresis_prediction
from mc2.utils.model_evaluation import reconstruct_model_from_file, plot_model_frequency_sweep, evaluate_cross_validation
from mc2.utils.final_data_evaluation import FINAL_SCENARIOS_PER_MATERIAL

## Load data:

In [None]:
FINAL_MATERIALS

In [None]:
data_set = DataSet.from_material_names(FINAL_MATERIALS)

# # if the data_set is too big for your memory you can also only load a single material:
# data_set = DataSet.from_material_names(["B"])

In [None]:
for material_set in data_set:
    print(material_set.material_name)
    for frequency_set in material_set:
        print(frequency_set.frequency)

## Load trained models:

A small collection of models is provided together with the repository. Otherwise, your own trained models should also be available based on their experiment id (exp_id).

In [None]:
exp_ids = {
    "A": 'A_GRU8_reduced-features-f32_2a1473b6_seed12',
    "B": 'B_GRU8_reduced-features-f32_c785b2c3_seed12',
    "C": 'C_GRU8_reduced-features-f32_348e220c_seed12',
    "D": 'D_GRU8_reduced-features-f32_b6ac55b5_seed12',
    "E": 'E_GRU8_reduced-features-f32_e88a2583_seed12',
}
models = {material_name: reconstruct_model_from_file(exp_id) for material_name, exp_id in exp_ids.items()}

## Inference + Visualization:

In [None]:
material_name = "B"  # change material name here to inspect the other models / material_sets
material_set = data_set.at_material(material_name)
model = models[material_name]

In [None]:
# choose subset of data:
past_size = 100
sequence_length = 2000
frequency = 50_000

relevant_frequency_set = material_set.at_frequency(jnp.array([frequency]))

B = relevant_frequency_set.B[:, :sequence_length]
H = relevant_frequency_set.H[:, :sequence_length]
T = relevant_frequency_set.T[:]

# prediction:
print("Shape of the arrays (n_sequences, sequence_length) B:", B.shape)
print("Shape of the arrays (n_sequences, sequence_length) H:", H.shape)
print("Shape of the arrays (n_sequences,) T:", T.shape)

H_pred = model(
    B_past=B[:, :past_size],
    B_future=B[:, past_size:],
    H_past=H[:, :past_size],
    T=T,
)

print("Shape of the prediction (n_sequences, sequence_length - past_size), H_pred:", H_pred.shape)


# visualization of predicted trajectories:
max_n_plots = 5
for idx in range(min(H_pred.shape[0], max_n_plots)):
    plot_sequence_prediction(B[idx], H[idx], T[idx], H_pred[idx], past_size=past_size, figsize=(4,4))
    plt.show()

In [None]:
# run the model on all frequencies on random trajectories
loader_key = jax.random.PRNGKey(seed=12)  # key for pseudorandom sampling

plot_model_frequency_sweep(model, material_set, loader_key, past_size=100);
plt.show()

## Evaluation of metrics:

Evaluate SRE and NERE

In [None]:
train_set, val_set, test_set = material_set.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=0
)

In [None]:
metrics = evaluate_cross_validation(
    model,
    test_set,
    scenarios=FINAL_SCENARIOS_PER_MATERIAL[material_name],
    sequence_length=1000,
    batch_size_per_frequency=100,
    loader_key=jax.random.PRNGKey(0),
)

In [None]:
df_models_3C90 = create_multilevel_df(
    {
        exp_ids[material_name]: metrics,
    }
)
display(HTML(df_models_3C90.T.to_html(float_format="%.4f", bold_rows=False)))