# PINN

PINN for the Jiles-Atherton model with GRUs

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import jax
import jax.numpy as jnp
import jax.random as jr

# jax.config.update("jax_enable_x64", True)

import equinox as eqx
import optax
import matplotlib.pyplot as plt

from typing import List

import pathlib

In [2]:
from mc2.utils.data_inspection import (
    get_available_material_names,
    get_file_overview,
    filter_file_overview,
    load_single_file,
    load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence
from mc2.data_management import FrequencySet, MaterialSet, DataSet, NormalizedFrequencySet, load_data_into_pandas_df
from mc2.features.features_jax import add_fe

In [3]:
from mc2.training.jax_routine import train_model
from mc2.models.model_interface import ModelInterface, RNNwInterface
from mc2.models.RNN import GRU

from mc2.features.features_jax import compute_fe_single
def featurize(norm_B_past, norm_H_past, norm_B_future, temperature):
    past_length = norm_B_past.shape[0]
    future_length = norm_B_future.shape[0]

    featurized_B = compute_fe_single(jnp.hstack([norm_B_past, norm_B_future]), n_s=10)

    return featurized_B[past_length:]

In [4]:
data_dict = load_data_into_pandas_df(material="3C90")
mat_set = MaterialSet.from_pandas_dict(data_dict)

Loading data for 3C90: 100%|██████████| 21/21 [00:06<00:00,  3.13it/s]
INFO:2025-11-10 09:27:38,444:jax._src.xla_bridge:925: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


2025-11-10 09:27:38 | INFO : Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


INFO:2025-11-10 09:27:38,446:jax._src.xla_bridge:925: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


2025-11-10 09:27:38 | INFO : Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


In [5]:
# new for data sampling and batch
# dataset_3C90 = DataSet.load_from_file(pathlib.Path("../../../data/processed") / "3C90.pickle")

dataset_3C90 = mat_set

mat_set_f = dataset_3C90.at_frequency(500_000)
mat_set_f_T_25 = mat_set_f.filter_temperatures(temperatures=[25])
mat_set_f_T_50 = mat_set_f.filter_temperatures(temperatures=[50])
mat_set_f_T_70 = mat_set_f.filter_temperatures(temperatures=[70])
train_set_f_T_25, val_set_f_T_25, test_set_f_T_25 = mat_set_f_T_25.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=12
)
train_set_f_T_50, val_set_f_T_50, test_set_f_T_50 = mat_set_f_T_50.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=12
)
train_set_f_T_70, val_set_f_T_70, test_set_f_T_70 = mat_set_f_T_70.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=12
)

IndexError: boolean index did not match shape of indexed array in index 0: got (3300,), expected (2978,)

In [None]:
mat_set

In [None]:
# GRU

class MyModel(eqx.Module):
    net: eqx.Module
    Ms: jnp.ndarray  # skalare Größen, die wir lernen wollen
    a: jnp.ndarray
    alpha: jnp.ndarray
    c: jnp.ndarray
    k: jnp.ndarray
    # out_layer: eqx.Module
    
    def __init__(self,input_size,hidden_size,*,key):
        init_key, key = jr.split(key)
        self.net = eqx.nn.GRUCell(
            input_size = input_size,
            hidden_size = hidden_size,
            key = init_key
        )
        
        key1, key2, key3, key4, key5, key6 = jr.split(key,6)
        Ms_key = jr.uniform(key1)
        a_key = jr.uniform(key2)
        a_key = jr.uniform(key3)
        alpha_key = jr.uniform(key4)
        c_key = jr.uniform(key5)
        k_key = jr.uniform(key6) 
        
        self.Ms = jnp.array([1.6e6], dtype=jnp.float32)
        self.a=jnp.array([110], dtype=jnp.float32)
        self.alpha = jnp.array([1.6e-3], dtype=jnp.float32)
        self.c=jnp.array([0.2], dtype=jnp.float32)
        self.k = jnp.array([400.01], dtype=jnp.float32)
        # self.out_layer = eqx.nn.Linear(in_features=1,out_features=1,key=init_key)
 
    def __call__(self, inp):
        hidden = jnp.zeros(self.net.hidden_size)

        def scan_fn(carry, inp):
            gru_out = self.net(inp,carry)
            gru_out_o = jnp.atleast_2d(gru_out)
            out = gru_out_o[:,0]
            return gru_out, out

        _, out = jax.lax.scan(scan_fn, hidden, inp)
        # out_o = self.out_layer(out)
        return out
    
# Reproducibility
key = jr.PRNGKey(0)

# Our PINN is a coordinate network in the form of a LSTM, mapping from scalar to scalar values
key, init_key = jr.split(key)

hidden_size=2
model = MyModel(input_size=1,hidden_size=hidden_size,key=init_key)

In [None]:
# H_e, after (4)
def He_fn(model,H,M):
    return  H + model.alpha*M

# Man (6)
def Man_fn(model,H,M):
    return model.Ms*(jnp.tanh(He_fn(model,H,M)/model.a)-(model.a/He_fn(model,H,M)))


# delta
def delta_fn(H):
    H_new = jnp.squeeze(H)
    diff_H = jnp.diff(H_new,append=1)
    delta = jnp.sign(diff_H)
    
    return delta

