In [None]:
from magpi.prelude import *
from magpi.integrate import gauss, integrate, integrate_sphere
from magpi.calc import *
from magpi.opt import TR
import matplotlib.pyplot as plt
from magpi import r_fun
from magpi.domain import Sphere
from scipy.stats.qmc import Halton, Sobol
from magpi.surface_integral import charge_tensor, source_tensor, single_layer_potential, curl_single_layer_potential, integrate_surface, scalar_potential_charge, vector_potential_charge

# Demagnetization of a hard magnetic cube using scalar and vector potential with a PINN ansatz for $\phi_1$ and $A_1$

This notebook is similar to `demag_sphere.ipynb` but for a cubic geometry.
The algorithm for the mimization is the same.

In [2]:
mu0 = 4*pi*10**-7  # Tm/A
Js = 1.61  # T
Ms = Js / mu0  # A/m
Km = mu0 * Ms**2  # J/m3
L = 70.0  # nm
Ka = 4.3e6 / Km  # -
A = 7.3e-12 / Km * 1e18 / L ** 2 # 1/nm2
unit_vec = lambda x: x / norm(x, axis=-1, keepdims=True)
dh = 5e-3
hext = jnp.arange(-3.5, 1., dh)[::-1]  # -
hext_axis = unit_vec(array([1., 0., 10.]))
m0 = lambda x: zeros_like(x).at[..., 2].set(1.)

In [3]:
from magpi.domain import Hypercube

lb, ub = -0.5, 0.5
cube = Hypercube((lb,lb,lb), (ub, ub, ub))

@partial(jit, static_argnames=("i",))
def parametrization(x, i):
    assert 0 <= i and i <= 5
    assert x.shape[-1] == 2
    if i <= 2:
        c = 0.0
        d = i
    else:
        c = 1.0
        d = i % 3
    return cube.transform(jnp.insert(x, d, c, axis=-1))


k = 7
u = jnp.linspace(0, 1, k)
v = jnp.linspace(0, 1, k)
def compute_source(x):
    _z = [source_tensor(x, lambda x: parametrization(x, i), u, v, method=gauss(10), compute_jacfwd=True) for i in range(6)]
    z, dz = zip(*_z)
    return jnp.concatenate(list(z), axis=0), jnp.concatenate(list(dz), axis=0)
    
def compute_charge(f, *args, **kwargs):
    _c = [charge_tensor(f, lambda x: parametrization(x, i), u, v, *args, order=2, **kwargs) for i in range(6)]
    return jnp.concatenate(_c, axis=0)


adf = r_fun.cube(ub - lb, centering=True)

X = array(Halton(3, seed=0).random(2 ** 12))
X = cube.transform(X)
Z, dZ = lax.map(compute_source, X)

X_val = array(Halton(3, seed=1562).random(2 ** 10))
X_val = cube.transform(X_val)


## Computation with the scalar potential

In [4]:
key = random.key(0)
m_init = lambda x: zeros_like(x).at[..., -1].set(1.0)

class ScalarPotential(nn.Module):
    @nn.compact
    def __call__(self, x):
        nodes = 16
        activation = nn.gelu
        y = activation(nn.Dense(nodes, name="dense1", kernel_init=nn.initializers.he_normal())(x))
        y = activation(nn.Dense(nodes, name="dense2", kernel_init=nn.initializers.he_normal())(y))
        y = nn.Dense(1, name="dense4", kernel_init=nn.initializers.he_normal())(y)
        return y[0]
    
key, _k = random.split(key)
phi1_model = ScalarPotential()
phi1_params_init = phi1_model.init(_k, zeros((3,)))
#vec_pot_params_init = tree_map(lambda p: p.astype(jnp.float64), vec_pot_params_init)

def phi1(x, params):
    return phi1_model.apply(params, x)

class PINN(nn.Module):
    @nn.compact
    def __call__(self, x):
        nodes = 16
        activation = nn.gelu
        y = activation(nn.Dense(nodes, name="dense1", kernel_init=nn.initializers.he_normal())(x))
        y = activation(nn.Dense(nodes, name="dense2", kernel_init=nn.initializers.he_normal())(y))
        y = nn.Dense(3, name="dense4", kernel_init=nn.initializers.he_normal())(y)
        return y

