In [None]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_platform_name", "cpu")

import neuralmag as nm

In [None]:
from neuralmag import config
config.dtype = "float64"
config.backlend = "jax"

In [None]:
freq = 800 * 1e3  # 800 kHz
h_ext = lambda t: jnp.stack([
    jnp.sin(2*jnp.pi * freq * t) * 10_000, #+ jnp.sin(2*jnp.pi * 100 * freq * t) * 1000,
    0.0,
    0.0,
])

# h_ext = lambda t: jnp.stack([
#     -t * 1e14,
#     0.0,
#     0.0,
# ])

ts = jnp.arange(0, 3e-6, 1e-7) # 1e-11 timsteps is working

print(ts.shape)

h = jax.vmap(h_ext)(ts)
plt.plot(ts, h[:, 0])
plt.show()

In [None]:
mesh = nm.Mesh((25, 25, 2), (5e-9, 5e-9, 5e-9))

state = nm.State(mesh)

state.material.Ms = 8e5
state.material.A = 1.3e-11
state.material.alpha = 0.4

state.m = nm.VectorFunction(state).fill((0.5**0.5, 0.5**0.5, 0))

nm.ExchangeField().register(state, "exchange")
nm.DemagField().register(state, "demag")
nm.TotalField("exchange", "demag").register(state)

In [None]:
mag = jnp.mean(state.m.tensor, axis=(-2, -3, -4))
print(mag)

In [None]:
llg = nm.LLGSolver(state, solver_type="Kvaerno5")
llg.relax()

In [None]:
mag = jnp.mean(state.m.tensor, axis=(-2, -3, -4))
print(mag)

In [None]:
mags = []
for t in tqdm(ts):
    nm.ExchangeField().register(state, "exchange")
    nm.DemagField().register(state, "demag")
    nm.ExternalField(h_ext(t)).register(state, "external")
    
    nm.TotalField("exchange", "demag", "external").register(state)
    
    llg.reset()
    llg.relax()
    mag = jnp.mean(state.m.tensor, axis=(-2, -3, -4))
    mags.append(mag)

In [None]:
plt.plot(h[:len(mags), 0], jnp.stack(mags)[..., 0])

In [None]:
jnp.stack(mags)[..., 0]