In [None]:
import os

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import torch

from symplearn.datasets import VectorFieldDataset, SnapshotDataset
from symplearn.numerics import EulerDVISimulation, RK4Simulation, QuasiExactSimulation

from utils import (
    HandlerColormap,
    colored_line,
    find_index_period,
    get_study_init_cond,
    load_models,
    GuidingCenter,
)

torch.set_default_dtype(torch.float64)
torch.set_default_device("cpu")

model_ref = load_models()["ref"]
t0, z0, dt = get_study_init_cond()

labels = ["BP", "BT", "WT", "DT"]

z_val, t_val, dt_z_val = VectorFieldDataset("val")[:]

In [None]:
SMALL_SIZE = 7
MEDIUM_SIZE = 8
BIGGER_SIZE = 9

plt.rc("axes", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=BIGGER_SIZE, title_fontsize=BIGGER_SIZE)  # legend fontsize

# colorblind friendly colormap
cmap0 = mpl.colormaps["Blues_r"]
c0a, c0b, c0c, c0d = cmap0([0.1, 0.25, 0.4, 0.6])
cmap1 = mpl.colormaps["Greens_r"]
c1a, c1b, c1c, c1d = cmap1([0.1, 0.25, 0.4, 0.6])
cmap2 = mpl.colormaps["Oranges_r"]
c2a, c2b, c2c, c2d = cmap2([0.1, 0.25, 0.4, 0.6])
cmap3 = mpl.colormaps["Purples_r"]
c3a, c3b, c3c, c3d = cmap3([0.1, 0.25, 0.4, 0.5])

c0 = mpl.colormaps["tab20c"](0)
c0light = mpl.colormaps["tab20c"](2)

c1 = mpl.colormaps["tab20c"](8)
c1light = mpl.colormaps["tab20c"](9)

c2 = mpl.colormaps["tab20c"](5)
c2light = mpl.colormaps["tab20c"](6)

c3 = mpl.colormaps["tab20c"](16)
c3light = mpl.colormaps["tab20c"](18)

colors = np.array([c0a, c1b, c2c, c3d])
colors_dark = np.array([c0a, c1a, c2a, c3a])
colors_light = np.array([c0c, c1c, c2c, c3c])
# colors = np.array([c0, c1, c2, c3])
# colors_light = np.array([c0light, c1light, c2light, c3light])
markers = ["o", "s", "D", "^"]

In [None]:
def simulate(
    name,
    model=GuidingCenter(),
    scheme=QuasiExactSimulation,
    z0=z0,
    dt=1e-3 * dt,
    nt=70_000,
):
    t_fname = os.path.join("simul", "t_" + name + ".pt")
    z_fname = os.path.join("simul", "z_" + name + ".pt")
    if os.path.exists(t_fname) and os.path.exists(z_fname):
        t_sim = torch.load(t_fname, weights_only=True)
        z_sim = torch.load(z_fname, weights_only=True)

    else:
        sim = scheme(model, dt)
        t_sim, z_sim = sim.simulate(z0, nt)
        t_sim, z_sim = t_sim.detach(), z_sim.detach()
        torch.save(t_sim, t_fname)
        torch.save(z_sim, z_fname)

    return t_sim.numpy(), z_sim.numpy()


t_ex, z_ex = simulate("ref_ex")
th_ex, r_ex, u_ex = z_ex[..., 0], z_ex[..., 2], z_ex[..., 3]

R_ex, Z_ex = model_ref.r0 + r_ex * np.cos(th_ex), r_ex * np.sin(th_ex)
th_ex = (th_ex + np.pi) % (2 * np.pi) - np.pi

nf_ex = [find_index_period(RZk) for RZk in np.stack((R_ex, Z_ex), -1)]

### Comparing trajectories

In [None]:
## period of each trajectory
[float(t_ex[k, nf_ex[k]]) for k in range(len(nf_ex))]