key, _k = random.split(key)
mag_pinn_model = PINN()
pinn_params_init = mag_pinn_model.init(_k, zeros((3,)))

def to_skew_simmetric_matrix(x):
    S = zeros((3, 3))
    S = S.at[1, 0].set(x[2])
    S = S.at[2, 0].set(-x[1])
    S = S.at[2, 1].set(x[0])
    S = S - S.T
    return S

def cayley_rotation(p, x):
    assert p.shape[0] == 3, f"{p.shape}"
    Q = to_skew_simmetric_matrix(p)
    I = jnp.eye(3)
    return jnp.linalg.inv(I - Q) @ (I + Q) @ x


def mag(x, params):
    p = mag_pinn_model.apply(params, x)
    m = m_init(x)
    return cayley_rotation(p, m)


def exchange_energy(x, params_m):
    dm = jacfwd(mag)(x, params_m)
    return A * jnp.sum(dm * dm)

def ani_energy(x, params_m):
    c = array([0., 0., 1.])
    return Ka * (1 - (mag(x, params_m) @ c) ** 2)

def ext_energy(x, params_m, hext, hext_axis):
    return - mag(x, params_m) @ (hext * hext_axis)

def solve_phi12(mag, phi1, params_mag=(), params_phi1=()):
    charge_fn = scalar_potential_charge(adf, mag, phi1, normalized=True)
    c = compute_charge(charge_fn, params_mag=params_mag, params_phi1=params_phi1)
    
    def solve(z, grad_z):
        phi2 = single_layer_potential(z, c)
        Jphi2 = single_layer_potential(grad_z, c)
        return phi2, Jphi2#, curl_A2
    
    return solve

def mag_energy(x, z, dz, params_m, phi1, phi2_Jphi1):
    m = lambda x: mag(x, params_m)
    Jphi1 = jacfwd(phi1)(x)
    phi2, Jphi2 = phi2_Jphi1(z, dz)
    
    h = - (Jphi1 + Jphi2)
    lower_bound = -jnp.sum(Jphi1 * Jphi1) + 2 * m(x) @ (Jphi1 + Jphi2) - (divergence(m)(x) * phi2 + m(x) @ Jphi2 - laplace(phi1)(x) * phi2 - Jphi1 @ Jphi2)

    return 1 / 2 * (- (m(x) @ h)), 1 / 2 * lower_bound

def _integrand(x_z_dz, params_m, params_old, hext, hext_axis, phi2_Jphi1):
    x, z, dz = x_z_dz
    _, params_phi1_old = params_old
    _phi1 = lambda x: adf(x) * phi1(x, params_phi1_old)
    e_d, lower_bound = mag_energy(x, z, dz, params_m, _phi1, phi2_Jphi1)
    return dict(
        e_ex = exchange_energy(x, params_m),
        e_ani = ani_energy(x, params_m),
        e_ext = ext_energy(x, params_m, hext, hext_axis),
        e_d = e_d,
        lower_bound=lower_bound
    )

def mc_integrate(f, X, *args, **kwargs):
    return tree_map(partial(mean, axis=0), vmap(lambda x: f(x, *args, *kwargs))(X))

@jit
def loss_m(params_m, params_old, hext, hext_axis, data):
    params_m_old, params_phi1_old = params_old
    phi2_Jphi1 = solve_phi12(mag, phi1, (params_m_old,), (params_phi1_old,))
    energies = mc_integrate(_integrand, data, params_m, params_old, hext, hext_axis, phi2_Jphi1)
    l = (energies["e_ex"] + energies["e_ani"] + energies["e_ext"] + 2 * energies["e_d"])
    e_tot = (energies["e_ex"] + energies["e_ani"] + energies["e_ext"] + energies["e_d"])
    energies |= {"e_tot": e_tot}
    return l, energies


@jit
def loss_phi1(params_phi1, params_old, data):
    params_m, _  = params_old
    m = lambda x: mag(x, params_m)
    _phi1 = lambda x: adf(x) * phi1(x, params_phi1)
    loss = lambda x:  norm(laplace(_phi1)(x) - divergence(m)(x)) ** 2
    return mc_integrate(loss, data)

