In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import pathlib
import glob

import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import jax
import jax.numpy as jnp
import equinox as eqx
import optax


In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names,
    get_file_overview,
    filter_file_overview,
    load_single_file,
    load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence
from mc2.data_management import FrequencySet, MaterialSet, DataSet, NormalizedFrequencySet, load_data_into_pandas_df
from mc2.features.features_jax import add_fe

In [None]:
from mc2.training.jax_routine import train_model

#gpus = jax.devices()
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
#jax.config.update("jax_default_device", gpus[0])

In [None]:
from mc2.models.model_interface import ModelInterface, RNNwInterface
from mc2.models.RNN import GRU

In [None]:
from mc2.features.features_jax import compute_fe_single
def featurize(norm_B_past, norm_H_past, norm_B_future, temperature):
    past_length = norm_B_past.shape[0]
    future_length = norm_B_future.shape[0]

    featurized_B = compute_fe_single(jnp.hstack([norm_B_past, norm_B_future]), n_s=10)

    return featurized_B[past_length:]

#feature_names=["original_b","db_dt","d2b_dt2","dyn_avg","pwm_of_b"] #,"frequency"]

In [None]:
#Adjoint method for estimating Jiles-Atherton hysteresis
#model parameters

#On physical aspects of the Jiles-Atherton hysteresis models
def softclip(x, limit=1e8):
    return limit * jnp.tanh(x / limit)

def softplus(x, eps=1e-6):
    return jnp.log1p(jnp.exp(x)) + eps

class JilesAthertonStatic(eqx.Module):
    Ms_param: jax.Array
    a_param: jax.Array
    alpha_param: jax.Array
    k_param: jax.Array
    c_param: jax.Array
    mu_0: float = 4e-7 * jnp.pi
    tau: float = 1/16e6

    @property
    def a(self):
        return 2000*jax.nn.sigmoid(self.a_param)

    @property
    def alpha(self):
        return 1e-2*jax.nn.sigmoid(self.alpha_param)

    @property
    def k(self):
        return 1000.0 *jax.nn.sigmoid(self.k_param)

    @property
    def c(self):
        return jax.nn.sigmoid(self.c_param)
    
    @property
    def Ms(self):
        return 6e6*jax.nn.sigmoid(self.Ms_param)

    def __init__(self, key, **kwargs):
        super().__init__(**kwargs)
        k_key, alpha_key, c_key, Ms_key, a_key = jax.random.split(key, 5)
        self.k_param = jax.random.uniform(k_key, ()) * 0.05 + 0.5
        self.c_param = jax.random.uniform(c_key, ()) * 0.05 + 0.5
        self.a_param = jax.random.uniform(a_key, ()) * 0.05 + 0.5
        self.Ms_param = jax.random.uniform(Ms_key, ()) * 0.05 + 0.5
        self.alpha_param= jax.random.uniform(alpha_key, ())*0.001 + 0.002

    def coth(self,x):
        return 1/ jnp.tanh(x)
    
    def coth_stable(self,x):
        eps = 1e-7
        x = jnp.where(jnp.abs(x) < eps, eps*jnp.sign(x), x)
        return 1/jnp.tanh(x)

    # Updated dM_dH function
    def dM_dH(self, H, M, dB_dt):
        H_e = H + self.alpha * M
        M_an = self.Ms * (self.coth_stable(H_e/self.a) - self.a/H_e)
        delta_m = 0.5*(1+jnp.sign((M_an-M)*dB_dt))

        dM_an_dH_e = self.Ms / self.a * (1 - (self.coth_stable(H_e/self.a))**2 + (self.a/H_e)**2)
        delta = jnp.sign(dB_dt)
        
        numerator = delta_m * (M_an - M) + self.c * self.k * delta * dM_an_dH_e
        denominator = self.k * delta - self.alpha * numerator
        
        dM_dH = numerator / denominator
        
        return dM_dH
    
    def ode(self, B, B_next, H):
        dB_dt_est = (B_next - B) / self.tau
        M = B / self.mu_0 - H
        dM_dH = self.dM_dH(H, M, dB_dt_est)
        dM_dB = dM_dH / (self.mu_0 * (1 + dM_dH))
        dM_dt = dM_dB * dB_dt_est
        dH_dt = 1/self.mu_0 * dB_dt_est - dM_dt

        dH_dt = softclip(dH_dt, limit=1e8)
        
        return dH_dt, dB_dt_est

    def step(self, H, B, B_next):
        dH_dt, _ = self.ode(B, B_next, H)
        H_next = H + self.tau * dH_dt
        B_next = B_next
        return H_next, B_next

    def __call__(self, H0, B_seq):
        
        def body_fun(carry, B_pair):
            H_prev = carry
            B_curr, B_next = B_pair
            H_next, _ = self.step(H_prev, B_curr, B_next)
            return H_next, H_next

        B_pairs = jnp.stack([B_seq[:-1], B_seq[1:]], axis=1)
        _, H_seq = jax.lax.scan(body_fun, H0, B_pairs)
        H_seq = jnp.concatenate([jnp.array([H0]), H_seq], axis=0)
        return H_seq

