In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import bisect
from scipy.stats.qmc import Halton, Sobol
from jaxopt import LBFGS, OptStep

from magpi.prelude import *
from magpi.integrate import gauss
from magpi.calc import *
from magpi import r_fun
from magpi.surface_integral import charge_tensor, source_tensor, single_layer_potential, curl_single_layer_potential, scalar_potential_charge, vector_potential_charge

jax.config.update("jax_enable_x64", False)

# Demagnetization of a hard magnetic cube with ELM ansatz for $\phi_1$ and $A_1$ and L-BFGS optimization

This notebook is similar to `demag_cube_elm.ipynb` but uses L-BFGS instead of Trust Region and offers a more efficient implementation.

In [None]:
mu0 = 4*pi*10**-7  # Tm/A
Js = 1.61  # T
Ms = Js / mu0  # A/m
Km = mu0 * Ms**2  # J/m3
L = 70.0  # nm
Ka = 4.3e6 / Km  # -
A = 7.3e-12 / Km * 1e18 / L ** 2 # 1/nm2
unit_vec = lambda x: x / norm(x, axis=-1, keepdims=True)
dh = 5e-3
hext = jnp.arange(-3.5, 1., dh)[::-1]  # -
hext_axis = unit_vec(array([1., 0., 10.]))
m0 = lambda x: zeros_like(x).at[..., 2].set(1.)

In [None]:
from magpi.domain import Hypercube

lb, ub = -0.5, 0.5
cube = Hypercube((lb,lb,lb), (ub, ub, ub))

@partial(jit, static_argnames=("i",))
def parametrization(x, i):
    assert 0 <= i and i <= 5
    assert x.shape[-1] == 2
    if i <= 2:
        c = 0.0
        d = i
    else:
        c = 1.0
        d = i % 3
    return cube.transform(jnp.insert(x, d, c, axis=-1))


k = 7
u = jnp.linspace(0, 1, k)
v = jnp.linspace(0, 1, k)
def compute_source(x):
    _z = [source_tensor(x, lambda x: parametrization(x, i), u, v, method=gauss(10), compute_jacfwd=True) for i in range(6)]
    z, dz = zip(*_z)
    return jnp.concatenate(list(z), axis=0), jnp.concatenate(list(dz), axis=0)
    
def compute_charge(f, *args, **kwargs):
    _c = [charge_tensor(f, lambda x: parametrization(x, i), u, v, *args, order=2, **kwargs) for i in range(6)]
    return jnp.concatenate(_c, axis=0)


adf = r_fun.cube(ub - lb, centering=True)

X = array(Halton(3, seed=42).random(2 ** 12))
X = cube.transform(X)
Z, dZ = lax.map(compute_source, X)

X_val = array(Sobol(3, seed=1562).random_base2(11))
X_val = cube.transform(X_val)


In [None]:
key = random.key(0)
m_init = lambda x: zeros_like(x).at[..., -1].set(1.0)

l2_reg = 1e-3

_weights = array(Halton(4, seed=43).random(2 ** 9))
W_elm = (_weights[:, :3] * 2 - 1) * 2
b_elm = (_weights[:, 3] * 2 - 1) * 2
h_elm = lambda x: nn.tanh(W_elm @ x + b_elm)
u_elm = lambda x: h_elm(x) * adf(x)
Q_phi1 = vmap(lambda x: -laplace(u_elm)(x))(X)

U_phi1, S_phi1, VT_phi1 = jax.scipy.linalg.svd(
    Q_phi1, full_matrices=False, lapack_driver="gesvd"
)
Pinv_phi1 = VT_phi1.T * (S_phi1 / (S_phi1 ** 2 + l2_reg)) @ U_phi1.T

@jit
def solve_phi1(params_m):
    f = lambda x: -divergence(mag)(x, params_m)
    b = vmap(f)(X)
    params_phi1 = Pinv_phi1 @ b
    return params_phi1

def phi1(x, params):
    return h_elm(x) @ params


W_m = array(Halton(3, seed=43).random(2 ** 9))
W_m = cube.transform(W_m)