In [None]:
def add_arrows(ax, x, y, u, v):
    xlim, ylim = ax.get_xlim(), ax.get_ylim()
    x_range, y_range = xlim[1] - xlim[0], ylim[1] - ylim[0]

    n_traj = len(x)
    n_markers = 6
    for k, (xk, yk, uk, vk) in enumerate(zip(x, y, u, v)):
        marker_idx = np.arange(n_markers) * 3 * n_traj + 2 * k
        marker_idx = (marker_idx * nf_ex[k]) // (n_markers * 3 * n_traj)
        for i in marker_idx:
            t = mpl.markers.MarkerStyle(marker=">")
            angle = np.arctan2(vk[i] / y_range, uk[i] / x_range)
            t._transform = t.get_transform().rotate(angle)
            ax.scatter(xk[i], yk[i], marker=t, s=30, color=colors[k], zorder=2.5)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3.4), width_ratios=(3, 4))

for k, (Rk, Zk, thk, uk) in enumerate(zip(R_ex, Z_ex, th_ex, u_ex)):
    nk = nf_ex[k]
    Rk, Zk, thk, uk = Rk[:nk], Zk[:nk], thk[:nk], uk[:nk]
    kwargs = {"lw": 1.5, "color": colors[k]}

    ax1.plot(Rk, Zk, label=labels[k], **kwargs)

    # detect discontinuities in theta
    (idx,) = np.where(thk[1:] - thk[:-1] > 1.9 * np.pi)
    idx = np.concatenate(([0], idx + 1))

    for i0, i1 in zip(idx[:-1], idx[1:]):
        ax2.plot(thk[i0:i1], uk[i0:i1], **kwargs)
    ax2.plot(thk[idx[-1] :], uk[idx[-1] :], label=labels[k], **kwargs)

ax1.set(xlabel="$R$", ylabel="$Z$")
ax2.set(xlabel="$\\theta$", ylabel="$u$")

fig.tight_layout(pad=0.2)
ax2.legend()

dt_z = model_ref.vector_field(torch.from_numpy(z_ex), None).numpy()
dt_th, dt_r, dt_u = dt_z[..., 0], dt_z[..., 2], dt_z[..., 3]
dt_R = dt_r * np.cos(th_ex) - r_ex * np.sin(th_ex) * dt_th
dt_Z = dt_r * np.sin(th_ex) + r_ex * np.cos(th_ex) * dt_th

add_arrows(ax1, R_ex, Z_ex, dt_R, dt_Z)
add_arrows(ax2, th_ex, u_ex, dt_th, dt_u)

fig.tight_layout(pad=0.2)
fig.savefig(os.path.join("figures", "intro_traj.pdf"))

### Studying the schemes 

In [None]:
t_rk4, z_rk4 = simulate("ref_rk4", scheme=RK4Simulation, dt=dt, nt=30_000)

x_rk4, y_rk4 = torch.from_numpy(z_rk4).tensor_split(2, -1)
h_rk4 = model_ref.hamiltonian(x_rk4, y_rk4, None).numpy()
h_rk4 = (h_rk4 - h_rk4[:, :1]) / h_rk4[:, :1]

th_rk4, _, r_rk4, u_rk4 = z_rk4[..., 0], z_rk4[..., 1], z_rk4[..., 2], z_rk4[..., 3]
R_rk4, Z_rk4 = model_ref.r0 + r_rk4 * np.cos(th_rk4), r_rk4 * np.sin(th_rk4)
th_rk4 = (th_rk4 + np.pi) % (2 * np.pi) - np.pi

nf_rk4 = [
    find_index_period(RZk, eps_dist=1.0, eps_diff=0.0)
    for RZk in np.stack((R_rk4, Z_rk4), -1)
]
nf_rk4

In [None]:
t_dvi, z_dvi = simulate("ref_dvi", scheme=EulerDVISimulation, dt=dt, nt=30_000)