In [None]:
from jaxopt import OptStep
from magpi.tr import TR

@partial(jit, static_argnames=("solver1", "solver2",))
def run(solver1, solver2, params, state, hext, hext_axis, data):
    state1, state2 = state
    params1, params2 = params
    state1 = state1._replace(iter_num=0)
    step = OptStep(params1, state1)
    params1_new, state1 = solver1.run(step, params, hext, hext_axis, data)
    params_new = (params1_new, params2)
    state2 = state2._replace(iter_num=0)
    step = OptStep(params2, state2)
    params2_new, state2 = solver2.run(step, params_new, data[0])
    return (params1_new, params2_new), (state1, state2)


hist_scalar_pot = []
def update_hist(params, state, iter_num):
    state1, state2 = state
    h = {
        "value_m": state1.value,
        "error_m": state1.error,
        "aux_m": state1.aux,
        "tr_radius_m": state1.tr_radius,
        "iter_num_m": state1.iter_num,
        "value_phi1": state2.value,
        "error_phi1": state2.error,
        "aux_phi1": state2.aux,
        "tr_radius_phi1": state2.tr_radius,
        "iter_num_phi1": state2.iter_num,
        "iter_num": iter_num,
        "mean_mag": mean_mag(params[0], hext_axis)
    }
    hist_scalar_pot.append(h)
    

@jit
def mean_mag(params, hext_axis):
    m = lambda x: mag(x, params)
    def m_proj(x):
        return m(x) @ hext_axis
    return mean(vmap(m_proj)(X_val))

tol_phi1 = 5e-2
tol_m = 1e-3
def solve(params, state, hext, hext_axis): 
    for i in range(50):
        params, state = run(solver1, solver2, params, state, hext, hext_axis, (X, Z, dZ))
        error1 = tree_l2_norm(solver1.optimality_fun(params[0], params, hext, hext_axis, (X, Z, dZ)))
        error2 = tree_l2_norm(solver2.optimality_fun(params[1], params, X))
        if error1 < tol_m and error2 < tol_phi1:
            break
    return params, state, i

solver1 = TR(loss_m, has_aux=True, maxiter=1, tol=tol_m, damping_factor=1e-3, unroll=False)
solver2 = TR(loss_phi1, has_aux=False, maxiter=20, tol=tol_phi1, damping_factor=1e-3, unroll=False)

params = (pinn_params_init, phi1_params_init)
hext_init = 1.0
state1 = solver1.init_state(params[0], params, hext_init, hext_axis, (X, Z, dZ))
state2 = solver2.init_state(params[1], params, X)
state = (state1, state2)


In [None]:
_params = params
for h in hext:
    print(h)
    _params, state, iter_num = solve(_params, state, h, hext_axis)
    update_hist(_params, state, iter_num)

## Computation with the vector potential

In [7]:
key = random.key(0)
m_init = lambda x: zeros_like(x).at[..., -1].set(1.0)

class VecPotential(nn.Module):
    @nn.compact
    def __call__(self, x):
        nodes = 16
        activation = nn.gelu
        y = activation(nn.Dense(nodes, name="dense1", kernel_init=nn.initializers.he_normal())(x))
        y = activation(nn.Dense(nodes, name="dense2", kernel_init=nn.initializers.he_normal())(y))
        y = nn.Dense(3, name="dense4", kernel_init=nn.initializers.he_normal())(y)
        return y
    
key, _k = random.split(key)
A1_model = VecPotential()
vec_pot_params_init = A1_model.init(_k, zeros((3,)))

def A1(x, params):
    return A1_model.apply(params, x)

class PINN(nn.Module):
    @nn.compact
    def __call__(self, x):
        nodes = 16
        activation = nn.gelu
        y = activation(nn.Dense(nodes, name="dense1", kernel_init=nn.initializers.he_normal())(x))
        y = activation(nn.Dense(nodes, name="dense2", kernel_init=nn.initializers.he_normal())(y))
        y = nn.Dense(3, name="dense4", kernel_init=nn.initializers.he_normal())(y)
        return y