#  eq (19)
def fn_dM_dH(model,H,M):
    numerator = Man_fn(model,H,M) - M
    numerator = jnp.squeeze(numerator)
        
    part1 = delta_fn(H)*model.k/mu_0
    part2 = model.alpha*(Man_fn(model,H,M)-M)
    part2_sqee = jnp.squeeze(part2)

    print(f"part1: {part1.shape}")
    print(f"part2: {part2.shape}")
    print(f"part2_sqee: {part2_sqee.shape}")
    print(f"numerator: {numerator.shape}")
    
    # denominator = delta_fn(H)*model.k/mu_0 - model.alpha*(Man_fn(model,H,M)-M)
    denominator = (delta_fn(H)*model.k)/mu_0 - part2_sqee

    M_rev =  model.c*(Man_fn(model,H,M)-M)
    M_rev = jnp.squeeze(M_rev)

    # (19) + (31)
    dM_dH = numerator/denominator + M_rev
    return dM_dH

    

In [None]:
def physics(model,H,B,B_next):
    dB_dt_est = (B_next - B) / TAU
    dB_dt_est = jnp.squeeze(dB_dt_est)
    M = B / mu_0 - H
    dM_dH = fn_dM_dH(model,H,M)
    dM_dB = dM_dH / (mu_0 * (1 + dM_dH))

    print(f"dM_dB {dM_dB.shape}")
    print(f"dB_dt_est: {dB_dt_est.shape}")
    
    dM_dt = dM_dB * dB_dt_est
    dH_dt = 1/mu_0 * dB_dt_est - dM_dt

    return dH_dt

def residuum(model,B):
    return model(B)

In [None]:
# loss function

def loss_fn(model,H,B,B_next):
    physics_at_collocation_points = jax.vmap(physics, in_axes=(None,None,0,None))(model, H, B, B_next)
    
    physics_loss_contribution = 0.5*jnp.mean(jnp.square(physics_at_collocation_points))

    prediction = jax.vmap(residuum,in_axes=(None,0))(model,B)
    prediction_loss_contribution = 0.5*jnp.mean(jnp.square(prediction-H))

    total_loss = 1e-20*physics_loss_contribution + prediction_loss_contribution
    # total_loss = prediction_loss_contribution

    return total_loss
    

In [None]:
dataset_3C90 = DataSet.load_from_file(pathlib.Path("../../../data/processed") / "3C90.pickle")
relevant_data = dataset_3C90.at_frequency(500_000).filter_temperatures([25])

# H = relevant_data.H[:10,0:1000][...,None]
# B = relevant_data.B[:10,0:1000][...,None]
# B_next = relevant_data.B[:10,1:1001][...,None]

LEARNING_RATE = 1e-3
N_OPTIMIZATION_EPOCHS = 1_000
mu_0 = 4*jnp.pi*10**(-7)
TAU = 1/16e6

In [None]:
train_set_norm = train_set_f_T_25.normalize(transform_H=True, featurize=featurize) #, feature_names=feature_names
# train_set_norm = train_set_f_T_50.normalize(transform_H=True, featurize=featurize) #, feature_names=feature_names
# train_set_norm = train_set_f_T_70.normalize(transform_H=True, featurize=featurize) #, feature_names=feature_names

In [None]:
# Training loop

optimizer = optax.adam(LEARNING_RATE)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

@eqx.filter_jit
def make_step(model,H,B,opt_state,B_next):
    loss, grad = eqx.filter_value_and_grad(loss_fn)(model,H,B,B_next)
    updates, new_state = optimizer.update(grad, opt_state, model)
    new_model = eqx.apply_updates(model, updates)
    B_next = B_next

    return new_model, new_state, loss

loss_history = []
Ms_history = []
a_history = []
alpha_history = []
c_history = []
k_history = []

N_TEMP = 3
sequ = 1
temp = 1

for temp in range(N_TEMP):
    if temp == 0:
        train_set_norm = train_set_f_T_25.normalize(transform_H=True, featurize=featurize) #, feature_names=feature_names
    if temp == 2:
        train_set_norm = train_set_f_T_50.normalize(transform_H=True, featurize=featurize) #, feature_names=feature_names
    if temp == 3:
        train_set_norm = train_set_f_T_70.normalize(transform_H=True, featurize=featurize) #, feature_names=feature_names


    H = train_set_norm.H[1:3,:][...,None]
    B = train_set_norm.B[1:3,:][...,None]
    B_next = jnp.roll(B,1)
    
    
    
    for epoch in range(N_OPTIMIZATION_EPOCHS):
        model, opt_state, loss = make_step(model,H,B,opt_state,B_next)
        loss_history.append(loss)
        Ms_history.append(model.Ms)
        a_history.append(model.a)
        alpha_history.append(model.alpha)
        c_history.append(model.c)
        k_history.append(model.k)
        if epoch % 100 == 0:
            print(f"Epoch: {epoch}, loss: {loss}, Ms:{model.Ms}, a:{model.a}, alpha:{model.alpha}, c:{model.c}, k:{model.k}")

In [None]:
plt.plot(loss_history)
plt.yscale("log")

In [None]:
# for ii in range(10):
#     plt.plot(H[ii], label="data", color="gray",alpha=0.6)
#     # plt.plot(H[ii], jax.vmap(model)(H)[ii], label="Final PINN solution")
#     plt.legend()
#     plt.show()

In [None]:
for ii in range(10):
    plt.plot(H[ii], label="data", color="gray",alpha=0.6)
    plt.plot(jax.vmap(model)(B)[ii], label="Final PINN solution")
    plt.legend()
    plt.show()

In [None]:
model