In [None]:
# hidden_size = 8
# in_size = 7#8
# out_size = 1
seed = 6
key = jax.random.PRNGKey(seed)
key, model_key = jax.random.split(key, 2)

ja_model = JilesAthertonStatic(key=model_key)
optimizer = optax.adam(1e-3)

In [None]:
from typing import Callable
class JAwJAInterface(ModelInterface):
    model: JilesAthertonStatic
    normalizer: eqx.Module
    featurize: Callable = eqx.field(static=True)

    def __call__(self, B_past, H_past, B_future, T):
        B_all = jnp.concatenate([B_past, B_future], axis=1)
        B_all_norm, H_past_norm, T_norm = self.normalizer.normalize(B_all, H_past, T)
        B_past_norm = B_all_norm[:, : B_past.shape[1]]
        B_future_norm = B_all_norm[:, B_past.shape[1]:]
        batch_H_pred_norm = self.normalized_call(B_past_norm, H_past_norm, B_future_norm,T_norm)
        batch_H_pred_denorm = jax.vmap(jax.vmap(self.normalizer.denormalize_H))(batch_H_pred_norm)
        return batch_H_pred_denorm

    def normalized_call(self, B_past_norm, H_past_norm, B_future_norm,T_norm):
        B_all_norm = jnp.concatenate([B_past_norm, B_future_norm], axis=1)
        B_all, H_past, T = self.normalizer.denormalize(B_all_norm, H_past_norm, T_norm)
        B_past = B_all[:, : B_past_norm.shape[1]]
        B_future = B_all[:, B_past_norm.shape[1]:]
        H0 = H_past[:, -1]  
        def single_batch(H0_i, B_future_i):
            H_seq_i = self.model(H0_i, B_future_i)
            return H_seq_i
        batch_H_pred = jax.vmap(single_batch)(H0, B_future)
        batch_H_pred_norm = jax.vmap(jax.vmap(self.normalizer.normalize_H))(batch_H_pred)
        return batch_H_pred_norm

In [None]:
data_dict = load_data_into_pandas_df(material="3C90")
mat_set = MaterialSet.from_pandas_dict(data_dict)
mat_set_f = mat_set.filter_frequencies(frequencies=[500_000])
mat_set_f_T_25 = mat_set_f.filter_temperatures(temperatures=[25])
mat_set_f_T_50 = mat_set_f.filter_temperatures(temperatures=[50])
mat_set_f_T_70 = mat_set_f.filter_temperatures(temperatures=[70])
train_set_f_T_25, val_set_f_T_25, test_set_f_T_25 = mat_set_f_T_25.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=12
)
train_set_f_T_50, val_set_f_T_50, test_set_f_T_50 = mat_set_f_T_50.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=12
)
train_set_f_T_70, val_set_f_T_70, test_set_f_T_70 = mat_set_f_T_70.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=12
)


In [None]:
train_set_norm = train_set_f_T_25.normalize(transform_H=True, featurize=featurize) #, feature_names=feature_names
ja_wrap_25=JAwJAInterface(model=ja_model, normalizer=train_set_norm.normalizer, featurize=featurize)
train_set_norm = train_set_f_T_50.normalize(transform_H=True, featurize=featurize) #, feature_names=feature_names
ja_wrap_50=JAwJAInterface(model=ja_model, normalizer=train_set_norm.normalizer, featurize=featurize)
train_set_norm = train_set_f_T_70.normalize(transform_H=True, featurize=featurize) #, feature_names=feature_names
ja_wrap_70=JAwJAInterface(model=ja_model, normalizer=train_set_norm.normalizer, featurize=featurize)