class PINN(nn.Module):
    @nn.compact
    def __call__(self, x):
        nodes = 16
        activation = nn.gelu
        y = activation(nn.Dense(nodes, kernel_init=nn.initializers.he_normal())(x))
        y = activation(nn.Dense(nodes, kernel_init=nn.initializers.he_normal())(y))
        y = nn.Dense(3, kernel_init=nn.initializers.he_normal())(y)
        return y
    
    
key, _k = random.split(key)
mag_pinn_model = PINN()
pinn_params_init = mag_pinn_model.init(_k, zeros((3,)))

if getattr(jax.config, "jax_enable_x64", False):
    pinn_params_init = tree_map(lambda p: p.astype(jnp.float64), pinn_params_init)

def to_skew_simmetric_matrix(x):
    S = zeros((3, 3))
    S = S.at[1, 0].set(x[2])
    S = S.at[2, 0].set(-x[1])
    S = S.at[2, 1].set(x[0])
    S = S - S.T
    return S

def cayley_rotation(p, x):
    assert p.shape[0] == 3, f"{p.shape}"
    Q = to_skew_simmetric_matrix(p)
    I = jnp.eye(3)
    return jnp.linalg.inv(I - Q) @ (I + Q) @ x


def mag(x, params):
    p = mag_pinn_model.apply(params, x)
    m = m_init(x)
    return cayley_rotation(p, m)


def exchange_energy(x, params_m):
    dm = jacfwd(mag)(x, params_m)
    return A * jnp.sum(dm * dm)

def ani_energy(x, params_m):
    c = array([0., 0., 1.])
    return Ka * (1 - (mag(x, params_m) @ c) ** 2)

def ext_energy(x, params_m, hext, hext_axis):
    return - mag(x, params_m) @ (hext * hext_axis)

def mag_energy_sp(x, params_m, hd):
    m = lambda x: mag(x, params_m)
    return 1 / 2 * (- (m(x) @ hd))


def _integrand(x_hd, params_m, hext, hext_axis):
    x, hd = x_hd
    e_d = mag_energy_sp(x, params_m, hd)
    return dict(
        e_ex = exchange_energy(x, params_m),
        e_ani = ani_energy(x, params_m),
        e_ext = ext_energy(x, params_m, hext, hext_axis),
        e_d = e_d,
    )

def mc_integrate(f, X, *args, **kwargs):
    return tree_map(partial(mean, axis=0), vmap(lambda x: f(x, *args, *kwargs))(X))

def _compute_hd(params_phi1, charge_tensor, x_dz):
    x, grad_z = x_dz
    _phi1 = lambda x: adf(x) * phi1(x, params_phi1)
    Jphi1 = jacfwd(_phi1)(x)
    Jphi2 = single_layer_potential(grad_z, charge_tensor)
    h = - (Jphi1 + Jphi2)
    return h

@jit
def compute_hd(params_m, params_phi1, X_dZ):
    charge_fn = scalar_potential_charge(adf, mag, phi1, normalized=True)
    c = compute_charge(charge_fn, params_mag=(params_m,), params_phi1=(params_phi1,))
    hd = vmap(_compute_hd, (None, None, 0))(params_phi1, c, X_dZ)
    return hd


@jit
def loss_m_sp(params_m, hext, hext_axis, data):
    X, hd = data
    energies = mc_integrate(_integrand, (X, hd), params_m, hext, hext_axis)
    l = (energies["e_ex"] + energies["e_ani"] + energies["e_ext"] + 2 * energies["e_d"])
    e_tot = (energies["e_ex"] + energies["e_ani"] + energies["e_ext"] + energies["e_d"])
    energies |= {"e_tot": e_tot}
    return l, energies


In [None]:
@partial(jit, static_argnames=("solver",))
def run_sp(solver, params, state, hext, hext_axis, data):
    X, Z, dZ, hd = data
    params_m, params_phi1 = params
    state = state._replace(iter_num=0)
    step = OptStep(params_m, state)
    params_m_new, state = solver.run(step, hext, hext_axis, (X, hd))
    params_phi1_new = solve_phi1(params_m_new)
    hd = compute_hd(params_m, params_phi1_new, (X, dZ))
    error = tree_l2_norm(solver.optimality_fun(params_m_new, hext, hext_axis, (X, hd)))
    return (params_m_new, params_phi1_new), state, error, hd


