# PINN
PINN for the Jiles-Atherton model

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import jax
import jax.numpy as jnp
import jax.random as jr

# jax.config.update("jax_enable_x64", True)

import equinox as eqx
import optax
import matplotlib.pyplot as plt

from typing import List

import pathlib

from mc2.utils.data_inspection import (
    get_available_material_names,
    get_file_overview,
    load_and_process_single_from_full_file_overview
)
from mc2.data_management import FrequencySet, MaterialSet, DataSet
from mc2.metrics import get_energy_loss

In [2]:
dataset_3C90 = DataSet.load_from_file(pathlib.Path("../../../data/processed") / "3C90.pickle")
relevant_data = dataset_3C90.at_frequency(500_000).filter_temperatures([25])

H = relevant_data.H[:10,0:1000][...,None]
B = relevant_data.B[:10,0:1000][...,None]
B_next = jnp.squeeze(B)


# # normalization
# max_H = jnp.max(H)
# min_H = jnp.min(H)

# max_B = jnp.max(B)
# min_B = jnp.min(B)

# H = H/(jnp.abs(min_H)+max_H)
# B = B/(jnp.abs(min_B)+max_B)


INFO:2025-10-31 15:14:12,412:jax._src.xla_bridge:925: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


2025-10-31 15:14:12 | INFO : Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


INFO:2025-10-31 15:14:12,414:jax._src.xla_bridge:925: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


2025-10-31 15:14:12 | INFO : Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


AttributeError: 'FrequencySet' object has no attribute 'H_RMS'

In [3]:
LEARNING_RATE = 1e-3
N_OPTIMIZATION_EPOCHS = 3_000

mu_0 = 4*jnp.pi*10**(-7)
TAU = 1/16e6

In [None]:
# MLP

class MyModel(eqx.Module):
    net: eqx.nn.MLP
    Ms: jnp.ndarray  # skalare Größen, die wir lernen wollen
    a: jnp.ndarray
    alpha: jnp.ndarray
    c: jnp.ndarray
    k: jnp.ndarray
    

 
    def __call__(self, x):
        return self.net(x)
    
# Reproducibility
key = jr.PRNGKey(42)

# Our PINN is a coordinate network in the form of a MLP, mapping from scalar to scalar values
key, init_key = jr.split(key)

pinn = eqx.nn.MLP(
    in_size="scalar",
    out_size="scalar",
    width_size=32,
    depth=10,
    activation=jax.nn.tanh,
    key=init_key,
)
Ms = jnp.array(.5)
a = jnp.array(.5)
alpha = jnp.array(.5)
c = jnp.array(.5)
k = jnp.array(.5)
model = MyModel(pinn, Ms, a, alpha, c, k)

In [None]:
# LSTM
class MyModel(eqx.Module):
    net: eqx.nn.LSTMCell
    Ms: jnp.ndarray  # skalare Größen, die wir lernen wollen
    a: jnp.ndarray
    alpha: jnp.ndarray
    c: jnp.ndarray
    k: jnp.ndarray
    
    def __init__(self,input_size,hidden_size,use_bias,key):
        self.net = eqx.nn.LSTMCell(
            input_size = 1,
            hidden_size = 2,
            use_bias = True,
            key = init_key
        )
        self.Ms = jnp.array(.5)
        self.a=jnp.array(.5)
        self.alpha = jnp.array(.5)
        self.c=jnp.array(.5)
        self.k = jnp.array(.5)
 
    def __call__(self,inp):
        init_state = (jnp.zeros(self.net.hidden_size),
                      jnp.zeros(self.net.hidden_size))

        def scan_fn(carry, inp):
            lstm_out = self.net(inp,carry)
            lstm_out_0 = jnp.alteast_2d(lstm_out)
            print(f"lstem_out {lstm_out}")
            out = lstm_out_o[:,0]
            return lstm_out

        h, = jax.lax.scan(scan_fn, init_state, inp)
        return h
    
# Reproducibility
key = jr.PRNGKey(42)

# Our PINN is a coordinate network in the form of a LSTM, mapping from scalar to scalar values
key, init_key = jr.split(key)

hidden_size=2
model = MyModel(input_size=1,hidden_size=hidden_size,use_bias=True,key=init_key)

In [4]:
# GRU