In [None]:
logs, model_f_T_25 = train_model(
    model=ja_wrap_25,
    optimizer=optimizer,
    n_epochs=75,
    n_steps=0,
    material_name="3C90",
    data_tuple=(train_set_f_T_25, val_set_f_T_25, test_set_f_T_25),
    tbptt_size=20,#, #50
    batch_size=64,
    val_every=1,
    past_size=20,
    key=key,
    seed=seed,
)
eqx.tree_serialise_leaves("temp_data/model_f_T_25.eqx", model_f_T_25)
logs, model_f_T_50 = train_model(
    model=ja_wrap_50,
    optimizer=optimizer,
    n_epochs=75,
    n_steps=0,
    material_name="3C90",
    data_tuple=(train_set_f_T_50, val_set_f_T_50, test_set_f_T_50),
    tbptt_size=20,#, #50
    batch_size=64,
    val_every=1,
    past_size=20,
    key=key,
    seed=seed,
)
eqx.tree_serialise_leaves("temp_data/model_f_T_50.eqx", model_f_T_50)
logs, model_f_T_70 = train_model(
    model=ja_wrap_70,
    optimizer=optimizer,
    n_epochs=75,
    n_steps=0,
    material_name="3C90",
    data_tuple=(train_set_f_T_70, val_set_f_T_70, test_set_f_T_70),
    tbptt_size=20,#, #50
    batch_size=64,
    val_every=1,
    past_size=20,
    key=key,
    seed=seed,
)
eqx.tree_serialise_leaves("temp_data/model_f_T_70.eqx", model_f_T_70)

In [None]:
print(f"{'Parameter':>8} | {'ja_model':>15} | {'model.model':>15}")
print("-"*45)
plt.plot(logs["loss_trends_train"],label="train_loss")
plt.plot(logs["loss_trends_val"],label="val_loss")
plt.yscale("log")
params = ["Ms", "alpha", "a", "k", "c"]

for T,model in zip([25,50,70],[model_f_T_25,model_f_T_50,model_f_T_70]):
    print("Temperature: ",T)
    for p in params:
        ja_val = getattr(ja_model, p)
        model_val = getattr(model.model, p)
        print(f"{p:>8} | {ja_val:15.6e} | {model_val:15.6e}")
   


In [None]:
#model_T_f = eqx.tree_deserialise_leaves("temp_data/model_T_f.eqx", ja_wrap)
_, _, test_set = mat_set.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=12
)
test_set_T=test_set.filter_temperatures([50])
for frequency_idx in [4]:
    batch_idx=jnp.array([0,1,2,3,4,5,6,7,8,9])
    print("Frequency in Hz: ",test_set_T.frequencies[frequency_idx])
    H_pred_T_f_25 = model_f_T_25(
        B_past=test_set_T[frequency_idx].B[batch_idx, :20],
        H_past=test_set_T[frequency_idx].H[batch_idx, :20],
        B_future=test_set_T[frequency_idx].B[batch_idx, 20:],
        T=test_set_T[frequency_idx].T[batch_idx]
    )#    f=test_set[frequency_idx].frequency
    H_pred_T_f_50 = model_f_T_50(
        B_past=test_set_T[frequency_idx].B[batch_idx, :20],
        H_past=test_set_T[frequency_idx].H[batch_idx, :20],
        B_future=test_set_T[frequency_idx].B[batch_idx, 20:],
        T=test_set_T[frequency_idx].T[batch_idx]
    )# 
    H_pred_T_f_70 = model_f_T_70(
        B_past=test_set_T[frequency_idx].B[batch_idx, :20],
        H_past=test_set_T[frequency_idx].H[batch_idx, :20],
        B_future=test_set_T[frequency_idx].B[batch_idx, 20:],
        T=test_set_T[frequency_idx].T[batch_idx]
    )# 
    for H_p_T_f25,H_p_T_f50,H_p_T_f70, H, B,T in zip(H_pred_T_f_25,H_pred_T_f_50,H_pred_T_f_70, test_set_T[frequency_idx].H[batch_idx, 20:], test_set_T[frequency_idx].B[batch_idx, 20:], test_set_T[frequency_idx].T[batch_idx]):
        fig, axs = plot_single_sequence(B, H, T)
        axs[-1].plot(H_p_T_f25, label="pred -- Training: freq=500_000, T=25",linestyle="--")
        axs[-1].plot(H_p_T_f50, label="pred -- Training: freq=500_000, T=50",linestyle="--")
        axs[-1].plot(H_p_T_f70, label="pred -- Training: freq=500_000, T=70",linestyle="--")
        # axs[-1].plot(H_p_T, label="pred-Training: freq_idx:500_000 - T:all",linestyle="-.")
        # axs[-1].plot(H_p_all, label="pred-Training: freq_idx:all - T:all",linestyle=":")
        axs[-1].legend(loc="upper left")
        plt.show()

