In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
from tqdm.notebook import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names, get_file_overview, load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis
from mc2.data_management import FrequencySet, MaterialSet, DataSet

In [None]:
dataset = DataSet.load_from_file(pathlib.Path("../../data/processed") / "ten_mat_data.pickle")

## deleting N49 from dataset for now, since the data is incomplete
# 50 kHz and 80 kHz are missing
# 320 kHz has no data at 25 degrees

available_materials = deepcopy(dataset.material_names)
print(available_materials)
print(len(available_materials))

dataset = dataset.filter_materials(available_materials)
assert dataset.material_names == available_materials

## Preliminaries

In [None]:
import jax
# jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import equinox as eqx
import optax

In [None]:
tau = 1 #
#/ (16 * 1e6)

In [None]:
all_relevant_data = dataset.at_material("3C90").at_frequency(50_000).filter_temperatures([25])
normalizer = all_relevant_data.normalize(transform_H=True).normalizer
all_relevant_data

In [None]:
# normalizer = Normalizer.from_material_set(material_set)

In [None]:
# all_relevant_data = FrequencySet(
#     material_name='3C90',
#     frequency=50000.0,
#     H=all_relevant_data.H.astype(jnp.float64),
#     B=all_relevant_data.B.astype(jnp.float64),
#     T=all_relevant_data.T.astype(jnp.float64),
# )

In [None]:
training_data = FrequencySet(
    all_relevant_data.material_name,
    all_relevant_data.frequency,
    all_relevant_data.H[:200, ::5],
    all_relevant_data.B[:200, ::5],
    all_relevant_data.T[:200],
)
norm_training_data = training_data.normalize(normalizer=normalizer)

testing_data = FrequencySet(
    all_relevant_data.material_name,
    all_relevant_data.frequency,
    all_relevant_data.H[200:, ::5],
    all_relevant_data.B[200:, ::5],
    all_relevant_data.T[200:],
)
norm_testing_data = testing_data.normalize(normalizer=normalizer)

In [None]:
def evaluate_on_test_data(test_data, model):
    batched_H = test_data.H[:, :][..., None]
    batched_B = test_data.B[:, :][..., None]
    _, pred_H = jax.vmap(model, in_axes=(0, 0, None))(batched_H[:, 0, :], add_features(batched_B[:, 1:, 0], n_s=10), tau)

    for i in range(min(batched_H.shape[0], 20)):
        fig, axs = plot_single_sequence(batched_B[i], batched_H[i], jnp.unique(test_data.T))
        axs[-1].plot(pred_H[i], label="pred")
        fig.legend()
        plt.show()

In [None]:
from mc2.training.optimization import make_step
from mc2.training.data_sampling import draw_data_uniformly

from mc2.models.NODE import HiddenStateNeuralEulerODE
from mc2.features.features_jax import add_fe as add_features

In [None]:
@eqx.filter_value_and_grad
def grad_loss(model, u, x, tau, featurize):

    _, pred_x = jax.vmap(model, in_axes=(0, 0, None))(x[:, 0, :], add_features(u[:, 1:, 0], n_s=10), tau)  # does the first "action" belong between x0 and x1 or between x_{-1} and x0?

    feat_pred_x = jax.vmap(featurize, in_axes=(0))(pred_x)
    feat_true_x = jax.vmap(featurize, in_axes=(0))(x)

    return jnp.mean((feat_pred_x - feat_true_x) ** 2)


@eqx.filter_jit
def make_step(model, u, x, tau, opt_state, featurize, optim):
    loss, grads = grad_loss(model, u, x, tau, featurize)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [None]:
n_sequences, full_sequence_length = norm_training_data.H.shape
training_batch_size = 512
sequence_length = 64

identity = lambda x: x

lr = optax.schedules.exponential_decay(
    init_value=1e-3,
    transition_steps=1_000_000,
    transition_begin=2_000,
    decay_rate=0.1,
    end_value=1e-4
)
# lr = 1e-3

key = jax.random.key(111)
key, model_key, loader_key = jax.random.split(key, 3)

model = HiddenStateNeuralEulerODE(obs_dim=1, state_dim=10, action_dim=5, width_size=64, depth=2, obs_func=lambda x: x[0], key=model_key)

