In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names, get_file_overview, load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis
from mc2.data_management import FrequencySet, MaterialSet, DataSet

In [None]:
dataset = DataSet.load_from_file(pathlib.Path("../../data/processed") / "ten_mat_data.pickle")

## deleting N49 from dataset for now, since the data is incomplete
# 50 kHz and 80 kHz are missing
# 320 kHz has no data at 25 degrees

available_materials = deepcopy(dataset.material_names)
print(available_materials)
print(len(available_materials))

available_materials.remove("N49")
print(available_materials)
print(len(available_materials))


dataset = dataset.filter_materials(available_materials)
assert dataset.material_names == available_materials

## Preliminaries

In [None]:
import jax
import jax.numpy as jnp
import equinox as eqx
import optax

In [None]:
tau = 1 #
#/ (16 * 1e6)

In [None]:
all_relevant_data = dataset.at_material("3C90").at_frequency(50_000).filter_temperatures([25])
all_relevant_data

In [None]:
training_data = FrequencySet(
    all_relevant_data.material_name,
    all_relevant_data.frequency,
    all_relevant_data.H[:200, :],
    all_relevant_data.B[:200, :],
    all_relevant_data.T[:200],
)

testing_data = FrequencySet(
    all_relevant_data.material_name,
    all_relevant_data.frequency,
    all_relevant_data.H[200:, :],
    all_relevant_data.B[200:, :],
    all_relevant_data.T[200:],
)

In [None]:
def evaluate_on_test_data(testing_data, model):
    batched_H = testing_data.H[:, :][..., None]
    batched_B = testing_data.B[:, :][..., None]
    pred_H = jax.vmap(model, in_axes=(0, 0, None))(batched_H[:, 0, :], batched_B[:, 1:, :], tau)

    for i in range(min(batched_H.shape[0], 20)):
        fig, axs = plot_single_sequence(batched_B[i], batched_H[i], jnp.unique(testing_data.T))
        axs[-1].plot(pred_H[i], label="pred")
        fig.legend()
        plt.show()

In [None]:
from mc2.training.optimization import make_step
from mc2.training.data_sampling import draw_data_uniformly

from mc2.models.NODE import NeuralEulerODE

In [None]:
@eqx.filter_value_and_grad
def grad_loss(model, u, x, tau, featurize):

    pred_x = jax.vmap(model, in_axes=(0, 0, None))(x[:, 0, :], u[:, 1:, :], tau)  # does the first "action" belong between x0 and x1 or between x_{-1} and x0?

    feat_pred_x = jax.vmap(featurize, in_axes=(0))(pred_x)
    feat_true_x = jax.vmap(featurize, in_axes=(0))(x)

    return jnp.mean((feat_pred_x - feat_true_x) ** 2)


@eqx.filter_jit
def make_step(model, u, x, tau, opt_state, featurize, optim):
    loss, grads = grad_loss(model, u, x, tau, featurize)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

In [None]:
n_sequences, full_sequence_length = training_data.H.shape
training_batch_size = 512
sequence_length = 100

identity = lambda x: x

learning_rate = 1e-3

key = jax.random.key(4)
key, model_key, loader_key = jax.random.split(key, 3)

model = NeuralEulerODE(obs_dim=1, action_dim=1, width_size=128, depth=4, key=model_key)
optim = optax.adam(learning_rate)

opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

In [None]:
training_losses = []

for step in tqdm.tqdm(range(20_000)):

    # if step % 1000 == 0 and step > 0 and step < 5_000:
    #     sequence_length += 100
    #     print("current sequence length:", sequence_length)
    
    batched_H, batched_B, loader_key = draw_data_uniformly(training_data, sequence_length, training_batch_size, loader_key)
    loss, model, opt_state = make_step(model, u=batched_B, x=batched_H, tau=tau, opt_state=opt_state, featurize=identity, optim=optim)    
    training_losses.append(loss)


    if step % 10_000 == 0 and step > 0:
        plt.suptitle(f"Training loss over training steps at {step} steps")
        plt.plot(np.log(training_losses))
        plt.show()

        evaluate_on_test_data(testing_data, model)

In [None]:
plt.suptitle("Final Training loss over over training steps")
plt.plot(np.log(training_losses))
plt.show()

## Consider Testing data

In [None]:
batched_H = testing_data.H[:, :][..., None]
batched_B = testing_data.B[:, :][..., None]

pred_H = jax.vmap(model, in_axes=(0, 0, None))(batched_H[:, 0, :], batched_B[:, 1:, :], tau)

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_single_sequence(batched_B[i], batched_H[i], jnp.unique(testing_data.T))
    axs[-1].plot(pred_H[i], label="pred")
    fig.legend()

### Hysteresis plots:

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i], batched_H[i], jnp.unique(testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i], pred_H[i, :], jnp.unique(testing_data.T), fig=fig, axs=axs)
    fig.legend()

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i, :1_000], batched_H[i, :1_000], jnp.unique(testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i, :1_000], pred_H[i, :1_000], jnp.unique(testing_data.T), fig=fig, axs=axs)
    fig.legend()

In [None]:
start = 5_000
end = 6_000


for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i, start:end], batched_H[i, start:end], jnp.unique(testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i, start:end], pred_H[i, start:end], jnp.unique(testing_data.T), fig=fig, axs=axs)
    fig.legend()

In [None]:
for i in range(min(batched_H.shape[0], 20)):
    fig, axs = plot_hysteresis(batched_B[i, :1_000], batched_H[i, :1_000], jnp.unique(testing_data.T)) 
    fig, axs = plot_hysteresis(batched_B[i, :1_000], pred_H[i, :1_000], jnp.unique(testing_data.T), fig=fig, axs=axs)
    fig.legend()

### Save model:

In [None]:
model_path = pathlib.Path("../../data/models") / "naiveNODE.eqx"

In [None]:
eqx.tree_serialise_leaves(model_path , model)

### Load model:

In [None]:
model = eqx.tree_deserialise_leaves(path_or_file=pathlib.Path("../../data/models") / "naiveNODE.eqx", like=model)  # requires you to have a model with the proper form (as has been used when the model was stored, To be extended/fixed...)

In [None]:
model