@jit
def mean_mag(params, hext_axis):
    m = lambda x: mag(x, params)
    def m_proj(x):
        return m(x) @ hext_axis
    return mean(vmap(m_proj)(X_val))



def _hist(params, state, last_iter_num, n_field_eval, hext_axis):
    h = {
        "value_m": state.value,
        "error_m": state.error,
        "aux_m": state.aux,
        "iter_num": last_iter_num + state.iter_num,
        "n_field_eval": n_field_eval + 1,
        "mean_mag": mean_mag(params[0], hext_axis)
    }
    return h


@partial(jit, static_argnames=("solver",))
def solve_sp(solver, params, state, hext, hext_axis):
    hd = compute_hd(*params, (X, dZ))

    params, state, error, hd = run_sp(solver, params, state, hext, hext_axis, (X, Z, dZ, hd))
    hist = _hist(params, state, state.iter_num, 1, hext_axis)
    def body(val):
        params, state, error, hd, hist = val
        params, state, error, hd = run_sp(solver, params, state, hext, hext_axis, (X, Z, dZ, hd))
        last_mean_mag = hist["mean_mag"]
        hist = _hist(params, state, hist["iter_num"], hist["n_field_eval"], hext_axis)
        _mean_mag = hist["mean_mag"]
        state = lax.cond(
            (last_mean_mag > 0) & (_mean_mag < 0),
            lambda: solver.init_state(params[0], hext, hext_axis, (X, hd)),
            lambda: state
        )
        return params, state, error, hd, hist
    
    def cond(val):
        params, state, error, hd, hist = val
        return error > solver.tol
    
    init_val = (params, state, error, hd, hist)
    params, state, error, hd, hist = lax.while_loop(cond, body, init_val)
    return params, state, hist


@partial(jit, static_argnames=("solver",))
def hysteresis_loop_sp(solver, pinn_params_init, hext):
    hext_init = hext[0]
    phi1_params_init = solve_phi1(pinn_params_init)
    hd = compute_hd(pinn_params_init, phi1_params_init, (X, dZ))
    state = solver.init_state(pinn_params_init, hext_init, hext_axis, (X, hd))

    def body(carry, hext):
        params, state = carry
        params, state, hist = solve_sp(solver, params, state, hext, hext_axis)        
        return (params, state), hist
    
    init_val = ((pinn_params_init, phi1_params_init), state)
    (params, state), history = lax.scan(body, init=init_val, xs=hext, unroll=False)
    return params[0], history

In [None]:
solver_sp = LBFGS(loss_m_sp, has_aux=True, maxiter=1, tol=1e-3, unroll=False, history_size=100)
params, hist_scalar_pot = hysteresis_loop_sp(solver_sp, pinn_params_init, hext)

In [None]:
%time hysteresis_loop_sp(solver_sp, pinn_params_init, hext);

## Vector Potential

In [None]:
Pinv_A1 = -Pinv_phi1

@jit
def solve_A1(params_m):
    f = lambda x: -curl(mag)(x, params_m)
    b = vmap(f)(X)
    params_A1 = Pinv_A1 @ b
    return params_A1

def A1(x, params):
    return h_elm(x) @ params


def _compute_bd(params_A1, charge_tensor, x_dz):
    x, grad_z = x_dz
    _A1 = lambda x: adf(x) * A1(x, params_A1)
    curl_A1 = curl(_A1)(x)
    curl_A2 = curl_single_layer_potential(grad_z, charge_tensor)
    bd = curl_A1 + curl_A2
    return bd

@jit
def compute_bd(params_m, params_A1, X_dZ):
    charge_fn = vector_potential_charge(adf, mag, A1, normalized=True)
    c = compute_charge(charge_fn, params_mag=(params_m,), params_A1=(params_A1,))
    bd = vmap(_compute_bd, (None, None, 0))(params_A1, c, X_dZ)
    return bd


def mag_energy_vec_pot(x, params_m, bd):
    m = lambda x: mag(x, params_m)
    return 1 / 2 * (1 - (m(x) @ bd))