# model = HiddenStateNeuralEulerODE(obs_dim=1, state_dim=10, action_dim=5, width_size=32, depth=2, obs_func=lambda x: x[0], key=model_key)
# model = HiddenStateNeuralEulerODE(obs_dim=1, state_dim=10, action_dim=5, width_size=32, depth=2, obs_func=lambda x: jnp.sum(x), key=model_key)
optim = optax.adam(lr)

opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

In [None]:
model

In [None]:
training_losses = []

for step in tqdm(range(1_000_000)):
    
    batched_H, batched_B, loader_key = draw_data_uniformly(norm_training_data, sequence_length, training_batch_size, loader_key)
    loss, model, opt_state = make_step(model, u=batched_B, x=batched_H, tau=tau, opt_state=opt_state, featurize=identity, optim=optim)    
    training_losses.append(loss)


    if step % 100_000 == 0 and step > 0:
        plt.suptitle(f"Training loss over training steps at {step} steps")
        plt.plot(np.log(training_losses))
        plt.show()

        evaluate_on_test_data(norm_testing_data, model)

In [None]:
plt.suptitle("Final Training loss over over training steps")
plt.plot(np.log(training_losses))
plt.show()

## Consider Testing data

In [None]:
evaluate_on_test_data(norm_testing_data, model)

In [None]:
batched_H = norm_testing_data.H[:, :][..., None]
batched_B = norm_testing_data.B[:, :][..., None]

_, pred_H = jax.vmap(model, in_axes=(0, 0, None))(batched_H[:, 0, :], add_features(batched_B[:, 1:1000, 0], n_s=10), tau)

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_single_sequence(batched_B[i, :1000], batched_H[i, :1000], jnp.unique(norm_testing_data.T))
    axs[-1].plot(pred_H[i], label="pred")
    fig.legend()

### Hysteresis plots:

In [None]:
batched_H = norm_testing_data.H[:, :][..., None]
batched_B = norm_testing_data.B[:, :][..., None]

_, pred_H = jax.vmap(model, in_axes=(0, 0, None))(batched_H[:, 0, :], add_features(batched_B[:, 1:, 0], n_s=10), tau)

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i], batched_H[i], jnp.unique(norm_testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i], pred_H[i, :], jnp.unique(norm_testing_data.T), fig=fig, axs=axs)
    fig.legend()

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i, :1_000], batched_H[i, :1_000], jnp.unique(norm_testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i, :1_000], pred_H[i, :1_000], jnp.unique(norm_testing_data.T), fig=fig, axs=axs)
    fig.legend()

In [None]:
start = 5_000
end = 6_000


for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i, start:end], batched_H[i, start:end], jnp.unique(norm_testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i, start:end], pred_H[i, start:end], jnp.unique(norm_testing_data.T), fig=fig, axs=axs)
    fig.legend()

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i, :1_000], batched_H[i, :1_000], jnp.unique(norm_testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i, :1_000], pred_H[i, :1_000], jnp.unique(norm_testing_data.T), fig=fig, axs=axs)
    fig.legend()

## Loading and Saving:

In [None]:
from mc2.models.NODE import save_model, load_model
import mc2

### Save model:

In [None]:
jax.random.key_data(model_key)

In [None]:
# models_path = mc2.data_management.MODEL_DUMP_ROOT 

# save_model(
#     filename=models_path / "normalized_NODE_for_interface_tests.json",
#     hyperparams=dict(
#         obs_dim=1, state_dim=10, action_dim=5, width_size=64, depth=2, obs_func=None, key=jax.random.key_data(model_key).tolist()
#     ),
#     model=model,
# )

In [None]:
model_path = pathlib.Path("../../data/models") / "normalized_NODE_for_interface_tests.eqx"

In [None]:
eqx.tree_serialise_leaves(model_path , model)

### Load model:

In [None]:
# model = eqx.tree_deserialise_leaves(path_or_file=pathlib.Path("../../data/models") / "naiveNODE.eqx", like=model)  # requires you to have a model with the proper form (as has been used when the model was stored, To be extended/fixed...)

In [None]:
# model

In [None]:
model

15 -> 32 -> 32 -> 10

In [None]:
15 * 32 + 32 + 32 * 32 + 32 + 10 * 32 + 10