x_dvi, y_dvi = torch.from_numpy(z_dvi).tensor_split(2, -1)
h_dvi = model_ref.hamiltonian(x_dvi, y_dvi, None).numpy()
h_dvi = (h_dvi - h_dvi[:, :1]) / h_dvi[:, :1]

th_dvi, _, r_dvi, u_dvi = z_dvi[..., 0], z_dvi[..., 1], z_dvi[..., 2], z_dvi[..., 3]
R_dvi, Z_dvi = model_ref.r0 + r_dvi * np.cos(th_dvi), r_dvi * np.sin(th_dvi)
th_dvi = (th_dvi + np.pi) % (2 * np.pi) - np.pi

nf_dvi = [
    find_index_period(RZk, eps_dist=2e-2, eps_diff=0.0)
    for RZk in np.stack((R_dvi, Z_dvi), -1)
]
nf_dvi = np.array(nf_dvi) + 2

### Long-time integration

In [None]:
fig = plt.figure(figsize=(8, 4))

ax_rz = plt.subplot2grid((3, 2), (0, 0), rowspan=3)
ax_thu = plt.subplot2grid((3, 2), (0, 1), rowspan=2)
ax_h = plt.subplot2grid((3, 2), (2, 1), rowspan=1)

R_rk4_plt, Z_rk4_plt = R_rk4[0, ::5], Z_rk4[0, ::5]
th_rk4_plt, u_rk4_plt = th_rk4[0, ::5], u_rk4[0, ::5]
c_rk4_plt = t_rk4[0, ::5]

nf_ex_plt = nf_ex[0]
R_ex_plt, Z_ex_plt = R_ex[0, :nf_ex_plt], Z_ex[0, :nf_ex_plt]
th_ex_plt, u_ex_plt = th_ex[0, :nf_ex_plt], u_ex[0, :nf_ex_plt]

nf_dvi_plt = nf_dvi[0] - 1
R_dvi_plt, Z_dvi_plt = R_dvi[0, :nf_dvi_plt], Z_dvi[0, :nf_dvi_plt]
th_dvi_plt, u_dvi_plt = th_dvi[0, :nf_dvi_plt], u_dvi[0, :nf_dvi_plt]

ax_rz.scatter(R_rk4_plt, Z_rk4_plt, c=c_rk4_plt, s=0.5, cmap="plasma_r", label="RK4")
ax_rz.plot(R_dvi_plt, Z_dvi_plt, color="gray", marker="o", label="DVI")
ax_rz.plot(R_ex_plt, Z_ex_plt, color="black", ls="dashed", lw=1.5, label="exact")

ax_rz.set_xlabel("$R$")
ax_rz.set_ylabel("$Z$")

ax_thu.scatter(th_rk4_plt, u_rk4_plt, s=0.5, c=c_rk4_plt, cmap="plasma_r")

(idx,) = np.where(np.diff(th_dvi_plt) > 1.5 * np.pi)
idx = np.concatenate(([0], idx + 1))

for i0, i1 in zip(idx[:-1], idx[1:]):
    ax_thu.plot(th_dvi_plt[i0:i1], u_dvi_plt[i0:i1], color="gray", marker="o")
ax_thu.plot(th_dvi_plt[idx[-1] :], u_dvi_plt[idx[-1] :], color="gray", marker="o")

(idx,) = np.where(np.diff(th_ex_plt) > 1.5 * np.pi)
idx = np.concatenate(([0], idx + 1))

for i0, i1 in zip(idx[:-1], idx[1:]):
    ax_thu.plot(th_ex_plt[i0:i1], u_ex_plt[i0:i1], color="black", lw=1, ls="dashed")
ax_thu.plot(th_ex_plt[idx[-1] :], u_ex_plt[idx[-1] :], color="black", lw=1, ls="dashed")

ax_thu.set_xlabel("$\\theta$")
ax_thu.set_ylabel("$u$")

