In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

from functools import partial
import pathlib
import glob
from tqdm.notebook import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd

In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names, get_file_overview, load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis
from mc2.data_management import FrequencySet, MaterialSet, DataSet, NormalizedFrequencySet, MODEL_DUMP_ROOT

In [None]:
dataset = DataSet.load_from_file(pathlib.Path("../../data/processed") / "ten_mat_data.pickle")

## deleting N49 from dataset for now, since the data is incomplete
# 50 kHz and 80 kHz are missing
# 320 kHz has no data at 25 degrees

available_materials = deepcopy(dataset.material_names)
print(available_materials)
print(len(available_materials))

dataset = dataset.filter_materials(available_materials)
assert dataset.material_names == available_materials

In [None]:
all_relevant_data = dataset.at_material("3C90").at_frequency(50_000).filter_temperatures([25])

normalizer = all_relevant_data.normalize(transform_H=True).normalizer

training_data = FrequencySet(
    all_relevant_data.material_name,
    all_relevant_data.frequency,
    all_relevant_data.H[:200, ::],
    all_relevant_data.B[:200, ::],
    all_relevant_data.T[:200],
)

testing_data = FrequencySet(
    all_relevant_data.material_name,
    all_relevant_data.frequency,
    all_relevant_data.H[200:, ::],
    all_relevant_data.B[200:, ::],
    all_relevant_data.T[200:],
)

In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
from mc2.models.NODE import HiddenStateNeuralEulerODE
from mc2.features.features_jax import add_fe as add_features
from mc2.features.features_jax import compute_fe_single
from mc2.models.model_interface import ModelInterface, NODEwInterface, load_model

In [None]:
# model = HiddenStateNeuralEulerODE(obs_dim=1, state_dim=10, action_dim=5, width_size=64, depth=2, obs_func=lambda x: x[0], key=jax.random.key(0))
# model = eqx.tree_deserialise_leaves(path_or_file=pathlib.Path("../../data/models") / "normalized_NODE_for_interface_tests.eqx", like=model)  # requires you to have a model with the proper form (as has been used when the model was stored, To be extended/fixed...)

model = load_model(MODEL_DUMP_ROOT / "953ca0fe-e78d-48.eqx", model_class=HiddenStateNeuralEulerODE)

In [None]:
def featurize(norm_B_past, norm_H_past, norm_B_future, temperature):
    past_length = norm_B_past.shape[0]
    future_length = norm_B_future.shape[0]
    
    featurized_B = compute_fe_single(jnp.hstack([norm_B_past, norm_B_future]), n_s=10)

    return featurized_B[past_length:]


interfaced_model = NODEwInterface(
    model,
    normalizer=normalizer,
    featurize=featurize,
)

In [None]:
H_pred = interfaced_model(
    B_past=training_data.B[:10, :10],
    H_past=training_data.H[:10, :10],
    B_future=training_data.B[:10, 10:],
    T=training_data.T[:10]
)

print(H_pred.shape)

In [None]:
for H_p, H, B in zip(H_pred, training_data.H[:10, 15:], training_data.B[:10, 15:]):

    fig, axs = plot_single_sequence(B, H, jnp.unique(testing_data.T))
    axs[-1].plot(H_p, label="pred")
    fig.legend()
    plt.show()

In [None]:
H_pred = interfaced_model(
    B_past=training_data.B[0, :15][None],
    H_past=training_data.H[0, :15][None],
    B_future=training_data.B[0, 15:][None],
    T=training_data.T[0][None]
)
H_pred.shape

In [None]:
fig, axs = plot_single_sequence(training_data.B[0, 15:], training_data.H[0, 15:], jnp.unique(testing_data.T))
axs[-1].plot(H_pred, label="pred")
fig.legend()
plt.show()

## Eval Metrics:

In [None]:
from mc2.metrics import evaluate_model

In [None]:
evaluate_model(
    interfaced_model,
    B_past=testing_data.B[:, :1],
    H_past=testing_data.H[:, :1],
    B_future=testing_data.B[:, 1:],
    H_future=testing_data.H[:, 1:],
    T=testing_data.T[:],
    reduce_to_scalar=True,
)

In [None]:
# where should the data come from?

In [None]:
testing_data.B.shape