In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import pathlib
import glob
from tqdm.notebook import tqdm
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
from mc2.utils.data_inspection import (
    get_available_material_names, get_file_overview, load_and_process_single_from_full_file_overview,
)
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis
from mc2.data_management import FrequencySet, MaterialSet, DataSet

In [None]:
import jax
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import equinox as eqx
import optax

---

In [None]:
from mc2.runners.model_setup_jax import setup_model
from mc2.utils.model_evaluation import reconstruct_model_from_exp_id, get_exp_ids

In [None]:
from mc2.model_interfaces.ja_interfaces import JAwInterface
from mc2.models.jiles_atherton import JAStatic, JAEnsemble

In [None]:
from mc2.models.RNN import GRUaroundLinearModel

In [None]:
wrapped_model, optimizer, params, (train_set, eval_set, test_set) = setup_model(model_label="GRUaroundLinearModel", material_name="3C90", model_key=jax.random.PRNGKey(0), n_epochs=300, tbptt_size=128, batch_size=512,)

In [None]:
model = GRUaroundLinearModel(
    in_size=7,
    hidden_size=6,
    linear_in_size=3,
    key=jax.random.key(0),
)
model

In [None]:
norm_test_set = test_set.normalize(wrapped_model.normalizer)

In [None]:
test_set.frequencies

In [None]:
def add_gaussian_noise(in_data, key, noise_std: float = 0.002):
    return in_data + jax.random.normal(noise_key, shape=in_data.shape) * noise_std

In [None]:
in_data = norm_test_set.at_frequency(50_000).B[0, 1000:2000]
noise_key = jax.random.PRNGKey(0)
noise_in_data = add_gaussian_noise(in_data, noise_key)

In [None]:
plt.plot(in_data, color="tab:blue")
plt.plot(noise_in_data, color="tab:orange")

In [None]:
model.construct_init_hidden(jnp.ones((200,2)), batch_size=200)

# JA:

In [None]:
wrapped_model, optimizer, params, (train_set, eval_set, test_set) = setup_model(model_label="JA", material_name="3C90", model_key=jax.random.PRNGKey(0), n_epochs=300, tbptt_size=128, batch_size=512,)

In [None]:
helper_wrapped_model = deepcopy(wrapped_model)

In [None]:
keys = jax.random.split(jax.random.key(0), 10)

model = eqx.filter_vmap(JAStatic)(keys)
model

In [None]:
key = jax.random.key(0) 
model = JAEnsemble(key, 100)

In [None]:
wrapped_model = JAwInterface(
    model = model,
    normalizer = helper_wrapped_model.normalizer,
    featurize = helper_wrapped_model.featurize,
)

In [None]:
wrapped_model

In [None]:
def set_physical_parameters(wrapped_model, Ms, a, alpha, k, c):
    inverse_sigmoid = lambda x: jnp.log(x / (1 - x))
    wrapped_model = eqx.tree_at(lambda m: m.model.Ms_param, wrapped_model, jnp.array(inverse_sigmoid(Ms / 2e6)))
    wrapped_model = eqx.tree_at(lambda m: m.model.a_param, wrapped_model, jnp.array(inverse_sigmoid(a / 100)))
    wrapped_model = eqx.tree_at(lambda m: m.model.alpha_param, wrapped_model, jnp.array(inverse_sigmoid(alpha / 1e-3)))
    wrapped_model = eqx.tree_at(lambda m: m.model.k_param, wrapped_model, jnp.array(inverse_sigmoid(k / 100)))
    wrapped_model = eqx.tree_at(lambda m: m.model.c_param, wrapped_model, jnp.array(inverse_sigmoid(float(c))))
    return wrapped_model

def set_parameters(wrapped_model, Ms_param, a_param, alpha_param, k_param, c_param):
    wrapped_model = eqx.tree_at(lambda m: m.model.Ms_param, wrapped_model, jnp.array(float(Ms_param)))
    wrapped_model = eqx.tree_at(lambda m: m.model.a_param, wrapped_model, jnp.array(float(a_param)))
    wrapped_model = eqx.tree_at(lambda m: m.model.alpha_param, wrapped_model, jnp.array(float(alpha_param)))
    wrapped_model = eqx.tree_at(lambda m: m.model.k_param, wrapped_model, jnp.array(float(k_param)))
    wrapped_model = eqx.tree_at(lambda m: m.model.c_param, wrapped_model, jnp.array(float(c_param)))
    return wrapped_model

In [None]:
plt.plot(test_set.at_frequency(80_000).B[0, 1000:2000] / jnp.max(jnp.abs(test_set.at_frequency(80_000).B[0, 1000:2000])), label="B")
plt.plot(test_set.at_frequency(80_000).H[0, 1000:2000] / jnp.max(jnp.abs(test_set.at_frequency(80_000).H[0, 1000:2000])), label="H")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
T_test = test_set.at_frequency(50_000).T[0:1]


B_test = test_set.at_frequency(800_000).B[5:6, 1500:2000]
H_test = test_set.at_frequency(800_000).H[5:6, 1500:2000]

# B_test = jnp.linspace(0, -1, 1000) 
# B_test = jnp.concatenate([B_test, jnp.linspace(-1, 0.4, 1000)], axis=0)
# B_test = jnp.concatenate([B_test, jnp.linspace(0.4, 0.3, 1000)], axis=0)
# B_test = jnp.concatenate([B_test, jnp.linspace(0.3, 1, 1000)], axis=0)
# B_test = jnp.concatenate([B_test, jnp.linspace(1, -1, 2000)], axis=0)
# # B_test = jnp.concatenate([B_test, jnp.linspace(-0.3, -0.5, 1000)], axis=0)
# # B_test = jnp.concatenate([B_test, jnp.linspace(-0.5, 0.15, 1000)], axis=0)
# # B_test = jnp.concatenate([B_test, jnp.linspace(0.15, -0.05, 1000)], axis=0)
# # B_test = jnp.concatenate([B_test, jnp.linspace(-0.049, 0.1, 1000)], axis=0)
# B_test = B_test[None, ...] * 0.25
# H_test = jnp.zeros(B_test.shape)

In [None]:
# wrapped_model = set_physical_parameters(
#     wrapped_model, 
#     Ms=5e5,
#     a=50.,
#     alpha=1.47 * 1e-4,
#     k=25,
#     c=0.6
# )

# wrapped_model = set_parameters(
#     wrapped_model, 
#     Ms_param=0.0,
#     a_param=0.0,
#     alpha_param=-6.0,
#     k_param=0.0,
#     c_param=0.0,
# )

#print(wrapped_model.model.physical_params)

H_pred = wrapped_model(
    B_past = B_test[:, :100],
    B_future = B_test[:, 100:],
    H_past = H_test[:, :100],
    T=T_test,
)

if 'gru_model' not in locals():
    gru_model = reconstruct_model_from_exp_id('3C90_GRU_72562eee-55a6-48')