ax_h.plot(t_dvi[0, ::6], h_dvi[0, ::6], c="gray", lw=0.5)
# ax_h.scatter(
#     t_rk4_plt,
#     h_rk4_plt,
#     s=0.5,
#     c=t_rk4_plt,
#     cmap="plasma_r",
#     zorder=2.01,
# )
t_rk4_plt, h_rk4_plt = t_rk4[0, ::21], h_rk4[0, ::21]
ax_h.plot(t_rk4_plt, h_rk4_plt, alpha=0.0, label=None)  # dummy line for ylims
colored_line(t_rk4_plt, h_rk4_plt, t_rk4_plt, ax_h, cmap="plasma_r")
ax_h.set_xlabel("$t$")
ax_h.set_ylabel("rel. err. on $H$")

handles, plt_labels = ax_rz.get_legend_handles_labels()
ax_rz.legend(
    handles=handles, 
    labels=plt_labels, 
    # special treatment for the multicolored line
    handler_map={handles[0]: HandlerColormap(mpl.colormaps["plasma_r"])},
)

# ax_rz.legend()

fig.tight_layout(pad=0.5)
fig.savefig(os.path.join("figures", "intro_ref.pdf"))

### (Mis-)characterization of the trajectories

In [None]:
step = 75
param_ex = dict(lw=2, ls="dashed")
param_dvi = dict(s=3, zorder=2.1, alpha=0.5)

fig, ax = plt.subplots(figsize=(4, 4))

for k in [2, 0]:
    ck, ck_ex = colors_light[k], colors_dark[k]
    # ck, ck_ex = colors[k], colors_light[k]
    param_ex["color"] = ck_ex
    Rk_ex, Zk_ex = R_ex[k, : nf_ex[k]], Z_ex[k, : nf_ex[k]]
    ax.plot(Rk_ex, Zk_ex, label=labels[k] + " (exact)", **param_ex)

    # Rk_dvi, Zk_dvi = R_dvi[k, :nf_dvi[k]], Z_dvi[k, :nf_dvi[k]]
    # ax.plot(Rk_dvi, Zk_dvi, marker=mk, ms=4, label=labels[k] + "(DVI)", **param)
    param_dvi["color"] = ck
    Rk_dvi, Zk_dvi = R_dvi[k, ::step], Z_dvi[k, ::step]
    ax.scatter(Rk_dvi, Zk_dvi, label=labels[k] + " (DVI)", **param_dvi)

ax.set_xlabel("$R$")
ax.set_ylabel("$Z$")
ax.legend(bbox_to_anchor=(0.15, 0.5))

fig.tight_layout(pad=0.5)
fig.savefig(os.path.join("figures", "intro_chara_dvi1.pdf"))

In [None]:
fig, ax = plt.subplots(figsize=(4, 4))

for k in [1, 3]:
    ck, ck_ex = colors_light[k], colors_dark[k]
    # ck, ck_ex = colors[k], colors_light[k]
    param_ex["color"] = ck_ex
    Rk_ex, Zk_ex = R_ex[k, : nf_ex[k]], Z_ex[k, : nf_ex[k]]
    ax.plot(Rk_ex, Zk_ex, label=labels[k] + " (exact)", **param_ex)

    # Rk_dvi, Zk_dvi = R_dvi[k, :nf_dvi[k]], Z_dvi[k, :nf_dvi[k]]
    # ax.plot(Rk_dvi, Zk_dvi, marker=mk, ms=4, label=labels[k] + "(DVI)", **param)
    param_dvi["color"] = ck
    Rk_dvi, Zk_dvi = R_dvi[k, ::step], Z_dvi[k, ::step]
    ax.scatter(Rk_dvi, Zk_dvi, label=labels[k] + " (DVI)", **param_dvi)

ax.set_xlabel("$R$")
ax.set_ylabel("$Z$")
ax.legend(bbox_to_anchor=(0.15, 0.5))

fig.tight_layout(pad=0.5)
fig.savefig(os.path.join("figures", "intro_chara_dvi2.pdf"))