In [None]:
train_set, val_set, test_set = mat_set.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=12
)
train_set_norm = train_set.normalize(transform_H=True, featurize=featurize) #, feature_names=feature_names
ja_wrap=JAwJAInterface(model=ja_model, normalizer=train_set_norm.normalizer, featurize=featurize)

In [None]:
logs, model_all = train_model(
    model=ja_wrap,
    optimizer=optimizer,
    n_epochs=100,
    n_steps=0,
    material_name="3C90",
    data_tuple=(train_set, val_set, test_set),
    tbptt_size=20,#, #50
    batch_size=64,
    val_every=1,
    past_size=20,
    key=key,
    seed=seed,
)

In [None]:
print(f"{'Parameter':>8} | {'ja_model':>15} | {'model.model':>15}")
print("-"*45)

params = ["Ms", "alpha", "a", "k", "c"]

for p in params:
    ja_val = getattr(ja_model, p)
    model_val = getattr(model_all.model, p)
    print(f"{p:>8} | {ja_val:15.6e} | {model_val:15.6e}")
plt.plot(logs["loss_trends_train"],label="train_loss")
plt.plot(logs["loss_trends_val"],label="val_loss")
plt.yscale("log")

In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names, get_file_overview, load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis

In [None]:
def plot_single_sequence2(B, H, T, t=None, fig=None, axs=None):
    if fig is None or axs is None:
        fig, axs = plt.subplots(2, 1, figsize=(10, 10), sharex=True)

    fig.suptitle("Temperature: " + str(T) + " CÂ°")
    if t is None:
        axs[0].plot(B)
        #axs[1].plot(H)

    else:
        axs[0].plot(t, B)
        #axs[1].plot(t, H)
        for ax in axs:
            ax.set_xlabel("Time in s")

    axs[0].set_ylabel("B in T")
    axs[1].set_ylabel("H in A/m")

    for ax in axs:
        ax.grid()

    fig.tight_layout()
    return fig, axs

In [None]:
#model_T_f = eqx.tree_deserialise_leaves("temp_data/model_T_f.eqx", ja_wrap)
_, _, test_set = mat_set.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=12
)
test_set_T=test_set.filter_temperatures([25])
for frequency_idx in [1]:
    batch_idx=jnp.array([0,1,2,3,4,5,6,7,8,9])
    print("Frequency in Hz: ",test_set_T.frequencies[frequency_idx])
    H_pred_all = model_all(
        B_past=test_set_T[frequency_idx].B[batch_idx, :20],
        H_past=test_set_T[frequency_idx].H[batch_idx, :20],
        B_future=test_set_T[frequency_idx].B[batch_idx, 20:],
        T=test_set_T[frequency_idx].T[batch_idx]
    )#    f=test_set[frequency_idx].frequency
    
    for H_p_all, H, B,T in zip(H_pred_all, test_set_T[frequency_idx].H[batch_idx, 20:], test_set_T[frequency_idx].B[batch_idx, 20:], test_set_T[frequency_idx].T[batch_idx]):

        fig, axs = plot_single_sequence(B, H, T)
        axs[-1].plot(H_p_all, label="pred -- Training: freq=All, T=All",linestyle="--")
        # axs[-1].plot(H_p_T, label="pred-Training: freq_idx:500_000 - T:all",linestyle="-.")
        # axs[-1].plot(H_p_all, label="pred-Training: freq_idx:all - T:all",linestyle=":")
        axs[-1].legend(loc="upper left")
        plt.show()