In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import pathlib
import glob

import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import jax
import jax.numpy as jnp
import equinox as eqx
import optax


In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names,
    get_file_overview,
    filter_file_overview,
    load_single_file,
    load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence
from mc2.data_management import FrequencySet, MaterialSet, DataSet, NormalizedFrequencySet, load_data_into_pandas_df
from mc2.features.features_jax import add_fe

In [None]:
from mc2.training.jax_routine import train_model

#gpus = jax.devices()
# jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
#jax.config.update("jax_default_device", gpus[0])

In [None]:
from mc2.models.model_interface import ModelInterface, RNNwInterface
from mc2.models.RNN import GRU

In [None]:
hidden_size = 8
in_size = 7#8
out_size = 1

seed = 5

key = jax.random.PRNGKey(seed)
key, model_key = jax.random.split(key, 2)

rnn_model = GRU(in_size, out_size, hidden_size, key=model_key)
optimizer = optax.adam(1e-3)

In [None]:
from mc2.features.features_jax import compute_fe_single
def featurize(norm_B_past, norm_H_past, norm_B_future, temperature):
    past_length = norm_B_past.shape[0]
    future_length = norm_B_future.shape[0]

    featurized_B = compute_fe_single(jnp.hstack([norm_B_past, norm_B_future]), n_s=10)

    return featurized_B[past_length:]

#feature_names=["original_b","db_dt","d2b_dt2","dyn_avg","pwm_of_b"] #,"frequency"]

In [None]:
data_dict = load_data_into_pandas_df(material="3C90")
mat_set = MaterialSet.from_pandas_dict(data_dict)
train_set, val_set, test_set = mat_set.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=12
)
train_set_norm = train_set.normalize(transform_H=True, featurize=featurize) #, feature_names=feature_names

In [None]:
rnn_wrap=RNNwInterface(rnn=rnn_model, normalizer=train_set_norm.normalizer, featurize=featurize)

In [None]:
logs, model = train_model(
    model=rnn_wrap,
    optimizer=optimizer,
    n_steps=100,
    material_name="3C90",
    tbptt_size=50,#512,
    batch_size=64,
    val_every=500,
    past_size=10,
    key=key,
    seed=seed,
)

In [None]:
plt.plot(logs["loss_trends_train"])
plt.yscale("log")

In [None]:
plt.plot(logs["loss_trends_val"])
plt.yscale("log")

In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names, get_file_overview, load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis


In [None]:
frequency_idx=2
batch_idx=jnp.array([1,2,3,4,5,6,7])
H_pred = model(
    B_past=test_set[frequency_idx].B[batch_idx, :15],
    H_past=test_set[frequency_idx].H[batch_idx, :15],
    B_future=test_set[frequency_idx].B[batch_idx, 15:],
    T=test_set[frequency_idx].T[batch_idx]
)#    f=test_set[frequency_idx].frequency
H_pred.shape

In [None]:
for H_p, H, B,T in zip(H_pred, test_set[frequency_idx].H[batch_idx, 15:], test_set[frequency_idx].B[batch_idx, 15:], test_set[frequency_idx].T[batch_idx]):

    fig, axs = plot_single_sequence(B, H, T)
    axs[-1].plot(H_p, label="pred")
    fig.legend()
    plt.show()

In [None]:
from mc2.models.model_interface import save_model, load_model
from mc2.data_management import MODEL_DUMP_ROOT

In [None]:
# save_model(MODEL_DUMP_ROOT / "testy_test.eqx", {"in_size": in_size, "out_size": out_size, "hidden_size": hidden_size}, model.rnn)

In [None]:
loaded_rnn = load_model(MODEL_DUMP_ROOT / "bdd216a8.eqx", GRU)

In [None]:
test_model = RNNwInterface(rnn=loaded_rnn, normalizer=train_set_norm.normalizer, featurize=featurize)

In [None]:
test_model

In [None]:
frequency_idx=3
batch_idx=jnp.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15])
H_pred = test_model(
    B_past=test_set[frequency_idx].B[batch_idx, :1],
    H_past=test_set[frequency_idx].H[batch_idx, :1],
    B_future=test_set[frequency_idx].B[batch_idx, 1:],
    T=test_set[frequency_idx].T[batch_idx]
)#    f=test_set[frequency_idx].frequency
H_pred.shape

In [None]:
for H_p, H, B,T in zip(H_pred, test_set[frequency_idx].H[batch_idx, 1:], test_set[frequency_idx].B[batch_idx, 1:], test_set[frequency_idx].T[batch_idx]):

    fig, axs = plot_single_sequence(B, H, T)
    axs[-1].plot(H_p, label="pred")
    fig.legend()
    plt.show()

In [None]:
from mc2.metrics import evaluate_model

In [None]:
frequency = 80_000

metrics = evaluate_model(
    test_model,
    B_past=test_set.at_frequency(frequency).B[:, :1],
    H_past=test_set.at_frequency(frequency).H[:, :1],
    B_future=test_set.at_frequency(frequency).B[:, 1:],
    H_future=test_set.at_frequency(frequency).H[:, 1:],
    T=test_set.at_frequency(frequency).T[:],
    reduce_to_scalar=True,
)

In [None]:
metrics

In [None]:
val_set