key, _k = random.split(key)
mag_pinn_model = PINN()
pinn_params_init = mag_pinn_model.init(_k, zeros((3,)))

def to_skew_simmetric_matrix(x):
    S = zeros((3, 3))
    S = S.at[1, 0].set(x[2])
    S = S.at[2, 0].set(-x[1])
    S = S.at[2, 1].set(x[0])
    S = S - S.T
    return S

def cayley_rotation(p, x):
    assert p.shape[0] == 3, f"{p.shape}"
    Q = to_skew_simmetric_matrix(p)
    I = jnp.eye(3)
    return jnp.linalg.inv(I - Q) @ (I + Q) @ x


def mag(x, params):
    p = mag_pinn_model.apply(params, x)
    m = m_init(x)
    return cayley_rotation(p, m)


def exchange_energy(x, params_m):
    dm = jacfwd(mag)(x, params_m)
    return A * jnp.sum(dm * dm)

def ani_energy(x, params_m):
    c = array([0., 0., 1.])
    return Ka * (1 - (mag(x, params_m) @ c) ** 2)

def ext_energy(x, params_m, hext, hext_axis):
    return - mag(x, params_m) @ (hext * hext_axis)

def solve_A2(mag, A1, params_mag=(), params_A1=()):
    charge_fn = vector_potential_charge(adf, mag, A1, normalized=True)
    c = compute_charge(charge_fn, params_mag=params_mag, params_A1=params_A1)
    
    def solve(z, grad_z):
        A2 = single_layer_potential(z, c)
        JA2 = single_layer_potential(grad_z, c)
        curl_A2 = curl_single_layer_potential(grad_z, c)
        return A2, JA2, curl_A2
    
    return solve

def mag_energy(x, z, dz, params_m, A1, A2_JA2_curlA2):
    m = lambda x: mag(x, params_m)
    JA1 = jacfwd(A1)(x)
    curl_A1 = curl(A1)(x)
    A2, JA2, curl_A2 = A2_JA2_curlA2(z, dz)
    _curl_A = curl_A1 + curl_A2
    upper_bound = (norm(m(x)) ** 2 + jnp.sum(JA1 * JA1) - 2 * m(x) @ _curl_A -
                    curl(m)(x) @ A2 + m(x) @ curl_A2 - laplace(A1)(x) @ A2 - jnp.sum(JA1 * JA2))

    return 1 / 2 * (1 - (m(x) @ _curl_A)), 1 / 2 * upper_bound

def _integrand(x_z_dz, params_m, params_old, hext, hext_axis, A2_JA2_curlA2):
    x, z, dz = x_z_dz
    _, params_A1_old = params_old
    _A1 = lambda x: adf(x) * A1(x, params_A1_old)
    e_d, upper_bound = mag_energy(x, z, dz, params_m, _A1, A2_JA2_curlA2)
    return dict(
        e_ex = exchange_energy(x, params_m),
        e_ani = ani_energy(x, params_m),
        e_ext = ext_energy(x, params_m, hext, hext_axis),
        e_d = e_d,
        upper_bound=upper_bound
    )

def mc_integrate(f, X, *args, **kwargs):
    return tree_map(partial(mean, axis=0), vmap(lambda x: f(x, *args, *kwargs))(X))

@jit
def loss_m(params_m, params_old, hext, hext_axis, data):
    params_m_old, params_A1_old = params_old
    A2_JA2_curlA2 = solve_A2(mag, A1, (params_m_old,), (params_A1_old,))
    energies = mc_integrate(_integrand, data, params_m, params_old, hext, hext_axis, A2_JA2_curlA2)
    l = (energies["e_ex"] + energies["e_ani"] + energies["e_ext"] + 2 * energies["e_d"])
    e_tot = (energies["e_ex"] + energies["e_ani"] + energies["e_ext"] + energies["e_d"])
    energies |= {"e_tot": e_tot}
    return l, energies


@jit
def loss_A1(params_A1, params_old, data):
    params_m, _  = params_old
    m = lambda x: mag(x, params_m)
    _A1 = lambda x: adf(x) * A1(x, params_A1)
    loss = lambda x: norm(laplace(_A1)(x) + curl(m)(x)) ** 2
    return mc_integrate(loss, data)

