In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
from tqdm.notebook import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names, get_file_overview, load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis
from mc2.data_management import FrequencySet, MaterialSet, DataSet

In [None]:
import jax
jax.config.update("jax_platform_name", "cpu")
import jax.numpy as jnp
import equinox as eqx
import optax

---

In [None]:
from mc2.runners.model_setup_jax import setup_model
from mc2.utils.model_evaluation import reconstruct_model_from_exp_id, get_exp_ids

In [None]:
wrapped_model, optimizer, params, (train_set, eval_set, test_set) = setup_model(model_label="GRU", material_name="3C90", model_key=jax.random.PRNGKey(0), n_epochs=300, tbptt_size=128, batch_size=512,)

In [None]:
get_exp_ids(material_name=None, model_type="GRU")

In [None]:
wrapped_model = reconstruct_model_from_exp_id('3C90_GRU_31895366-dd82-4f')

In [None]:
wrapped_model

In [None]:
plt.plot(test_set.at_frequency(80_000).B[0, :1000] / jnp.max(jnp.abs(test_set.at_frequency(80_000).B[0, :1000])), label="B")
plt.plot(test_set.at_frequency(80_000).H[0, :1000] / jnp.max(jnp.abs(test_set.at_frequency(80_000).H[0, :1000])), label="H")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
from mc2.model_interfaces.linear_interfaces import LinearInterface
from mc2.models.linear import LinearStatic

from mc2.models.RNN import GRUwLinearModel, GRU
from mc2.model_interfaces.rnn_Interfaces import GRUwLinearModelInterface, MagnetizationRNNwInterface, RNNwInterface

In [None]:
# model = LinearStatic(11, 1, key=jax.random.PRNGKey(0)) 

# model = GRUwLinearModel(in_size=7, hidden_size=8, linear_in_size=7, key=jax.random.PRNGKey(0))

In [None]:
# wrapped_linear_model = LinearInterface(model, normalizer=wrapped_model.normalizer, featurize=wrapped_model.featurize)

# wrapped_model = GRUwLinearModelInterface(model, normalizer=wrapped_model.normalizer, featurize=wrapped_model.featurize)

In [None]:
# model_params_d = dict(hidden_size=8, in_size=7, key=jax.random.PRNGKey(0))
# model = GRU(**model_params_d)

# wrapped_model = RNNwInterface(
#     model=model,
#     normalizer=wrapped_model.normalizer,
#     featurize=wrapped_model.featurize
# )

# wrapped_model = MagnetizationRNNwInterface(
#     model=model,
#     normalizer=wrapped_model.normalizer,
#     featurize=wrapped_model.featurize
# )

In [None]:
from mc2.features.features_jax import db_dt, d2b_dt2, dyn_avg

In [None]:
seq_idx = 2
seq_start = 0
seq_len = 100

B_test = test_set.at_frequency(800_000).B[seq_idx, seq_start:seq_start+seq_len]
H_test = test_set.at_frequency(800_000).H[seq_idx, seq_start:seq_start+seq_len]

In [None]:
B_test_norm = B_test / jnp.max(jnp.abs(B_test))
H_test_norm = H_test / jnp.max(jnp.abs(H_test))

In [None]:
plt.plot(B_test_norm, label="B")
plt.plot(H_test_norm, label="H")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

In [None]:
from mc2.features.features_jax import shift_signal

# B_shifted = shift_signal(B_test_norm, k_0=5)

correlation_values = jnp.correlate(
    B_test_norm - jnp.mean(B_test_norm),
    H_test_norm - jnp.mean(H_test_norm),
    mode="full",
)

x = jnp.arange(-seq_len+1, seq_len, 1)

plt.plot(x, correlation_values)
plt.grid(True, alpha=0.3)
plt.show()
print()

In [None]:
plt.plot(shift_signal(B_test_norm - jnp.mean(B_test_norm), 0), label="B", linestyle="dashed")
plt.plot(shift_signal(B_test_norm - jnp.mean(B_test_norm), 3), label="B_shifted")
plt.plot(H_test_norm - jnp.mean(H_test_norm), label="H")
plt.grid(True, alpha=0.3)
plt.legend()

In [None]:
signal1 = B_test_norm - jnp.mean(B_test_norm)
signal2 = H_test_norm - jnp.mean(H_test_norm)

def min_max_norm(x):
    min_x = jnp.min(x)
    max_x = jnp.max(x)

    return (x - min_x) / (max_x - min_x)

signal1 = min_max_norm(signal1)
signal2 = min_max_norm(signal2)

