In [None]:
from mumag3 import *
from magpi.prelude import *
import matplotlib.pyplot as plt
from scipy.stats.qmc import Halton
from magpi.integrate import gauss5
from magpi.opt import TR

jax.config.update('jax_platform_name', 'gpu')
jax.config.update("jax_enable_x64", False)
%matplotlib widget

key = random.PRNGKey(0)



In [None]:
mu0 = 4*pi*10**-7  # Tm/A
Js = 1.61  # T
Ms = Js / mu0  # A/m
Km = mu0 * Ms**2  # J/m3
L = 70.0  # nm
Ka = 4.3e6 / Km  # -
easy_axis = unit_vec(array([0., 0., 1.]))
A = 7.3e-12 / Km * 1e18 / L ** 2 # 1/nm2
dh = 5e-3
hext = jnp.arange(-3.5, 1., dh)[::-1]  # -
hext_axis = unit_vec(array([1., 0., 10.]))
m0 = lambda x: zeros_like(x).at[..., 0].set(1.)

In [None]:
domain = Cuboid(
    linspace(-1 / 2, 1 / 2, 5),
    linspace(-1 / 2, 1 / 2, 5),
    linspace(-1 / 2, 1 / 2, 5)
)

x_dom = array(Halton(3, seed=42).random(2**12))
x_dom = domain.transform(x_dom)

x_tensor_dom = lax.map(
    lambda x: surface_tensors_grad(x, domain, 15, gauss5), 
    x_dom
)

In [None]:
_weights = array(Halton(4, seed=43).random(2 ** 9))
W_elm = (_weights[:, :3] * 2 - 1) * 2
b_elm = (_weights[:, 3] * 2 - 1) * 2

stray_field_solver = create_stray_field_solver(
    x_dom, domain, W_elm, b_elm,
    use_precomputed_grad_tensors=True
)

class PINN(nn.Module):
    @nn.compact
    def __call__(self, x):
        nodes = 20
        activation = nn.gelu
        y = activation(nn.Dense(nodes, name="dense1")(x))
        y = activation(nn.Dense(nodes, name="dense2")(y))
        y = activation(nn.Dense(nodes, name="dense3")(y))
        y = nn.Dense(3, name="dense5")(y)
        return y

key, _k = random.split(key)
mag_pinn_model = PINN()
pinn_params_init = mag_pinn_model.init(_k, zeros((3,)))

    
def mag(mag0, x, params):
    x = domain.normalize(x)
    p = mag_pinn_model.apply(params, x)
    m = mag0(x)
    return cayley_rotation(p, m)

def exchange_energy(m, x):
    def e_ex(x):
        dm = jacfwd(m)(x)
        return jnp.sum(dm * dm)
    return A * mean(vmap(e_ex)(x))

def ani_energy(m, x):
    def e_ani(x):
        return 1 - (m(x) @ easy_axis) ** 2

    return Ka * mean(vmap(e_ani)(x))

def mag_energy(hs, m, x, xt):
    def e_mag(x, xt):
        e = lambda x: dot(m(x), hs(x, xt))
        return e(x)

    return - 1 / 2 * mean(vmap(e_mag)(x, xt))

def ext_energy(m, x, hext):
    def e_ext(x):
        return m(x) @ hext
    
    return - mean(vmap(e_ext)(x))


@partial(jit)
def loss(params, hext):
    m = lambda x: mag(m0, x, params)
    e_ex = exchange_energy(m, x_dom)
    e_ani = ani_energy(m, x_dom)
    _m = lambda x: mag(m0, x, lax.stop_gradient(params))
    hs = stray_field_solver(_m)
    e_mag = mag_energy(hs, m, x_dom, x_tensor_dom)
    e_ext = ext_energy(m, x_dom, hext * hext_axis)
    l = e_ex + e_ani + e_ext + 2 * e_mag
    return l, {
        'e_tot': e_ex + e_ani + e_mag + e_ext,
        'e_mag': e_mag,
        'e_ex': e_ex,
        'e_ani': e_ani,
        'e_ext': e_ext
    }

In [None]:
init_params, init_state = TR(
    loss,
    has_aux=True, 
    maxiter=100,
    tol=1e-4,
    jit=True,
    unroll=False,
    rho_accept=1 / 5
).run(pinn_params_init, 1.0 + dh)
    

In [None]:
x_val = array(Halton(3, seed=41232).random(2**13))
x_val = domain.transform(x_val)

x_tensor_val = lax.map(
    lambda x: surface_tensors_grad(x, domain, 15, gauss5), 
    x_val
)

def train(params, maxiter=100, tol=1e-3, init_tr_radius=1., max_tr_radius=1.):
    @jit
    def mean_mag(params):
        m = lambda x: mag(m0, x, params)
        def m_proj(x):
            return m(x) @ hext_axis
        return mean(vmap(m_proj)(x_val))
    
    curvatures = []
    
    def cb(step):
        curvatures.append(step.state.steihaug_curvature)
        
    @jit
    def update(params, hext):   
        opt = TR(
            loss, 
            has_aux=True, 
            maxiter=maxiter,
            tol=tol,
            jit=True,
            unroll=False,
            init_tr_radius=init_tr_radius,
            max_tr_radius=max_tr_radius,
            rho_accept=1 / 5,
            callback=cb
        )
        return opt.run(params, hext)
        
    states = []
    mean_mags = []
    
    for _hext in hext:
        opt_step = update(params, _hext)
        params, state = opt_step
        states.append(state)
        mean_mags.append(mean_mag(params))
    

    iters = list(map(lambda s: s.iter_num, states))
    iters_cum = [0] + list(jnp.cumsum(jnp.array(iters)))
    c = list(map(lambda i: curvatures[iters_cum[i]:iters_cum[i]+iters[i]], range(len(iters))))
        
    return {
        'params': params,
        'mean_mags': mean_mags,
        'curvatures': list(map(lambda c: min(c), c)),
        'states': states
    }

In [None]:
result = train(
    init_params, 
    maxiter=100,
    tol=1e-4,
    init_tr_radius=0.1,
)

In [None]:
import numpy as np
mean_mags = result['mean_mags']
plt.figure(figsize=(7,3.5))
plt.plot(hext, mean_mags, "-")
plt.xlim((-3.5, 1.))
plt.ylim((-1.1, 1.1))
plt.grid()
plt.xlabel("$h_{ext}$")
plt.ylabel("$\\langle \\mathbf{m}\\rangle$")
plt.xticks(np.linspace(-3.5, 1, 10));
plt.tight_layout()


In [None]:
energies = list(map(lambda s: s.aux, result['states']))
fig = plt.figure(figsize=(7,3.5))
(ax1, ax2) = fig.subplots(2, 1, sharex=True)
ax1.plot(hext, list(map(lambda e: e['e_tot'], energies)), "-")
plt.plot(hext, result['curvatures'], "-")
ax1.grid()
ax2.grid()
ax2.set_xlabel("$h_{ext}$")
ax1.set_ylabel("$e \\left[\\mu_0M_s^2\\right]$")
ax2.set_ylabel("$curvature$")
ax1.set_xticks(np.linspace(-3.5, 1, 10));
fig.tight_layout()