In [8]:
from jaxopt import OptStep
from magpi.tr import TR

@partial(jit, static_argnames=("solver1", "solver2",))
def run(solver1, solver2, params, state, hext, hext_axis, data):
    state1, state2 = state
    params1, params2 = params
    state1 = state1._replace(iter_num=0)
    step = OptStep(params1, state1)
    params1_new, state1 = solver1.run(step, params, hext, hext_axis, data)
    params_new = (params1_new, params2)
    state2 = state2._replace(iter_num=0)
    step = OptStep(params2, state2)
    params2_new, state2 = solver2.run(step, params_new, data[0])
    return (params1_new, params2_new), (state1, state2)



hist_vec_pot = []
def update_hist(params, state, iter_num):
    state1, state2 = state
    h = {
        "value_m": state1.value,
        "error_m": state1.error,
        "aux_m": state1.aux,
        "tr_radius_m": state1.tr_radius,
        "iter_num_m": state1.iter_num,
        "value_A1": state2.value,
        "error_A1": state2.error,
        "aux_A1": state2.aux,
        "tr_radius_A1": state2.tr_radius,
        "iter_num_A1": state2.iter_num,
        "iter_num": iter_num,
        "mean_mag": mean_mag(params[0], hext_axis)
    }
    hist_vec_pot.append(h)
    

@jit
def mean_mag(params, hext_axis):
    m = lambda x: mag(x, params)
    def m_proj(x):
        return m(x) @ hext_axis
    return mean(vmap(m_proj)(X_val))

tol_A1 = 5e-2
tol_m = 1e-3
def solve(params, state, hext, hext_axis): 
    for i in range(50):
        params, state = run(solver1, solver2, params, state, hext, hext_axis, (X, Z, dZ))
        error1 = tree_l2_norm(solver1.optimality_fun(params[0], params, hext, hext_axis, (X, Z, dZ)))
        error2 = tree_l2_norm(solver2.optimality_fun(params[1], params, X))
        if error1 < tol_m and error2 < tol_A1:
            break
    return params, state, i

solver1 = TR(loss_m, has_aux=True, maxiter=1, tol=tol_m, damping_factor=1e-3, unroll=False)
solver2 = TR(loss_A1, has_aux=False, maxiter=20, tol=tol_A1, damping_factor=1e-3, unroll=False)

params = (pinn_params_init, vec_pot_params_init)
hext_init = 1.0
state1 = solver1.init_state(params[0], params, hext_init, hext_axis, (X, Z, dZ))
state2 = solver2.init_state(params[1], params, X)
state = (state1, state2)

In [None]:
_params = params
for h in hext:
    print(h)
    _params, state, iter_num = solve(_params, state, h, hext_axis)
    update_hist(_params, state, iter_num)
    

## Compute switching fields

In [14]:
import numpy as np
from scipy.optimize import bisect

mean_mags = asarray([h["mean_mag"] for h in hist_scalar_pot])

def _g(x):
    return np.array(jnp.interp(x, hext[:len(mean_mags)][::-1], asarray(mean_mags)[::-1]))

hsw_scalar_pot = bisect(_g, -3.5, 1)

In [15]:
import numpy as np
from scipy.optimize import bisect

mean_mags = asarray([h["mean_mag"] for h in hist_vec_pot])

def _g(x):
    return np.array(jnp.interp(x, hext[:len(mean_mags)][::-1], asarray(mean_mags)[::-1]))

hsw_vec_pot = bisect(_g, -3.5, 1)

In [None]:

