In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import json
from functools import partial
import pathlib
import glob
from tqdm.notebook import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd

In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx

In [None]:
from mc2.data_management import MaterialSet, EXPERIMENT_LOGS_ROOT, MODEL_DUMP_ROOT, load_data_into_pandas_df
from mc2.features.features_jax import compute_fe_single
from mc2.models.NODE import HiddenStateNeuralEulerODE
from mc2.models.model_interface import NODEwInterface, load_model
from mc2.metrics import evaluate_model, evaluate_model_on_test_set

In [None]:
def featurize(norm_B_past, norm_H_past, norm_B_future, temperature):
    past_length = norm_B_past.shape[0]
    future_length = norm_B_future.shape[0]

    featurized_B = compute_fe_single(jnp.hstack([norm_B_past, norm_B_future]), n_s=10)

    return featurized_B[past_length:]

In [None]:
data_dict = load_data_into_pandas_df(material="3C90")
mat_set = MaterialSet.from_pandas_dict(data_dict)
train_set, val_set, test_set = mat_set.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=12
)
train_set_norm = train_set.normalize(transform_H=True, featurize=featurize)
normalizer = train_set_norm.normalizer

In [None]:
exp_id = "dcc2150a-7c6b-49"

In [None]:
interfaced_model = NODEwInterface(
    load_model(MODEL_DUMP_ROOT / f"{exp_id}.eqx", model_class=HiddenStateNeuralEulerODE),
    normalizer=normalizer,
    featurize=featurize,
)

In [None]:
norm_test_set = test_set.normalize(normalizer)

In [None]:
eval_metrics = evaluate_model_on_test_set(
    interfaced_model,
    test_set,
)
eval_metrics

# plots:

In [None]:
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis

In [None]:
data = test_set.at_frequency(80_000)

H_pred = interfaced_model(
    data.B[:10, :10],
    data.H[:10, :10],
    data.B[:10, 10:500],
    data.T[:10]
)

In [None]:
for H_p, H, B in zip(H_pred, data.H[:10, 10:500], data.B[:10, 10:500]):

    fig, axs = plot_single_sequence(B, H, jnp.unique(data.T))
    axs[-1].plot(H_p, label="pred")
    fig.legend()
    plt.show()

In [None]:
with open(EXPERIMENT_LOGS_ROOT / "jax_experiments" / f"{exp_id}.json") as f:
    exp_results = json.load(f)

In [None]:
exp_results