gru_pred = gru_model(
    B_past = B_test[:, :100],
    B_future = B_test[:, 100:],
    H_past = H_test[:, :100],
    T=T_test,
)

x = jnp.arange(0, B_test.shape[1], 1)

plt.plot(x, jnp.squeeze(H_test), label="H_test", c="tab:blue")
plt.plot(x[100:], jnp.squeeze(H_pred), label="H_pred", c="tab:orange")
plt.plot(x[100:], jnp.squeeze(gru_pred), label="GRU_pred", c="tab:green")
# plt.plot(x, jnp.squeeze(B_test) / jnp.max(jnp.abs(B_test)), label="B_test") 
plt.legend()
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

plt.plot(jnp.squeeze(H_test[:, 100:]), jnp.squeeze(B_test[:, 100:]), c="tab:blue")
plt.plot(jnp.squeeze(H_pred), jnp.squeeze(B_test[:, 100:]), c="tab:orange")
plt.xlabel("H")
plt.ylabel("B")
plt.xlabel("H")
plt.ylabel("B")
plt.show()


plt.plot(jnp.squeeze(gru_pred), jnp.squeeze(B_test[:, 100:]), c="tab:green") 
plt.plot(jnp.squeeze(H_test[:, 100:]), jnp.squeeze(B_test[:, 100:]), c="tab:blue")
plt.xlabel("H")
plt.ylabel("B")
plt.show()


plt.plot(jnp.squeeze(gru_pred), jnp.squeeze(B_test[:, 100:]), c="tab:green") 
plt.plot(jnp.squeeze(H_pred), jnp.squeeze(B_test[:, 100:]), c="tab:orange")
plt.xlabel("H")
plt.ylabel("B")
plt.show()
plt.show()


# plt.plot(jnp.squeeze(H_pred))


In [None]:
wrapped_model

In [None]:
wrapped_model.model.physical_params

In [None]:
gru_model

In [None]:
wrapped_model.model.params

In [None]:
wrapped_model.model.physical_params

In [None]:
raise

In [None]:
plt.plot(jnp.squeeze(H_pred), jnp.squeeze(B_test[:, 100:]))
plt.xlabel("H")
plt.ylabel("B")

In [None]:
plt.plot(jnp.squeeze(H_test[:, 100:]), jnp.squeeze(B_test[:, 100:]))

In [None]:
from mc2.model_interfaces.linear_interfaces import LinearInterface
from mc2.models.linear import LinearStatic

from mc2.models.RNN import GRUwLinearModel, GRU
from mc2.model_interfaces.rnn_Interfaces import GRUwLinearModelInterface, MagnetizationRNNwInterface, RNNwInterface

In [None]:
# model = LinearStatic(11, 1, key=jax.random.PRNGKey(0)) 

# model = GRUwLinearModel(in_size=7, hidden_size=8, linear_in_size=7, key=jax.random.PRNGKey(0))

In [None]:
# wrapped_linear_model = LinearInterface(model, normalizer=wrapped_model.normalizer, featurize=wrapped_model.featurize)

# wrapped_model = GRUwLinearModelInterface(model, normalizer=wrapped_model.normalizer, featurize=wrapped_model.featurize)

In [None]:
# model_params_d = dict(hidden_size=8, in_size=7, key=jax.random.PRNGKey(0))
# model = GRU(**model_params_d)

# wrapped_model = RNNwInterface(
#     model=model,
#     normalizer=wrapped_model.normalizer,
#     featurize=wrapped_model.featurize
# )

# wrapped_model = MagnetizationRNNwInterface(
#     model=model,
#     normalizer=wrapped_model.normalizer,
#     featurize=wrapped_model.featurize
# )

In [None]:
from mc2.features.features_jax import db_dt, d2b_dt2, dyn_avg

In [None]:
seq_idx = 2
seq_start = 0
seq_len = 100

B_test = test_set.at_frequency(800_000).B[seq_idx, seq_start:seq_start+seq_len]
H_test = test_set.at_frequency(800_000).H[seq_idx, seq_start:seq_start+seq_len]

In [None]:
B_test_norm = B_test / jnp.max(jnp.abs(B_test))
H_test_norm = H_test / jnp.max(jnp.abs(H_test))

In [None]:
plt.plot(B_test_norm, label="B")
plt.plot(H_test_norm, label="H")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

In [None]:
from mc2.features.features_jax import shift_signal

# B_shifted = shift_signal(B_test_norm, k_0=5)

correlation_values = jnp.correlate(
    B_test_norm - jnp.mean(B_test_norm),
    H_test_norm - jnp.mean(H_test_norm),
    mode="full",
)

x = jnp.arange(-seq_len+1, seq_len, 1)

plt.plot(x, correlation_values)
plt.grid(True, alpha=0.3)
plt.show()
print()

In [None]:
plt.plot(shift_signal(B_test_norm - jnp.mean(B_test_norm), 0), label="B", linestyle="dashed")
plt.plot(shift_signal(B_test_norm - jnp.mean(B_test_norm), 3), label="B_shifted")
plt.plot(H_test_norm - jnp.mean(H_test_norm), label="H")
plt.grid(True, alpha=0.3)
plt.legend()

In [None]:
signal1 = B_test_norm - jnp.mean(B_test_norm)
signal2 = H_test_norm - jnp.mean(H_test_norm)

def min_max_norm(x):
    min_x = jnp.min(x)
    max_x = jnp.max(x)

    return (x - min_x) / (max_x - min_x)

signal1 = min_max_norm(signal1)
signal2 = min_max_norm(signal2)

In [None]:
plt.plot(db_dt(signal1), label="B")
plt.plot(db_dt(signal2), label="H")
plt.grid(True, alpha=0.3)

In [None]:
plt.plot(db_dt(B_test_norm - jnp.mean(B_test_norm)), label="B")
plt.plot(db_dt(H_test_norm - jnp.mean(H_test_norm)), label="H")
plt.grid(True, alpha=0.3)

In [None]:
plt.plot(B_shifted, label="B_shifted")
plt.plot(B_test_norm, label="B")
plt.plot(H_test_norm, label="H")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

In [None]:
plt.plot(db_dt(B_shifted), label="B_shifted")
plt.plot(db_dt(B_test_norm), label="B", linestyle="dashed")
plt.plot(db_dt(H_test_norm), label="H")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

In [None]:
plt.plot(d2b_dt2(B_shifted), label="B_shifted")
plt.plot(d2b_dt2(B_test_norm), label="B")
plt.plot(d2b_dt2(H_test_norm), label="H")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

In [None]:
plt.plot(dyn_avg(B_test_norm, n_s=15, mirrored_padding=True), label="B_averaged")
plt.plot(B_test_norm, label="B")
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

---

In [None]:
B_test = test_set.at_frequency(80_000).B[:10, :1000]
H_past = test_set.at_frequency(80_000).H[:10, :100]
H_future = test_set.at_frequency(80_000).H[:10, 100:1000]
T = test_set.at_frequency(80_000).T[:10]