with plt.rc_context(rc={'text.usetex': True, 'text.latex.preamble': r"\usepackage{amsmath}", "axes.labelsize": 16, "axes.titlesize": 16}):
    fig = plt.figure(figsize=(8, 8))
    ax1, ax2, ax3, ax4, ax5 = fig.subplots(5, 1, sharex=True)
    ax1.plot(hext[:len(hist_vec_pot)],  [h["mean_mag"] for h in hist_vec_pot], c="tab:red", label="vector potential")
    ax1.set_title("mean magnetization", fontsize=16)
    ax1.grid()
    ax1.annotate(f"$h_c={hsw_vec_pot:.3f}$", (hsw_vec_pot, 0), (hsw_vec_pot - 0.7, -0.5), xycoords="data", fontsize=14, arrowprops={"arrowstyle": "->"}, color="tab:red")
    ax2.plot(hext[:len(hist_vec_pot)], [h["aux_m"]["e_ext"] for h in hist_vec_pot], c="tab:red")
    ax2.set_title("$e_{zee}$", fontsize=16)
    ax2.grid()

    ax3.plot(hext[:len(hist_vec_pot)], [h["aux_m"]["e_ani"] for h in hist_vec_pot], c="tab:red")
    ax3.set_title("$e_{a}$", fontsize=16)
    ax3.grid()
    ax4.ticklabel_format(axis="y", style="sci", useOffset=True, scilimits=(0,0))
    ax4.plot(hext[:len(hist_vec_pot)], [h["aux_m"]["e_ex"] for h in hist_vec_pot], c="tab:red")
    ax4.set_title("$e_{ex}$", fontsize=16)
    ax4.grid()
    ax5.plot(hext[:len(hist_vec_pot)], [(h["aux_m"]["e_d"]) for h in hist_vec_pot], label=r"$e_d^{\mathbf{A}}$", c="tab:red")
    ax5.plot(hext[:len(hist_vec_pot)], [(h["aux_m"]["upper_bound"]) for h in hist_vec_pot], linestyle="dotted", label=r"$e_\mathbf{A}$ (upper bound)", c="tab:orange")
    
    ax1.plot(hext[:len(hist_scalar_pot)], [h["mean_mag"] for h in hist_scalar_pot], "--", c="tab:blue", label="scalar potential")
    ax1.annotate(f"$h_c={hsw_scalar_pot:.3f}$", (hsw_scalar_pot, 0), (hsw_scalar_pot + 0.2, -0.3), xycoords="data", fontsize=14, arrowprops={"arrowstyle": "->"}, color="tab:blue")
    ax2.plot(hext[:len(hist_scalar_pot)], [h["aux_m"]["e_ext"] for h in hist_scalar_pot], "--", c="tab:blue")
    ax3.plot(hext[:len(hist_scalar_pot)], [h["aux_m"]["e_ani"] for h in hist_scalar_pot], "--", c="tab:blue")
    
    ax4.plot(hext[:len(hist_scalar_pot)], [h["aux_m"]["e_ex"] for h in hist_scalar_pot], "--", c="tab:blue")
    ax5.plot(hext[:len(hist_scalar_pot)], [(h["aux_m"]["lower_bound"]) for h in hist_scalar_pot], linestyle="dashdot", label=r"$e_\phi$ (lower bound)", c="tab:green")
    ax5.plot(hext[:len(hist_scalar_pot)], [(h["aux_m"]["e_d"]) for h in hist_scalar_pot], "--", label=r"$e_d^{\phi}$", c="tab:blue")
    
    ax1.legend(fontsize=16)
    
    ax5.legend(loc="right")
    ax5.set_title("$e_{d}$", fontsize=16)
    ax5.set_xlabel(r"$h_{ext} [-]$", fontsize=16)
    fig.text(0.04, 0.43, r"$e\;[\mu_0 M_s^2]$", ha='center', va='center', rotation='vertical', fontsize=16)
    fig.text(0.04, 0.89, r"$\langle \mathbf{m}\rangle \;[-]$", ha='center', va='center', rotation='vertical', fontsize=16)
    ax5.grid()
    ax5.set_xlim((min(hext),max(hext)))
    fig.tight_layout(rect=(0.03, 0, 1, 1))

    fig.savefig("demag_cube_scalar_and_vec_pot.pdf")
    

In [None]:
hsw_true = -2.767
def rel_error(hsw):
    return jnp.abs(hsw - hsw_true) / jnp.abs(hsw_true) * 100

print("vector potential rel error", rel_error(hsw_vec_pot))
print("scalar potential rel error", rel_error(hsw_scalar_pot))