class MyModel(eqx.Module):
    net: eqx.Module
    Ms: jnp.ndarray  # skalare Größen, die wir lernen wollen
    a: jnp.ndarray
    alpha: jnp.ndarray
    c: jnp.ndarray
    k: jnp.ndarray
    # out_layer: eqx.Module
    
    def __init__(self,input_size,hidden_size,*,key):
        init_key, key = jr.split(key)
        self.net = eqx.nn.GRUCell(
            input_size = 1,
            hidden_size = 2,
            key = init_key
        )
        
        key1, key2, key3, key4, key5, key6 = jr.split(key,6)
        Ms_key = jr.uniform(key1)
        a_key = jr.uniform(key2)
        a_key = jr.uniform(key3)
        alpha_key = jr.uniform(key4)
        c_key = jr.uniform(key5)
        k_key = jr.uniform(key6) 
        
        self.Ms = jnp.array([1.6e6], dtype=jnp.float32)
        self.a=jnp.array([110], dtype=jnp.float32)
        self.alpha = jnp.array([1.6e-3], dtype=jnp.float32)
        self.c=jnp.array([0.2], dtype=jnp.float32)
        self.k = jnp.array([400], dtype=jnp.float32)
        # self.out_layer = eqx.nn.Linear(in_features=1,out_features=1,key=init_key)
 
    def __call__(self, inp):
        hidden = jnp.zeros(self.net.hidden_size)

        def scan_fn(carry, inp):
            gru_out = self.net(inp,carry)
            gru_out_o = jnp.atleast_2d(gru_out)
            out = gru_out_o[:,0]
            return gru_out, out

        _, out = jax.lax.scan(scan_fn, hidden, inp)
        # out_o = self.out_layer(out)
        return out
    
# Reproducibility
key = jr.PRNGKey(0)

# Our PINN is a coordinate network in the form of a LSTM, mapping from scalar to scalar values
key, init_key = jr.split(key)

hidden_size=32
model = MyModel(input_size=1,hidden_size=hidden_size,key=init_key)

In [5]:
# H_e
def He_fn(model,H):
    return  H + model.alpha*model(H)

# Man
def Man_fn(model,H):
    return model.Ms*(jnp.tanh(He_fn(model,H)/model.a)-(model.a/He_fn(model,H)))


# delta
def delta_fn(H):
    H_new = jnp.squeeze(H)
    diff_H = jnp.diff(H_new)
    delta = jnp.sign(diff_H)

    return delta

#  eq (19)
def dM_dH(model,H,M):
    numerator = Man_fn(model,H) - M
    denominator = delta_fn(H)*model.k/mu_0 - model.alpha*(Man_fn(H)-M)
    
    dM_dH = numerator/denominator
    return dM_dH

def dB_dt_est(B,B_next):
    return dB_dt_est = (B_next - B) / TAU
    
        

In [6]:
def physics(model,H,B,B_next):
    
    M = B / mu_0 - H
    dM_dH = dM_dH(model,H,M)
    dM_dB = dM_dH / (mu_0 * (1 + dM_dH))
    dM_dt = dM_dB * dB_dt_est
    dH_dt = 1/mu_0 * dB_dt_est - dM_dt

    return dH_dt

def residuum(model,H):
    return model(H)

In [7]:
# loss function

def loss_fn(model,H,B,B_next):
    physics_at_collocation_points = jax.vmap(physics, in_axes=(None,0,None,None))(model, H, B, B_next)
    physics_loss_contribution = 0.5*jnp.mean(jnp.square(physics_at_collocation_points))

    prediction = jax.vmap(residuum,in_axes=(None,0))(model,H)
    prediction_loss_contribution = 0.5*jnp.mean(jnp.square(prediction-B))

    # total_loss = physics_loss_contribution + prediction_loss_contribution
    total_loss = prediction_loss_contribution

    return total_loss
    

In [9]:
# Training loop

optimizer = optax.adam(LEARNING_RATE)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

@eqx.filter_jit
def make_step(model,H,B,opt_state,B_next):
    loss, grad = eqx.filter_value_and_grad(loss_fn)(model,H,B,B_next)
    updates, new_state = optimizer.update(grad, opt_state, model)
    new_model = eqx.apply_updates(model, updates)
    B_next = B_next

    return new_model, new_state, loss

loss_history = []
Ms_history = []
a_history = []
alpha_history = []
c_history = []
k_history = []

for epoch in range(N_OPTIMIZATION_EPOCHS):
    model, opt_state, loss = make_step(model,H,B,opt_state,B_next)
    loss_history.append(loss)
    Ms_history.append(model.Ms)
    a_history.append(model.a)
    alpha_history.append(model.alpha)
    c_history.append(model.c)
    k_history.append(model.k)
    if epoch % 100 == 0:
        print(f"Epoch: {epoch}, loss: {loss}, Ms:{model.Ms}, a:{model.a}, alpha:{model.alpha}, c:{model.c}, k:{model.k}")

ValueError: Incompatible shapes for broadcasting: shapes=[(10, 1000), (10, 1000, 1)]

In [None]:
plt.plot(loss_history)
plt.yscale("log")

In [None]:
for ii in range(10):
    plt.plot(H[ii], B[ii], label="data", color="gray",alpha=0.6)
    # plt.plot(H[ii], jax.vmap(model)(H)[ii], label="Final PINN solution")
    plt.legend()
    plt.show()

In [None]:
for ii in range(10):
    plt.plot(B[ii], label="data", color="gray",alpha=0.6)
    plt.plot(jax.vmap(model)(H)[ii], label="Final PINN solution")
    plt.legend()
    plt.show()