In [None]:
B_past = B_test[:, :100]
B_future = B_test[:, 100:]

In [None]:
H_est = wrapped_model(B_past, H_past, B_future, T)
H_est.shape

In [None]:
# B_future_norm, H_est_norm, T_norm = wrapped_linear_model.normalizer.normalize(B_future, H_est, T)
# _, H_future_norm, _ = wrapped_linear_model.normalizer.normalize(B_future, H_future, T)

B_future_norm, H_est_norm, T_norm = wrapped_model.normalizer.normalize(B_future, H_est, T)
_, H_future_norm, _ = wrapped_model.normalizer.normalize(B_future, H_future, T)

plt.plot(B_future_norm[2])
plt.show()
plt.plot(H_future_norm[2])
plt.plot(H_est_norm[2])

In [None]:
plt.plot(B_test[2])
plt.grid(True, alpha=0.3)
plt.show()
plt.plot(jnp.hstack([H_past[2], H_future[2]]))
plt.grid(True, alpha=0.3)
plt.plot()

In [None]:
# B_future_norm, H_est_norm, T_norm = wrapped_linear_model.normalizer.normalize(B_future, H_est, T)
# _, H_future_norm, _ = wrapped_linear_model.normalizer.normalize(B_future, H_future, T)

B_future_norm, H_est_norm, T_norm = wrapped_model.normalizer.normalize(B_future, H_est, T)
_, H_future_norm, _ = wrapped_model.normalizer.normalize(B_future, H_future, T)

plt.plot(B_future_norm[2])
plt.plot(H_future_norm[2])
plt.plot(H_est_norm[2])

In [None]:
B_all = B_test

In [None]:
B_all_padded = jnp.pad(B_all, ((0, 0), (5, 5)), mode='reflect', reflect_type="odd")

In [None]:
B_all_padded.shape

In [None]:
B_in = jnp.concatenate([jnp.roll(B_all_padded,idx)[..., None] for idx in jnp.arange(-5, 5 + 1e-7, 1,)], axis=-1)[:, 5 + 100: -5, :]

In [None]:
plt.plot(B_future[0])

In [None]:
plt.plot(B_in[0, :, 0])

In [None]:
from mc2.utils.data_plotting import plot_single_sequence, plot_hysteresis

In [None]:
B_past = jnp.ones((1,1)) * 0.01
H_past = jnp.ones((1,1)) * 0.01
B_future = jnp.concatenate([jnp.linspace(0.01, 0.4, 1000), jnp.linspace(0.4, -0.4, 2000), np.linspace(-0.4, 1, 2000)])[None, :]
H_future = jnp.ones_like(B_future)
T=jnp.array([25])

In [None]:
H_pred = wrapped_model(B_past, H_past, B_future, T)

In [None]:
fig, axs = plot_single_sequence(B_future[0], H_future[0], T=T)
axs[-1].plot(H_pred[0])

In [None]:
fig, axs = plot_hysteresis(B_future[0], H_future[0], T=T)
axs.plot(H_pred[0], B_future[0] / H_pred[0])

In [None]:
mu_0 = 4 * jnp.pi * 1e-7

plt.plot(H_pred[0], B_future[0] / H_pred[0] / mu_0) 
plt.ylim(0, 5000)

# LLG:

## testing

In [None]:
from mc2.models.RNN import VectorfieldGRU
from mc2.model_interfaces.rnn_Interfaces import VectorfieldGRUInterface

In [None]:
material_set = MaterialSet.from_material_name("3C90")
train_set, val_set, test_set = material_set.split_into_train_val_test(
    train_frac=0.7, val_frac=0.15, test_frac=0.15, seed=0
)

In [None]:
# model = VectorfieldGRU(in_size=6, n_locs=9, key=jax.random.key(0))
# wrapped_model = VectorfieldGRUInterface(
#     model=model,
#     normalizer=wr_model.normalizer,
#     featurize=wr_model.featurize,
# )


# wrapped_model, _, _, (train_set, eval_set, test_set) = setup_model(
#     model_label="VectorfieldGRU",
#     material_name="N49",
#     model_key=jax.random.PRNGKey(0),
#     n_epochs=300,
#     tbptt_size=128,
#     batch_size=512,
# )

wrapped_model = reconstruct_model_from_exp_id('3C90_VectorfieldGRU_0b3f92ef-91e8-4f')
wrapped_model

In [None]:
wrapped_model.normalizer.H_inverse_transform(1.0)

In [None]:
wrapped_model.n_params

In [None]:
subset = test_set.at_frequency(50_000)
seq_idx = 2
past_size = 10
seq_length = 1000

H = subset.H[seq_idx:seq_idx+3]
B = subset.B[seq_idx:seq_idx+3]
T = subset.T[seq_idx:seq_idx+3]

B_past = B[:, :past_size]
B_future = B[:, past_size:seq_length]
H_past  = H[:, :past_size]
H_future = H[:, past_size:seq_length]

In [None]:
H_pred, mag_pred_norm = wrapped_model(B_past, H_past, B_future, T, warmup=False, debug=True)
H_pred.shape

In [None]:
plt.plot(H_pred[0], jnp.squeeze(B_future[0]), c="tab:orange")
plt.plot(jnp.squeeze(H_future)[0], jnp.squeeze(B_future[0]), c="tab:blue")
plt.show()

plt.plot(H_pred[0], c="tab:orange")
plt.plot(jnp.squeeze(H_future[0]), c="tab:blue")
plt.show()

plt.plot(B_future[0], c="tab:blue")
plt.show()

In [None]:
from ipywidgets import interact, IntSlider

In [None]:
n_locs = wrapped_model.model.n_locs
locs_per_dim = int(jnp.sqrt(n_locs))
locs_per_dim

In [None]:
data = mag_pred_norm[0].reshape((-1, locs_per_dim, locs_per_dim, 2))

T, nx, ny, _ = data.shape
x = np.arange(nx)
y = np.arange(ny)


def plot_quiver(k):
    fig, axs = plt.subplots(4, 1, figsize=(6, 8))
    
    # Quiver placeholder
    U0 = data[k, :, :, 0]
    V0 = data[k, :, :, 1]
    quiv = axs[0].quiver(x, y, U0, V0)
    
    axs[0].set_xlim(-0.5, nx - 0.5)
    axs[0].set_ylim(-0.5, ny - 0.5)
    axs[0].set_aspect("equal")
    
    m_x = data[..., 0]
    for m_i in m_x.reshape((-1, nx*ny)).T:
        axs[1].plot(m_i, alpha=.4, c="tab:orange")
    axs[1].plot(jnp.sum(m_x, axis=(-1, -2)), c="tab:orange")
    
    ymin1, ymax1 = axs[1].get_ylim()
    vline1, = axs[1].plot([k, k], [ymin1, ymax1], color="r", linestyle=":", linewidth=1)

    m_y = data[..., 1]
    for m_i in m_y.reshape((-1, nx*ny)).T:
        axs[2].plot(m_i, alpha=.4, c="tab:green")
    axs[2].plot(jnp.sum(m_y, axis=(-1, -2)), c="tab:green")
    
    axs[3].plot(H_pred[0], c="tab:orange", label="H_hat")
    axs[3].plot(jnp.squeeze(H_future[0]), c="tab:blue", label="H_true")
    ymin2, ymax2 = axs[3].get_ylim()
    vline2, = axs[3].plot([k, k], [ymin2, ymax2], color="r", linestyle=":", linewidth=1)
    axs[3].legend()
    
    for ax in axs:
        ax.grid(alpha=0.3)
    
    plt.tight_layout()

    plt.tight_layout()
    plt.show()