def _integrand(x_bd, params_m, hext, hext_axis):
    x, bd = x_bd
    e_d = mag_energy_vec_pot(x, params_m, bd)
    return dict(
        e_ex = exchange_energy(x, params_m),
        e_ani = ani_energy(x, params_m),
        e_ext = ext_energy(x, params_m, hext, hext_axis),
        e_d = e_d,
    )

@jit
def loss_m_vp(params_m, hext, hext_axis, data):
    X, bd = data
    energies = mc_integrate(_integrand, (X, bd), params_m, hext, hext_axis)
    l = (energies["e_ex"] + energies["e_ani"] + energies["e_ext"] + 2 * energies["e_d"])
    e_tot = (energies["e_ex"] + energies["e_ani"] + energies["e_ext"] + energies["e_d"])
    energies |= {"e_tot": e_tot}
    return l, energies


In [None]:
from jaxopt import LBFGS, OptStep

@partial(jit, static_argnames=("solver",))
def run_vp(solver, params, state, hext, hext_axis, data):
    X, Z, dZ, bd = data
    params_m, params_A1 = params
    state = state._replace(iter_num=0)
    step = OptStep(params_m, state)
    params_m_new, state = solver.run(step, hext, hext_axis, (X, bd))
    params_A1_new = solve_A1(params_m_new)
    bd = compute_bd(params_m, params_A1_new, (X, dZ))
    error = tree_l2_norm(solver.optimality_fun(params_m_new, hext, hext_axis, (X, bd)))
    return (params_m_new, params_A1_new), state, error, bd


@partial(jit, static_argnames=("solver",))
def solve_vp(solver, params, state, hext, hext_axis):
    bd = compute_bd(*params, (X, dZ))

    params, state, error, bd = run_vp(solver, params, state, hext, hext_axis, (X, Z, dZ, bd))
    hist = _hist(params, state, state.iter_num, 1, hext_axis)
    def body(val):
        params, state, error, bd, hist = val
        params, state, error, bd = run_vp(solver, params, state, hext, hext_axis, (X, Z, dZ, bd))
        last_mean_mag = hist["mean_mag"]
        hist = _hist(params, state, hist["iter_num"], hist["n_field_eval"], hext_axis)
        _mean_mag = hist["mean_mag"]
        state = lax.cond(
            (last_mean_mag > 0) & (_mean_mag < 0),
            lambda: solver.init_state(params[0], hext, hext_axis, (X, bd)),
            lambda: state
        )
        return params, state, error, bd, hist
    
    def cond(val):
        params, state, error, bd, hist = val
        return error > solver.tol
    
    init_val = (params, state, error, bd, hist)
    params, state, error, bd, hist = lax.while_loop(cond, body, init_val)
    return params, state, hist


@partial(jit, static_argnames=("solver",))
def hysteresis_loop_vp(solver, pinn_params_init, hext):
    hext_init = hext[0]
    A1_params_init = solve_A1(pinn_params_init)
    bd = compute_bd(pinn_params_init, A1_params_init, (X, dZ))
    state = solver.init_state(pinn_params_init, hext_init, hext_axis, (X, bd))

    def body(carry, hext):
        params, state = carry
        params, state, hist = solve_vp(solver, params, state, hext, hext_axis)        
        return (params, state), hist
    
    init_val = ((pinn_params_init, A1_params_init), state)
    (params, state), history = lax.scan(body, init=init_val, xs=hext, unroll=False)
    return params[0], history


In [None]:
solver_vp = LBFGS(loss_m_vp, has_aux=True, maxiter=1, tol=1e-3, unroll=False, history_size=100)
params, hist_vec_pot = hysteresis_loop_vp(solver_vp, pinn_params_init, hext)

In [None]:
%time hysteresis_loop_vp(solver_vp, pinn_params_init, hext);

## Computation of the switching field

In [None]:
#hsw_true = -2.767
hsw_true = -2.74243605445065
def rel_error(hsw):
    return jnp.abs(hsw - hsw_true) / jnp.abs(hsw_true) * 100


mean_mags_sp = hist_scalar_pot["mean_mag"]
mean_mags_vp = hist_vec_pot["mean_mag"]