In [None]:
plt.plot(db_dt(signal1), label="B")
plt.plot(db_dt(signal2), label="H")
plt.grid(True, alpha=0.3)

In [None]:
plt.plot(db_dt(B_test_norm - jnp.mean(B_test_norm)), label="B")
plt.plot(db_dt(H_test_norm - jnp.mean(H_test_norm)), label="H")
plt.grid(True, alpha=0.3)

In [None]:
plt.plot(B_shifted, label="B_shifted")
plt.plot(B_test_norm, label="B")
plt.plot(H_test_norm, label="H")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

In [None]:
plt.plot(db_dt(B_shifted), label="B_shifted")
plt.plot(db_dt(B_test_norm), label="B", linestyle="dashed")
plt.plot(db_dt(H_test_norm), label="H")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

In [None]:
plt.plot(d2b_dt2(B_shifted), label="B_shifted")
plt.plot(d2b_dt2(B_test_norm), label="B")
plt.plot(d2b_dt2(H_test_norm), label="H")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

In [None]:
plt.plot(dyn_avg(B_test_norm, n_s=15, mirrored_padding=True), label="B_averaged")
plt.plot(B_test_norm, label="B")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

---

In [None]:
B_test = test_set.at_frequency(80_000).B[:10, :1000]
H_past = test_set.at_frequency(80_000).H[:10, :100]
H_future = test_set.at_frequency(80_000).H[:10, 100:1000]
T = test_set.at_frequency(80_000).T[:10]

In [None]:
B_past = B_test[:, :100]
B_future = B_test[:, 100:]

In [None]:
H_est = wrapped_model(B_past, H_past, B_future, T)
H_est.shape

In [None]:
# B_future_norm, H_est_norm, T_norm = wrapped_linear_model.normalizer.normalize(B_future, H_est, T)
# _, H_future_norm, _ = wrapped_linear_model.normalizer.normalize(B_future, H_future, T)

B_future_norm, H_est_norm, T_norm = wrapped_model.normalizer.normalize(B_future, H_est, T)
_, H_future_norm, _ = wrapped_model.normalizer.normalize(B_future, H_future, T)

plt.plot(B_future_norm[2])
plt.show()
plt.plot(H_future_norm[2])
plt.plot(H_est_norm[2])

In [None]:
plt.plot(B_test[2])
plt.grid(True, alpha=0.3)
plt.show()
plt.plot(jnp.hstack([H_past[2], H_future[2]]))
plt.grid(True, alpha=0.3)
plt.plot()

In [None]:
# B_future_norm, H_est_norm, T_norm = wrapped_linear_model.normalizer.normalize(B_future, H_est, T)
# _, H_future_norm, _ = wrapped_linear_model.normalizer.normalize(B_future, H_future, T)

B_future_norm, H_est_norm, T_norm = wrapped_model.normalizer.normalize(B_future, H_est, T)
_, H_future_norm, _ = wrapped_model.normalizer.normalize(B_future, H_future, T)

plt.plot(B_future_norm[2])
plt.plot(H_future_norm[2])
plt.plot(H_est_norm[2])

In [None]:
B_all = B_test

In [None]:
B_all_padded = jnp.pad(B_all, ((0, 0), (5, 5)), mode='reflect', reflect_type="odd")

In [None]:
B_all_padded.shape

In [None]:
B_in = jnp.concatenate([jnp.roll(B_all_padded,idx)[..., None] for idx in jnp.arange(-5, 5 + 1e-7, 1,)], axis=-1)[:, 5 + 100: -5, :]

In [None]:
plt.plot(B_future[0])

In [None]:
plt.plot(B_in[0, :, 0])

In [None]:
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis

In [None]:
B_past = jnp.ones((1,1)) * 0.01
H_past = jnp.ones((1,1)) * 0.01
B_future = jnp.concatenate([jnp.linspace(0.01, 0.4, 1000), jnp.linspace(0.4, -0.4, 2000), np.linspace(-0.4, 1, 2000)])[None, :]
H_future = jnp.ones_like(B_future)
T=jnp.array([25])

In [None]:
H_pred = wrapped_model(B_past, H_past, B_future, T)

In [None]:
fig, axs = plot_single_sequence(B_future[0], H_future[0], T=T)
axs[-1].plot(H_pred[0])

In [None]:
fig, axs = plot_hysteresis(B_future[0], H_future[0], T=T)
axs.plot(H_pred[0], B_future[0] / H_pred[0])

In [None]:
mu_0 = 4 * jnp.pi * 1e-7

plt.plot(H_pred[0], B_future[0] / H_pred[0] / mu_0) 
plt.ylim(0, 5000)