In [None]:
interact(
    plot_quiver,
    k=IntSlider(min=0, max=T-1, step=1, value=0, description="timestep")
);

In [None]:
# which is the summing direction?

## GIF:

In [None]:
from matplotlib.animation import FuncAnimation, PillowWriter

In [None]:
T, nx, ny, _ = data.shape

# Create the grid for quiver
x = np.arange(nx)
y = np.arange(ny)
X, Y = np.meshgrid(x, y, indexing="ij")

# --- Set up the figure ---
fig, axs = plt.subplots(4, 1, figsize=(6, 8))

# Quiver placeholder
U0 = data[0, :, :, 0]
V0 = data[0, :, :, 1]
quiv = axs[0].quiver(X, Y, U0, V0)

axs[0].set_xlim(-0.5, nx - 0.5)
axs[0].set_ylim(-0.5, ny - 0.5)
axs[0].set_aspect("equal")

m_x = data[..., 0]
for m_i in m_x.reshape((-1, nx*ny)).T:
    axs[1].plot(m_i, alpha=.4, c="tab:orange")
axs[1].plot(jnp.sum(m_x, axis=(-1, -2)), c="tab:orange")

ymin1, ymax1 = axs[1].get_ylim()
vline1, = axs[1].plot([0, 0], [ymin1, ymax1], color="r", linestyle=":", linewidth=1)


m_y = data[..., 1]
for m_i in m_y.reshape((-1, nx*ny)).T:
    axs[2].plot(m_i, alpha=.4, c="tab:green")
axs[2].plot(jnp.sum(m_y, axis=(-1, -2)), c="tab:green")
ymin3, ymax3 = axs[2].get_ylim()
vline3, = axs[2].plot([0, 0], [ymin3, ymax3], color="r", linestyle=":", linewidth=1)


axs[3].plot(H_pred[0], c="tab:orange", label="H_hat")
axs[3].plot(jnp.squeeze(H_future[0]), c="tab:blue", label="H_true")
ymin2, ymax2 = axs[3].get_ylim()
vline2, = axs[3].plot([0, 0], [ymin2, ymax2], color="r", linestyle=":", linewidth=1)
axs[3].legend()

for ax in axs:
    ax.grid(alpha=0.3)

plt.tight_layout()

# --- Animation update function ---
def update(k):
    # Update quiver
    U = data[k, :, :, 0]
    V = data[k, :, :, 1]
    quiv.set_UVC(U, V)

    # Update vertical line
    vline1.set_xdata([k, k])
    vline1.set_ydata([ymin1, ymax1])

    vline2.set_xdata([k, k])
    vline2.set_ydata([ymin2, ymax2])

    vline3.set_xdata([k, k])
    vline3.set_ydata([ymin3, ymax3])

    # Update titles
    axs[0].set_title(f"Vector field (t={k})")
    return quiv, vline1, vline2, vline3

# --- Create animation ---
anim = FuncAnimation(fig, update, frames=T, interval=200, blit=False)

# --- Save as GIF ---
anim.save("vector_field_timeseries.gif", dpi=120, writer=PillowWriter(fps=5))

print("GIF saved as vector_field_timeseries.gif")

## Physical Interaction:

### LLG version:

- based on:
    - mumax documentation,
    - "The design and verification of MuMax3",
    - mumax open source code

In [None]:
def random_unit_vectors_2d(n, key):
    angles = jax.random.uniform(key, n, minval=0, maxval=2*jnp.pi)
    vectors = jnp.column_stack((jnp.cos(angles), jnp.sin(angles)))
    return vectors

In [None]:
m_t = jnp.concatenate([jnp.zeros((25, 1)), jnp.ones((25, 1))], axis=-1).reshape((5,5,2))
# m_t = random_unit_vectors_2d(9, jax.random.PRNGKey(0)).reshape(3,3,2)
# m_t = jnp.concatenate([jnp.zeros((25, 1)), jnp.ones((25, 1))], axis=-1).reshape((5,5,2))
m_t = m_t.at[:1, :, 1].set(-1/jnp.sqrt(2))
m_t = m_t.at[:1, :, 0].set(+1/jnp.sqrt(2))
# m_t = m_t.at[2, :, 1].set(0)
# m_t = m_t.at[2, :, 0].set(-1)

m_t.shape

In [None]:
nx, ny, _ = m_t.shape

# Create the grid for quiver
X = np.arange(nx)
Y = np.arange(ny)

plt.quiver(X, Y, m_t[..., 0], m_t[..., 1])

In [None]:
H_ext = jnp.array([[0.0, 0.0]])

In [None]:
def laplacian(m, dx=1, dy=1):
    grad_y, grad_x = jnp.gradient(m, dx, dy, axis=(0,1))
    grad_xx = jnp.gradient(grad_x, dx, axis=1)
    grad_yy = jnp.gradient(grad_y, dy, axis=0)
    return(grad_xx + grad_yy)

def get_effective_field(m, H_ext):

    H_demag = 0 
    H_exc = laplacian(m)
    
    H_eff = H_ext + H_demag + H_exc
    return H_eff
    
@eqx.filter_jit
def LLG_dynamics(m, H_ext, gamma, alpha):
    H_eff = get_effective_field(m, H_ext)

    mx, my = (m[..., 0], m[..., 1])
    Hx, Hy = (H_eff[..., 0], H_eff[..., 1])
    expl_cross_product = jnp.concatenate([
        jnp.array(my*mx*Hy - my**2 * Hx)[..., None], 
        jnp.array(-mx**2 * Hy + mx*my*Hx)[..., None],
    ], axis=-1)

    dm_dt = - jnp.abs(gamma) / (1 + alpha**2) * (alpha * expl_cross_product)

    return dm_dt

In [None]:
ms = [m_t]
m = m_t
n_steps = 10_000

for t in range(n_steps):
    m_next = m + 0.1 * LLG_dynamics(m, H_ext, gamma=0.1, alpha=10)
    m = m_next
    ms.append(m)

In [None]:
fig, axs = plt.subplots(3, 1, figsize=(10, 10))

