# PINN
PINN for the Jiles-Atherton model

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import optax
import matplotlib.pyplot as plt

from typing import List

import pathlib

from mc2.utils.data_inspection import (
    get_available_material_names,
    get_file_overview,
    load_and_process_single_from_full_file_overview
)
from mc2.data_management import FrequencySet, MaterialSet, DataSet
from mc2.metrics import get_energy_loss

In [None]:
dataset_3C90 = DataSet.load_from_file(pathlib.Path("../../../data/processed") / "3C90.pickle")
relevant_data = dataset_3C90.at_frequency(50_000).filter_temperatures([25])

H = relevant_data.H
B = relevant_data.B

H = jnp.reshape(H,7523760)
B = jnp.reshape(B,7523760)

In [None]:
N_DOF_FD = 100
N_COLLOCATION = 50
LEARNING_RATE = 1e-3
N_OPTIMIZATION_EPOCHS = 1_000
BC_LOSS_WEIGHT = 100.0

mu_0 = 4*jnp.pi*10**(-7)

In [None]:
# MLP

class MyModel(eqx.Module):
    net: eqx.nn.MLP
    Ms: jnp.ndarray  # skalare Größen, die wir lernen wollen
    a: jnp.ndarray
    alpha: jnp.ndarray
    c: jnp.ndarray
    k: jnp.ndarray
    

 
    def __call__(self, x):
        return self.net(x)
    
# Reproducibility
key = jr.PRNGKey(42)

# Our PINN is a coordinate network in the form of a MLP, mapping from scalar to scalar values
key, init_key = jr.split(key)

pinn = eqx.nn.MLP(
    in_size="scalar",
    out_size="scalar",
    width_size=32,
    depth=3,
    activation=jax.nn.tanh,
    key=init_key,
)
Ms = jnp.array(.5)
a = jnp.array(.5)
alpha = jnp.array(.5)
c = jnp.array(.5)
k = jnp.array(.5)
model = MyModel(pinn, Ms, a, alpha, c, k)

In [None]:
# H_e

def He_fn(model,H):
    return  H + model.alpha*model.net(H)

In [None]:
# Man

def Man_fn(model,H):
    return model.Ms*(jnp.tanh(He_fn(model,H)/model.a)-(model.a/He_fn(model,H)))

In [None]:
# delta

# def delta_fn(H):
#     H0 = H[-2]
#     H1 = H[-1]
#     if H0 < H1:
#         delta = 1
#     elif H0 > H1:
#         delta = -1

#     return delta

def delta_fn(H):
    return 1

In [None]:
def physics(model,H):
    return -jax.grad(model.net)(H) + 1/(delta_fn(H)*model.k/mu_0 - model.alpha*(Man_fn(model,H)-model.net(H)))*(Man_fn(model,H)-model.net(H))

def residuum(model,H):
    return model.net(H)

In [None]:
# loss function

def loss_fn(model,H,B):
    physics_at_collocation_points = jax.vmap(physics, in_axes=(None,0))(model, H)
    physics_loss_contribution = 0.5*jnp.mean(jnp.square(physics_at_collocation_points))

    prediction = jax.vmap(residuum,in_axes=(None,0))(model,H)
    prediction_loss_contribution = 0.5*jnp.mean((prediction-B)**2)

    total_loss = physics_loss_contribution + prediction_loss_contribution

    return total_loss
    

In [None]:
# Training loop

optimizer = optax.adam(LEARNING_RATE)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

@eqx.filter_jit
def make_step(model, state,H,B):
    loss, grad = eqx.filter_value_and_grad(loss_fn)(model,H,B)
    updates, new_state = optimizer.update(grad, state, model)
    new_model = eqx.apply_updates(model, updates)

    return new_model, new_state, loss

loss_history = []
Ms_history = []
a_history = []
alpha_history = []
c_history = []
k_history = []


for epoch in range(N_OPTIMIZATION_EPOCHS):
    model, opt_state, loss = make_step(model, opt_state,H,B)
    loss_history.append(loss)
    Ms_history.append(model.Ms)
    a_history.append(model.a)
    alpha_history.append(model.alpha)
    c_history.append(model.c)
    k_history.append(model.k)
    if epoch % 100 == 0:
        print(f"Epoch: {epoch}, loss: {loss}, Ms:{model.Ms}, a:{model.a}, alpha:{model.alpha}, c:{model.c}, k:{model.k}")