In [None]:
import matplotlib.pyplot as plt
import numpy as np

import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_platform_name", "cpu")

import neuralmag as nm
import pyvista as pv

In [None]:
pv.set_jupyter_backend("static")

# first test:

In [None]:
mesh = nm.Mesh((10, 10, 1), (5e-9, 5e-9, 3e-9))

In [None]:
state = nm.State(mesh)

In [None]:
state.material.Ms = 8e5
state.material.A = 1.3e-11
state.material.alpha = 0.02

In [None]:
state.m = nm.VectorFunction(state).fill((0.5**0.5, 0.5**0.5, 0))

In [None]:
h_ext = nm.VectorFunction(state).fill([-19576.0, 3421.0, 0.0], expand=True)

nm.ExchangeField().register(state, "exchange")
nm.DemagField().register(state, "demag")
nm.ExternalField(h_ext).register(state, "external")

nm.TotalField("exchange", "demag").register(state)

In [None]:
llg = nm.LLGSolver(state)
llg.relax()

In [None]:
print(state.t)

In [None]:
# add external field to perform switch
nm.TotalField("exchange", "demag", "external").register(state)
llg.reset()

logger = nm.Logger("data", ["t", "m"], ["m"])
while state.t < 1e-9:
    logger.log(state)
    llg.step(1e-11)

In [None]:
data = np.loadtxt("data/log.dat")
plt.plot(data[:, 0], data[:, 1], label="m_x")
plt.plot(data[:, 0], data[:, 2], label="m_y")
plt.plot(data[:, 0], data[:, 3], label="m_z")
plt.legend()
plt.xlabel("t [s]")
plt.ylabel("m_i")
plt.show()

In [None]:
data.shape

In [None]:
t = jnp.linspace(0, 1e-9, 10_000)

In [None]:
plt.plot()

# Sinusoidal external field?

In [None]:
import os
os.environ["EQX_ON_ERROR"] = "breakpoint"

In [None]:
from neuralmag import config

In [None]:
config.dtype = "float64"
config.backlend = "jax"

In [None]:
mesh = nm.Mesh((25, 25, 2), (5e-9, 5e-9, 5e-9))

state = nm.State(mesh)

state.material.Ms = 8e5
state.material.A = 1.3e-11
state.material.alpha = 0.4

state.m = nm.VectorFunction(state).fill((0.5**0.5, 0.5**0.5, 0))

freq = 800 * 1e3  # 800 kHz
h_ext = lambda t: jnp.stack([
    jnp.sin(2*jnp.pi * freq * t) * 10_000, #+ jnp.sin(2*jnp.pi * 100 * freq * t) * 1000,
    0.0,
    0.0,
])

# h_ext = lambda t: jnp.stack([
#     -t * 1e14,
#     0.0,
#     0.0,
# ])

ts = jnp.arange(0, 1e-6, 1e-9) # 1e-11 timsteps is working

h = jax.vmap(h_ext)(ts)
plt.plot(ts, h[:, 0])
plt.show()

nm.ExchangeField().register(state, "exchange")
nm.DemagField().register(state, "demag")
nm.ExternalField(h_ext).register(state, "external")

nm.TotalField("exchange", "demag").register(state)

In [None]:
llg = nm.LLGSolver(state, solver_type="Kvaerno5")
llg.relax()

In [None]:
# add external field to perform switch
nm.TotalField("exchange", "demag", "external").register(state)
#llg.reset()
llg = nm.LLGSolver(state, scale_t=1e-9, max_steps=100_000)#, rtol=100, atol=100)

# logger = nm.Logger("data", ["t", "m"], ["m"])
# while state.t < 1e-8:
#     logger.log(state)
#     llg.step(1e-11)

# with jax.disable_jit():
sol = llg.solve(ts)

#llg.relax()

In [None]:
sol

In [None]:
# %debug

In [None]:
mag = jnp.mean(sol.ys, axis=(-2, -3, -4))
mag.shape

In [None]:
plt.plot(sol.ts * 1e-9, mag[:, 0], label="m_x")
plt.plot(sol.ts* 1e-9, mag[:, 1], label="m_y")
plt.plot(sol.ts* 1e-9, mag[:, 2], label="m_z")
plt.legend()
plt.xlabel("t [s]")
plt.ylabel("m_i")
plt.show()

In [None]:
h = jax.vmap(h_ext)(sol.ts * 1e-9)
plt.plot(sol.ts * 1e-9, h[:, 0])

In [None]:
plt.plot(h[:, 0], mag[:, 0])
plt.xlabel("H")
plt.ylabel("M")