ms = jnp.array(ms)
ms_plot = ms.reshape(n_steps+1, -1, 2)
for i in range(ms_plot.shape[1]):
    axs[0].quiver(X, Y, ms[-1, ..., 0], ms[-1, ..., 1])
    axs[0].set_xlim(-0.5, nx - 0.5)
    axs[0].set_ylim(-0.5, ny - 0.5)
    axs[0].set_aspect("equal")
    axs[1].plot(ms_plot[:, i, 0], alpha=0.4, c="tab:blue")
    axs[2].plot(ms_plot[:, i, 1], alpha=0.4, c="tab:blue")

for ax in axs:
    ax.grid(True, alpha=0.3)

In [None]:
data = deepcopy(ms)

T, nx, ny, _ = data.shape
x = np.arange(nx)
y = np.arange(ny)


def plot_quiver(k):
    fig, axs = plt.subplots(4, 1, figsize=(6, 8))
    
    # Quiver placeholder
    U0 = data[k, :, :, 0]
    V0 = data[k, :, :, 1]
    quiv = axs[0].quiver(x, y, U0, V0)
    
    axs[0].set_xlim(-0.5, nx - 0.5)
    axs[0].set_ylim(-0.5, ny - 0.5)
    axs[0].set_aspect("equal")
    
    m_x = data[..., 0]
    for m_i in m_x.reshape((-1, nx*ny)).T:
        axs[1].plot(m_i, alpha=.4, c="tab:blue")
    # axs[1].plot(jnp.sum(m_x, axis=(-1, -2)), c="tab:orange")
    
    ymin1, ymax1 = axs[1].get_ylim()
    vline1, = axs[1].plot([k, k], [ymin1, ymax1], color="r", linestyle=":", linewidth=1)

    m_y = data[..., 1]
    for m_i in m_y.reshape((-1, nx*ny)).T:
        axs[2].plot(m_i, alpha=.4, c="tab:blue")
    # axs[2].plot(jnp.sum(m_y, axis=(-1, -2)), c="tab:green")
    
    # axs[3].plot(H_pred[0], c="tab:orange", label="H_hat")
    # axs[3].plot(jnp.squeeze(H_future[0]), c="tab:blue", label="H_true")
    # ymin2, ymax2 = axs[3].get_ylim()
    # vline2, = axs[3].plot([k, k], [ymin2, ymax2], color="r", linestyle=":", linewidth=1)
    # axs[3].legend()
    
    for ax in axs:
        ax.grid(alpha=0.3)
    
    plt.tight_layout()

    plt.tight_layout()
    plt.show()

In [None]:
interact(
    plot_quiver,
    k=IntSlider(min=0, max=T-1, step=1, value=0, description="timestep")
);

In [None]:
m_t.shape

In [None]:
lap = laplacian(m_t)

In [None]:
nx, ny, _ = m_t.shape

# Create the grid for quiver
X = np.arange(nx)
Y = np.arange(ny)

plt.quiver(X, Y, m_t[..., 0], m_t[..., 1])

In [None]:
nx, ny, _ = lap.shape

# Create the grid for quiver
X = np.arange(nx)
Y = np.arange(ny)

plt.quiver(X, Y, lap[..., 0], lap[..., 1])

### Dipol version:

In [None]:
def random_unit_vectors_2d(n, key):
    angles = jax.random.uniform(key, n, minval=0, maxval=2*jnp.pi)
    vectors = jnp.column_stack((jnp.cos(angles), jnp.sin(angles)))
    return vectors

In [None]:
m_t = jnp.concatenate([jnp.zeros((2, 1)), jnp.ones((2, 1))], axis=-1).reshape((1,2,2))
# m_t = jnp.array([[0.1, -0.8], [0, 1]]).reshape((1,2,2))

m_t = random_unit_vectors_2d(2, jax.random.PRNGKey(4)).reshape(1,2,2)
# m_t = jnp.concatenate([jnp.zeros((25, 1)), jnp.ones((25, 1))], axis=-1).reshape((5,5,2))

#m_t = m_t.at[:1, :, 1].set(-1/jnp.sqrt(2))
#m_t = m_t.at[:1, :, 0].set(+1/jnp.sqrt(2))
# m_t = m_t.at[2, :, 1].set(0)
# m_t = m_t.at[2, :, 0].set(-1)

# m_t.shape

In [None]:
nx, ny, _ = m_t.shape

# Create the grid for quiver
X = np.arange(nx)
Y = np.arange(ny)

xx, yy = np.meshgrid(X,Y)

plt.quiver(xx, yy, m_t[..., 0], m_t[..., 1])

In [None]:
init_state = dict(
    m=m_t[0],
    omega=jnp.array([0.0, 0.0]),
)
init_state

In [None]:
def angle_from_m(m):
    return jnp.arctan2(m[..., 1], m[..., 0])

In [None]:
angle_from_m(init_state["m"]) / (2*jnp.pi) * 360

In [None]:
mu0 = 4 * jnp.pi * 1e-7

def get_torque(m1, m2, r):

    r_magnitude = jnp.linalg.norm(r)
    r_unit = r / r_magnitude
    
    prefactor = 1 / (4 * jnp.pi * r_magnitude**3)
    field = prefactor * (3 * (m2 * r_unit) * r_unit - m2)

    torque = jnp.cross(m1, field)
    return torque

def step_rot_ode(state, tau=1, J=100, d=1e-1):
    
    theta = angle_from_m(state["m"])
    omega = state["omega"]
    m = state["m"]

    torques = jnp.array([
        get_torque(m[0], m[1], r=jnp.array([0, 1])),
        get_torque(m[1], m[0], r=jnp.array([0, -1])),
    ])

    domega = torques / J - d * omega
    dtheta = omega

    next_omega = omega + tau * domega 
    next_theta = theta + tau * dtheta

    # print(next_theta / (2 * jnp.pi) * 360)

    next_m = jnp.array([jnp.cos(next_theta), jnp.sin(next_theta)]).T * jnp.linalg.norm(m, axis=-1)[None, ...]

    next_state = dict(
        omega=next_omega,
        m=next_m
    )
    return next_state
    

In [None]:
state =  init_state
plt.quiver(xx, yy, state["m"][..., 0], state["m"][..., 1])
plt.show() 

ms = [state["m"]]
omegas = [state["omega"]]

for i in tqdm(range(1000)):
    next_state = step_rot_ode(state)
    state = next_state

    ms.append(state["m"])
    omegas.append(state["omega"])
    if i % 100 == 0 and i > 0:
        plt.quiver(xx, yy, next_state["m"][..., 0], next_state["m"][..., 1])
        plt.show() 

ms = jnp.stack(ms)
omegas = jnp.stack(omegas)

In [None]:
plt.plot(ms[:, 0], color="tab:blue")
plt.plot(ms[:, 1], color="tab:orange")
plt.show()