def _g(x, mean_mags_sp):
    return np.array(jnp.interp(x, hext[:len(mean_mags_sp)][::-1], asarray(mean_mags_sp)[::-1]))

hsw_scalar_pot = bisect(_g, -3.5, 1, (mean_mags_sp,))
hsw_vec_pot = bisect(_g, -3.5, 1, (mean_mags_vp,))


print(f"h_c for scalar potential = {hsw_scalar_pot}, rel error", rel_error(hsw_scalar_pot), "%")
print(f"h_c for vector potential = {hsw_vec_pot}, rel error", rel_error(hsw_vec_pot), "%")
print("total number of stray field evaluations:", jnp.sum(hist_scalar_pot["n_field_eval"]))
print("total number of b-field evaluations:", jnp.sum(hist_vec_pot["n_field_eval"]))

In [None]:
%matplotlib widget

In [None]:
with plt.rc_context(rc={'text.usetex': True, 'text.latex.preamble': r"\usepackage{amsmath}", "axes.labelsize": 16, "axes.titlesize": 16}):
    fig = plt.figure(figsize=(8, 8))
    ax1, ax2, ax3, ax4, ax5 = fig.subplots(5, 1, sharex=True)
    ax1.plot(hext,  hist_vec_pot["mean_mag"], c="tab:red", label="vector potential")
    ax1.set_title("mean magnetization", fontsize=16)
    ax1.grid()
    ax1.annotate(f"$h_c={hsw_vec_pot:.3f}$", (hsw_vec_pot, 0), (hsw_vec_pot - 0.7, -0.5), xycoords="data", fontsize=14, arrowprops={"arrowstyle": "->"}, color="tab:red")
    ax2.plot(hext, hist_vec_pot["aux_m"]["e_ext"], c="tab:red")
    ax2.set_title("$e_{zee}$", fontsize=16)
    ax2.grid()

    ax3.plot(hext, hist_vec_pot["aux_m"]["e_ani"], c="tab:red")
    ax3.set_title("$e_{a}$", fontsize=16)
    ax3.grid()
    ax4.ticklabel_format(axis="y", style="sci", useOffset=True, scilimits=(0,0))
    ax4.plot(hext, hist_vec_pot["aux_m"]["e_ex"], c="tab:red")
    ax4.set_title("$e_{ex}$", fontsize=16)
    ax4.grid()
    ax5.plot(hext, hist_vec_pot["aux_m"]["e_d"], label=r"$e_d^{\mathbf{A}}$", c="tab:red")
    
    ax1.plot(hext, hist_scalar_pot["mean_mag"], "--", c="tab:blue", label="scalar potential")
    ax1.annotate(f"$h_c={hsw_scalar_pot:.3f}$", (hsw_scalar_pot, 0), (hsw_scalar_pot + 0.2, -0.3), xycoords="data", fontsize=14, arrowprops={"arrowstyle": "->"}, color="tab:blue")
    ax2.plot(hext, hist_scalar_pot["aux_m"]["e_ext"], "--", c="tab:blue")
    ax3.plot(hext, hist_scalar_pot["aux_m"]["e_ani"], "--", c="tab:blue")
    
    ax4.plot(hext, hist_scalar_pot["aux_m"]["e_ex"], "--", c="tab:blue")
    ax5.plot(hext, hist_scalar_pot["aux_m"]["e_d"], "--", label=r"$e_d^{\phi}$", c="tab:blue")
    
    ax1.legend(fontsize=16)
    
    ax5.legend(loc="right")
    ax5.set_title("$e_{d}$", fontsize=16)
    ax5.set_xlabel(r"$h_{ext} [-]$", fontsize=16)
    fig.text(0.04, 0.43, r"$e\;[\mu_0 M_s^2]$", ha='center', va='center', rotation='vertical', fontsize=16)
    fig.text(0.04, 0.89, r"$\langle \mathbf{m}\rangle \;[-]$", ha='center', va='center', rotation='vertical', fontsize=16)
    ax5.grid()
    ax5.set_xlim((min(hext),max(hext)))

    fig.tight_layout(rect=(0.03, 0, 1, 1))

    fig.savefig("demag_cube_scalar_and_vec_pot_elm_lbfgs.pdf")