plt.plot(jnp.linalg.norm(ms[:, 0], axis=-1), color="tab:blue")
plt.plot(jnp.linalg.norm(ms[:, 1], axis=-1), color="tab:orange")
plt.show()

plt.plot(angle_from_m(ms[:, 0]), color="tab:blue")
plt.plot(angle_from_m(ms[:, 1]), color="tab:orange")
plt.show()


plt.plot(omegas[:, 0], color="tab:blue")
plt.plot(omegas[:, 1], color="tab:orange")
plt.show()

# Grid version:

In [None]:
from mc2.models.physical_regularization.dipole_interaction import DipoleGrid, random_unit_vectors_2d, step_rot_ode, simulate_rot_ode

In [None]:
init_state = DipoleGrid.from_random_key(jax.random.key(12), 400)

init_state.visualize()
plt.show()
print(get_gibbs_energy(init_state, ext_field = jnp.array([0.0, 0.0]), params=params))

tau = 1
J = 100
d = 1e-1

n_elements = 1000

# ext_fields = jnp.stack(
#     [
#         jnp.zeros(n_elements),
#         -jnp.sin(jnp.linspace(0, 8*jnp.pi, n_elements)) * 10# * jnp.linspace(0, 4, n_elements)
#     ],
#     axis=-1
# ) 

ext_fields = jnp.zeros((n_elements,2))

states, ms, omegas = simulate_rot_ode(init_state, ext_fields, tau, J, d)

states[-1].visualize()
print(get_gibbs_energy(states[-1], ext_field = jnp.array([0.0, 0.0]), params=params))

In [None]:
M = jnp.sum(ms, axis=(-3, -2))
M.shape

plt.plot(M)
plt.show()

plt.plot(ext_fields)

In [None]:
#plt.plot(H[:-1, 0], ext_fields[..., 0])
plt.plot(ext_fields[..., 1], M[:-1, 1])
plt.xlabel("H_ext")
plt.ylabel("M")

In [None]:
#plt.plot(H[:-1, 0], ext_fields[..., 0])
plt.plot(ext_fields[..., 1], M[:-1, 1] / 1500 + ext_fields[..., 1] * 3e-3)
plt.xlabel("H_ext")
plt.ylabel("B")

In [None]:
init_state = DipoleGrid.from_data(
    m=init_state.m.at[:, :, 0].set(1.0),
    pos=init_state.pos,
    omega=init_state.omega,
    distance=1,
)

# m = random_unit_vectors_2d(12, jax.random.key(2)).reshape(12, 1, 2)
# n_x = m.shape[0]
# n_y = m.shape[1]
# distance = 0.1

# xx, yy = jnp.meshgrid(jnp.arange(0, n_x, 1), jnp.arange(0, n_y, 1), indexing="ij")
# pos = jnp.concatenate([xx[..., None], yy[..., None]], axis=-1)

# init_state = DipoleGrid(
#     m=m,
#     n_elements=n_x * n_y,
#     n_x=n_x,
#     n_y=n_y,
#     pos=pos,
#     omega=jnp.zeros(pos.shape[:-1]),
#     distance=distance,
# )

In [None]:
ext_field = jnp.array([0.0, 0.0])

ms = [state.m]
omegas = [state.omega]

for i in tqdm(range(100000)):

    # if i > 100_000:
    #     ext_field = jnp.array([-0.1, -0.9])
    # elif i > 70_000:
    #     ext_field = jnp.array([0.0, 1.0])
    # elif i > 50_000 :
    #     ext_field = jnp.array([1.0, 0.0])

    next_state = step_rot_ode(state, ext_field, tau, J, d)
    state = next_state

    ms.append(state.m)
    omegas.append(state.omega)
    if i % 10_000 == 0 and i > 0:
        state.visualize()
        plt.show()

state.visualize()
plt.show()

ms = 
omegas = jnp.stack(omegas)

In [None]:
ms = ms.reshape(-1, state.n_elements, 2)


plt.plot(jnp.sum(ms[..., 0], axis=-1))
plt.plot(jnp.sum(ms[..., 1], axis=-1))
plt.show()

In [None]:
def angle_from_m(m):
    return jnp.arctan2(m[..., 1], m[..., 0])

ms = ms.reshape(-1, state.n_elements, 2)
omegas = omegas.reshape(-1, state.n_elements)

for idx in range(min(ms.shape[1], 20)):

    fig, axs = plt.subplots(1,4, figsize=(16, 4))

    axs[0].set_title("value")
    axs[0].plot(ms[:, idx, 0], color="tab:blue")
    axs[0].plot(ms[:, idx, 1], color="tab:orange")

    axs[1].set_title("abs_value")
    axs[1].plot(jnp.linalg.norm(ms[:, idx], axis=-1), color="tab:blue")
    
    axs[2].set_title("angle")
    axs[2].plot(angle_from_m(ms[:, idx]) / (2*jnp.pi) * 360, color="tab:blue")

    axs[3].set_title("angular velocity")
    axs[3].plot(omegas[:, idx], color="tab:blue")

    for ax in axs:
        ax.grid(True, alpha=0.3)
    fig.tight_layout()
    plt.show()

# Phase Field Method / Minimization of Energy

based on [Li 2024]: "Effect of magnetic field on macroscopic hysteresis and microscopic
magnetic domains for different ferromagnetic materials".

In [None]:
from mc2.models.physical_regularization.dipole_interaction import DipoleGrid, random_unit_vectors_2d, step_rot_ode, simulate_rot_ode

In [None]:
init_state = DipoleGrid.from_random_key(jax.random.key(12), 400)

init_state.visualize()
plt.show()

init_state

In [None]:
init_state.shape

In [None]:
def get_landau_potential_energy(
    state: DipoleGrid,
    alpha: float,
    beta_1: float,
    beta_2: float,
    gamma_1: float,
    gamma_2: float,
    delta_T: float=200.0,
) -> jax.Array:

    m_x, m_y = (state.m[..., 0], state.m[..., 1])
    
    E = (
        0.5 * alpha * delta_T * (m_x**2 + m_y**2)
        + 0.25 * beta_1 * (m_x**4 + m_y**4)
        + 0.25 * beta_2 * (m_x**2*m_y**2)
        + 1/6 * gamma_1 * (m_x**6 + m_y**6)
        + 1/6 * gamma_2 * (m_x**4*m_y**2 + m_y**4*m_x**2)
    )
    return 1 / (state.n_elements**2) * jnp.sum(E)


def get_exchange_energy(
    state: DipoleGrid, A: float, M_s: float
) -> jax.Array:

    m_x, m_y = (state.m[..., 0], state.m[..., 1])

    dMx_dx = (jnp.roll(m_x, shift=-1, axis=1) - jnp.roll(m_x, shift=1, axis=1))
    dMy_dx = (jnp.roll(m_y, shift=-1, axis=1) - jnp.roll(m_y, shift=1, axis=1))

    dMx_dy = (jnp.roll(m_x, shift=-1, axis=0) - jnp.roll(m_x, shift=1, axis=0))
    dMy_dy = (jnp.roll(m_y, shift=-1, axis=0) - jnp.roll(m_y, shift=1, axis=0))

    return A * jnp.sum(dMx_dx**2 + dMx_dy**2 + dMy_dx**2 + dMy_dy**2)


def get_magentization_energy(state: DipoleGrid, ext_field: jax.Array, mu_0: float) -> jax.Array:
    m_x, m_y = (state.m[..., 0], state.m[..., 1])
    h_x, h_y = (ext_field[0], ext_field[1])

    E = 0.5 * mu_0 * (h_x**2 + h_y**2) 
    E += mu_0 / (2 * state.n_elements**2) + jnp.sum(h_x * m_x + h_y * m_y)
    return E


def get_demagnetization_energy(state: DipoleGrid, ext_field: jax.Array, N_d: float, mu_0: float) -> jax.Array:
    m_x, m_y = (state.m[..., 0], state.m[..., 1])
    h_x, h_y = (ext_field[0], ext_field[1])

    E = 0.5 * mu_0 * N_d * (jnp.sum(h_x * m_x + h_y * m_y) + jnp.sum(m_x**2 + m_y**2))
    return E

def get_anisotropic_energy(state: DipoleGrid, K_1: float, M_s: float) -> jax.Array:
    m_x, m_y = (state.m[..., 0], state.m[..., 1])
    E = K_1 / (state.n_elements**2 * M_s**4) * jnp.sum(m_x**2 * m_y**2)
    return E

def get_confinement_energy(state: DipoleGrid, A_s: float, M_s: float) -> jax.Array:
    m_x, m_y = (state.m[..., 0], state.m[..., 1])
    E = A_s / state.n_elements**2 * jnp.sum((jnp.sqrt(m_x**2 + m_y**2) - M_s)**2)
    return E
    
def get_eddy_energy(state: DipoleGrid, K_2: float) -> jax.Array:
    m_x, m_y = (state.m[..., 0], state.m[..., 1])
    E = K_2 / (state.n_elements**2) * jnp.sum(m_x**2 * m_y**2)
    return E

def get_strain_energy(state: DipoleGrid, K_3: float, K_4: float, E_11: float, E_22: float) -> jax.Array:
    m_x, m_y = (state.m[..., 0], state.m[..., 1])
    E = K_3 * (E_11 * m_x**2 + E_22 * m_y**2) + K_4 * (E_11 * m_y**2 + E_22 * m_x**2)
    return jnp.sum(E)

In [None]:
@eqx.filter_jit
def get_gibbs_energy(state: DipoleGrid, ext_field: jax.Array, params: dict[str, float]) -> jax.Array:
    E_ld = get_landau_potential_energy(
        state,
        alpha=params["alpha"],
        delta_T=200.0,
        beta_1=params["beta_1"],
        beta_2=params["beta_2"],
        gamma_1=params["gamma_1"],
        gamma_2=params["gamma_2"],
     )
    E_exc = get_exchange_energy(
        state,
        A=params["A"],
        M_s=params["M_s"],
    )
    E_ma = get_magentization_energy(
        state, ext_field, mu_0=params["mu_0"],
    )
    E_de = get_demagnetization_energy(
        state, ext_field, N_d=params["N_d"], mu_0=params["mu_0"],
    )
    E_an = get_anisotropic_energy(
        state, K_1=params["K_1"], M_s=params["M_s"]
    )
    E_co = get_confinement_energy(state, A_s=params["A_s"], M_s=params["M_s"])
    E_ed = get_eddy_energy(state, K_2=params["K_2"])
    E_st = get_strain_energy(state, K_3=params["K_3"], K_4=params["K_4"], E_11=params["E_11"], E_22=params["E_22"])

    E_g = E_ld + E_exc + E_ma + E_de #+ E_an + E_co + E_ed #+ E_st
    return E_g, jnp.array([E_ld, E_exc, E_ma, E_de, E_an, E_co, E_ed, E_st])

In [None]:
params = dict(
    alpha=8.99*1e-10,
    beta_1=5.06*1e-19,
    beta_2=-4.87*1e-19,
    gamma_1=8.59*1e-29,
    gamma_2=6.59*1e-30,
    K_1=-5.35*1e4,
    K_2=3.31*1e-10,
    K_3=-1.95*1e-17,
    K_4=-2.10*1e-17,
    E_11=2*1e10,
    E_22=2*1e10,
    A=-3.54*1e-12,
    L=-1.572*1e5,
    M_s=1.33*1e6,
    N_d=1e-4,
    A_s=-6.2*1e-10,
    mu_0=4*jnp.pi*1e-7,
)

In [None]:
get_gibbs_energy(init_state, ext_field = jnp.array([0.0, 0.0]), params=params)

In [None]:
def get_gibbs_energy_expl_m(m: jax.Array, state: DipoleGrid, ext_field: jax.Array, params: dict):
    state = eqx.tree_at(lambda s: s.m, state, m)
    return get_gibbs_energy(state, ext_field, params)[0]

de_dm = eqx.filter_grad(get_gibbs_energy_expl_m)

In [None]:
def dmdt(state, ext_field, params):
    L = params["L"]
    return L * de_dm(state.m, state, ext_field=ext_field, params=params)

@eqx.filter_jit
def euler_step(state, ext_field, params, tau):
    m = state.m
    next_m = m + tau * dmdt(state, ext_field, params)
    next_state = eqx.tree_at(lambda s: s.m, state, next_m)
    return next_state

In [None]:
n_elements = 40_000
tau = 1e-6

# ext_fields = jnp.zeros((n_elements,2))

ext_fields = jnp.stack(
    [
        jnp.zeros(n_elements),
        jnp.sin(jnp.linspace(0, 4*jnp.pi, n_elements)) * 1000 * jnp.linspace(0, 1, n_elements),
        
    ],
    axis=-1
)

# ext_fields = jnp.concatenate([
#     ext_fields, 
#     jnp.stack(
#         [
#             jnp.zeros(n_elements),
#             jnp.concatenate([jnp.linspace(0, 20, 100), jnp.linspace( 20, 0, 100), jnp.linspace(0, - 20, 100), jnp.linspace(- 20, 0, 100), jnp.linspace(-0, 400, n_elements-400)])
            
#         ],
#         axis=-1
#     ),
# ], axis=0)

# ext_fields = jnp.stack(
#     [
#         jnp.zeros(n_elements),
#         -jnp.linspace(0, 0.005, n_elements)
#     ],
#     axis=-1
# )

state = init_state
state.visualize()
plt.show()
init_energy = get_gibbs_energy(init_state, ext_fields[0], params=params)[0]
print("Starting gibbs energy:", init_energy)


states = [state]
ms = [state.m]
omegas = [state.omega]
energies = [init_energy]

for i, ext_field in tqdm(enumerate(ext_fields), total=n_elements):
    next_state = euler_step(state, ext_field, params, tau)
    state = next_state

    energy = get_gibbs_energy(state, ext_field, params=params)[0]

    energies.append(energy)
    states.append(state)
    ms.append(state.m)

ms = jnp.stack(ms)
energies = jnp.stack(energies)

state.visualize()
plt.show()
print("Final gibbs energy:", get_gibbs_energy(states[-1], ext_fields[-1], params=params))

In [None]:
M = jnp.sum(ms, axis=(-3, -2)) / state.n_elements**2
plt.plot(M)
plt.show()

plt.plot(ext_fields)
plt.show()
plt.plot(energies)

In [None]:
#plt.plot(H[:-1, 0], ext_fields[..., 0])
plt.plot(ext_fields[..., 1], M[:-1, 1])
plt.xlabel("H_ext")
plt.ylabel("M")

In [None]:
#plt.plot(H[:-1, 0], ext_fields[..., 0])
plt.plot(-ext_fields[..., 1], M[:-1, 1] / 1500 + ext_fields[..., 1] * params["mu_0"])
plt.xlabel("H_ext")
plt.ylabel("B")

In [None]:
raise

In [None]:
ms.shape

In [None]:
ms_norm = jnp.stack([ms[..., 0] / jnp.max(jnp.abs(ms[..., 0])), ms[..., 1] / jnp.max(jnp.abs(ms[..., 1]))], axis=-1)

In [None]:
from ipywidgets import interact, IntSlider

In [None]:
data = deepcopy(ms_norm)

T, nx, ny, _ = data.shape
x = np.arange(nx)
y = np.arange(ny)


def plot_quiver(k):
    fig, axs = plt.subplots(4, 1, figsize=(8, 12))
    
    # Quiver placeholder
    U0 = data[k, :, :, 0]
    V0 = data[k, :, :, 1]
    quiv = axs[0].quiver(x, y, U0, V0)
    
    axs[0].set_xlim(-0.5, nx - 0.5)
    axs[0].set_ylim(-0.5, ny - 0.5)
    axs[0].set_aspect("equal")
    
    m_x = data[..., 0]
    for m_i in m_x.reshape((-1, nx*ny)).T:
        axs[1].plot(m_i, alpha=.4, c="tab:orange")
    #axs[1].plot(jnp.sum(m_x, axis=(-1, -2)), c="tab:orange")
    
    ymin1, ymax1 = axs[1].get_ylim()
    vline1, = axs[1].plot([k, k], [ymin1, ymax1], color="r", linestyle=":", linewidth=1)

    m_y = data[..., 1]
    for m_i in m_y.reshape((-1, nx*ny)).T:
        axs[2].plot(m_i, alpha=.4, c="tab:green")
    #axs[2].plot(jnp.sum(m_y, axis=(-1, -2)), c="tab:green")
    
    axs[3].plot(-ext_fields[..., 1], M[:-1, 1], c="tab:orange")

    # Plot the cross marker at the selected coordinate
    axs[3].plot(-ext_fields[k, 1], M[:-1, 1][k], 
            marker='x', 
            markersize=2, 
            markeredgewidth=2, 
            color='red', 
            label=f'Selected Point (k={k})')
    axs[3].legend()
    
    for ax in axs:
        ax.grid(alpha=0.3)
    
    plt.tight_layout()

    plt.tight_layout()
    plt.show()

In [None]:
interact(
    plot_quiver,
    k=IntSlider(min=0, max=T-1, step=1, value=0, description="timestep")
);

In [None]:
from matplotlib.animation import FuncAnimation, PillowWriter

In [None]:
sub_sampling = 5

data = deepcopy(ms_norm[::sub_sampling, ...])
T, nx, ny, _ = data.shape

# Create the grid for quiver
x = np.arange(nx)
y = np.arange(ny)
X, Y = np.meshgrid(x, y, indexing="ij")

# --- Set up the figure ---
fig, axs = plt.subplots(4, 1, figsize=(6, 8))

# Quiver placeholder
U0 = data[0, :, :, 0]
V0 = data[0, :, :, 1]
quiv = axs[0].quiver(X, Y, U0, V0)

axs[0].set_xlim(-0.5, nx - 0.5)
axs[0].set_ylim(-0.5, ny - 0.5)
axs[0].set_aspect("equal")

m_x = data[..., 0]
for m_i in m_x.reshape((-1, nx*ny)).T:
    axs[1].plot(m_i, alpha=.4, c="tab:orange")
axs[1].plot(jnp.sum(m_x, axis=(-1, -2)), c="tab:orange")

ymin1, ymax1 = axs[1].get_ylim()
vline1, = axs[1].plot([0, 0], [ymin1, ymax1], color="r", linestyle=":", linewidth=1)


m_y = data[..., 1]
for m_i in m_y.reshape((-1, nx*ny)).T:
    axs[2].plot(m_i, alpha=.4, c="tab:green")
axs[2].plot(jnp.sum(m_y, axis=(-1, -2)), c="tab:green")
ymin3, ymax3 = axs[2].get_ylim()
vline3, = axs[2].plot([0, 0], [ymin3, ymax3], color="r", linestyle=":", linewidth=1)


# axs[3].plot(H_pred[0], c="tab:orange", label="H_hat")
# axs[3].plot(jnp.squeeze(H_future[0]), c="tab:blue", label="H_true")
# ymin2, ymax2 = axs[3].get_ylim()
# vline2, = axs[3].plot([0, 0], [ymin2, ymax2], color="r", linestyle=":", linewidth=1)
# axs[3].legend()

axs[3].plot(-ext_fields[..., 1], M[:-1, 1], c="tab:orange")

# Plot the cross marker at the selected coordinate
line_cross, = axs[3].plot(-ext_fields[0, 1], M[:-1, 1][0], 
        marker='x',
        markersize=2,
        markeredgewidth=2,
        color='red',
)

for ax in axs:
    ax.grid(alpha=0.3)

plt.tight_layout()

# --- Animation update function ---
def update(k):
    # Update quiver
    U = data[k, :, :, 0]
    V = data[k, :, :, 1]
    quiv.set_UVC(U, V)

    # Update vertical line
    vline1.set_xdata([k, k])
    vline1.set_ydata([ymin1, ymax1])

    vline3.set_xdata([k, k])
    vline3.set_ydata([ymin3, ymax3])

    line_cross.set_data([-ext_fields[k*sub_sampling, 1]], [M[:-1, 1][k*sub_sampling]])

    # Update titles
    axs[0].set_title(f"Vector field (t={k})")
    return quiv, vline1, vline3, line_cross

# --- Create animation ---
anim = FuncAnimation(fig, update, frames=T, interval=200, blit=False)

# --- Save as GIF ---
anim.save("vector_field_timeseries.gif", dpi=120, writer=PillowWriter(fps=25))

print("GIF saved as vector_field_timeseries.gif")