In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import torch

from symplearn.datasets import VectorFieldDataset, SnapshotDataset
from symplearn.numerics import EulerDVISimulation, RK4Simulation, QuasiExactSimulation

from utils import (
    find_index_period,
    get_study_init_cond,
    load_models,
    colored_line,
    HandlerColormap,
    GuidingCenter,
)

torch.set_default_dtype(torch.float64)
torch.set_default_device("cpu")

models = load_models()
t0, z0, dt = get_study_init_cond()

titles = ["Barely Passing", "Barely Trapped", "Well Trapped", "Deeply Trapped"]
labels = ["BP", "BT", "WT", "DT"]

z_val, t_val, dt_z_val = VectorFieldDataset("val")[:]

In [None]:
SMALL_SIZE = 7
MEDIUM_SIZE = 8
BIGGER_SIZE = 9
HUGE_SIZE = 11

plt.rc("axes", labelsize=MEDIUM_SIZE, titlesize=HUGE_SIZE)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=BIGGER_SIZE, title_fontsize=BIGGER_SIZE)  # legend fontsize

# colorblind friendly colormap
cmap0 = mpl.colormaps["Blues_r"]
c0a, c0b, c0c, c0d = cmap0([0.1, 0.25, 0.4, 0.6])
cmap1 = mpl.colormaps["Greens_r"]
c1a, c1b, c1c, c1d = cmap1([0.1, 0.25, 0.4, 0.6])
cmap2 = mpl.colormaps["Oranges_r"]
c2a, c2b, c2c, c2d = cmap2([0.1, 0.25, 0.4, 0.6])
cmap3 = mpl.colormaps["Purples_r"]
c3a, c3b, c3c, c3d = cmap3([0.1, 0.25, 0.4, 0.5])

colors = np.array([c0a, c1b, c2c, c3d])
colors_dark = np.array([c0a, c1a, c2a, c3a])
colors_light = np.array([c0c, c1c, c2c, c3c])
markers = ["o", "s", "D", "^"]

In [None]:
def simulate(
    name,
    model=GuidingCenter(),
    scheme=QuasiExactSimulation,
    z0=z0,
    dt=1e-3 * dt,
    nt=70_000,
):
    t_fname = os.path.join("simul", "t_" + name + ".pt")
    z_fname = os.path.join("simul", "z_" + name + ".pt")
    if os.path.exists(t_fname) and os.path.exists(z_fname):
        t_sim = torch.load(t_fname, weights_only=True)
        z_sim = torch.load(z_fname, weights_only=True)

    else:
        sim = scheme(model, dt)
        t_sim, z_sim = sim.simulate(z0, nt)
        t_sim, z_sim = t_sim.detach(), z_sim.detach()
        torch.save(t_sim, t_fname)
        torch.save(z_sim, z_fname)

    return t_sim.numpy(), z_sim.numpy()


t_ex, z_ex = simulate("ref_ex")
th_ex, r_ex, u_ex = z_ex[..., 0], z_ex[..., 2], z_ex[..., 3]

R_ex, Z_ex = models["ref"].r0 + r_ex * np.cos(th_ex), r_ex * np.sin(th_ex)
th_ex = (th_ex + np.pi) % (2 * np.pi) - np.pi

nf_ex = [find_index_period(RZk) for RZk in np.stack((R_ex, Z_ex), -1)]

### Vector-field learning

In [None]:
t_vf_ex, z_vf_ex = simulate("vf_ex", model=models["vf_reg"])

th_vf_ex, r_vf_ex, u_vf_ex = z_vf_ex[..., 0], z_vf_ex[..., 2], z_vf_ex[..., 3]
R_vf_ex = GuidingCenter().r0 + r_vf_ex * np.cos(th_vf_ex)
Z_vf_ex = r_vf_ex * np.sin(th_vf_ex)
th_vf_ex = (th_vf_ex + np.pi) % (2 * np.pi) - np.pi

x_vf_ex, y_vf_ex = torch.from_numpy(z_vf_ex).tensor_split(2, -1)
h_vf_ex = GuidingCenter().hamiltonian(x_vf_ex, y_vf_ex, None).numpy()
h_vf_ex = (h_vf_ex - h_vf_ex[:, :1]) / h_vf_ex[:, :1]

nf_vf_ex = [
    find_index_period(RZk, eps_dist=2e-2, eps_diff=0.0)
    for RZk in np.stack((R_vf_ex, Z_vf_ex), -1)
]
nf_vf_ex = np.array(nf_vf_ex) + 2

In [None]:
t_vf, z_vf = simulate(
    "vf_dvi", model=models["vf_reg"], scheme=EulerDVISimulation, dt=dt, nt=30_000
)

x_vf, y_vf = torch.from_numpy(z_vf).tensor_split(2, -1)
h_vf = GuidingCenter().hamiltonian(x_vf, y_vf, None).numpy()
h_vf = (h_vf - h_vf[:, :1]) / h_vf[:, :1]

th_vf, _, r_vf, u_vf = z_vf[..., 0], z_vf[..., 1], z_vf[..., 2], z_vf[..., 3]
R_vf, Z_vf = GuidingCenter().r0 + r_vf * np.cos(th_vf), r_vf * np.sin(th_vf)
th_vf = (th_vf + np.pi) % (2 * np.pi) - np.pi

nf_vf = [
    find_index_period(RZk, eps_dist=2e-2, eps_diff=0.0)
    for RZk in np.stack((R_vf, Z_vf), -1)
]
nf_vf = np.array(nf_vf) + 2

In [None]:
def convert_to_rz(z, model=GuidingCenter()):
    th, r = z[..., 0], z[..., 2]
    R = model.r0 + r * np.cos(th)
    Z = r * np.sin(th)
    return R, Z


def plot_results_nn(z_nn_ex, nf_nn_ex, z_nn=None, step=50, title=None):
    def plot_single_traj_nn(ax, k):
        ck, ck_ex = colors_light[k], colors_dark[k]

        param_ex = dict(lw=2, color=(0.8, 0.8, 0.8, 1.0), ls="solid", alpha=1.0)
        Rk_ex, Zk_ex = R_ex[k, : nf_ex[k]], Z_ex[k, : nf_ex[k]]
        ax.plot(Rk_ex, Zk_ex, label=None, **param_ex)

        param_nn_ex = dict(lw=2.5, color=ck_ex, ls="dashed")
        Rk_nn_ex, Zk_nn_ex = convert_to_rz(z_nn_ex[k, : nf_nn_ex[k]])
        ax.plot(Rk_nn_ex, Zk_nn_ex, label=None, **param_nn_ex)

        if z_nn is not None:
            param_nn = dict(s=0.8, color=ck, alpha=0.7, zorder=2.01)
            Rk_nn, Zk_nn = convert_to_rz(z_nn[k, ::step])
            ax.scatter(Rk_nn, Zk_nn, label=None, **param_nn)

    # figures = []
    fig, ax = plt.subplots(1, 4, figsize=(10, 3.25), sharey=True, sharex=True)
    for k in range(4):
        # fig, ax = plt.subplots(figsize=(2.7, 3.5))
        plot_single_traj_nn(ax[k], k)
        ax[k].set_xlabel("$R$")
        # ax.set_ylabel("$Z$")

        if k == 0:
            ax[k].plot([], [], label="exact", color=(0.3, 0.3, 0.3), ls="dashed")
            if z_nn is not None:
                ax[k].scatter([], [], label="DVI", color=(0.3, 0.3, 0.3), s=0.8)
            ax[k].plot([], [], label="ref. sol.", color=(0.8, 0.8, 0.8), lw=2)
            ax[k].legend(title=title)
        ax[k].set_title(titles[k])

        # fig.tight_layout()
        # figures.append(fig)
    # return figures
    fig.supylabel("$Z$", fontsize=MEDIUM_SIZE)
    fig.tight_layout()
    return fig

In [None]:
figs_vf = plot_results_nn(z_vf_ex, nf_vf_ex, z_vf, title="VF learning")

fname = os.path.join("figures", "vf_results_rz.pdf")
figs_vf.savefig(fname, bbox_inches="tight")

In [None]:
fig, (ax, ax_ex) = plt.subplots(2, 1, figsize=(8, 3.5), sharex=True)

for k in range(4):
    nk = np.where(t_vf[k] == t_vf_ex[k, -1])[0][0]
    params = dict(color=colors[k], lw=2, label=labels[k])
    ex_params = dict(color=colors[k], lw=2.5, ls="dashed", label=labels[k])
    ax.plot(t_vf[k, :nk], h_vf[k, :nk], **params)
    ax_ex.plot(t_vf_ex[k], h_vf_ex[k], **ex_params)

ax_ex.set_xlabel("$t$")

# ax.set_ylabel("DVI")
# ax_ex.set_ylabel("exact")
fig.supylabel("rel. err. on $H$ (VF learning)", fontsize=BIGGER_SIZE)

ax.legend(title="DVI")
ax_ex.legend(title="exact")

fig.tight_layout()

fname = os.path.join("figures", "vf_err_h.pdf")
fig.savefig(fname, bbox_inches="tight")

#### Without regularization

In [None]:
t_cmy_ex, z_cmy_ex = simulate("cmy_ex", model=models["vf_no_reg"])

th_cmy_ex, r_cmy_ex, u_cmy_ex = z_cmy_ex[..., 0], z_cmy_ex[..., 2], z_cmy_ex[..., 3]
R_cmy_ex = GuidingCenter().r0 + r_cmy_ex * np.cos(th_cmy_ex)
Z_cmy_ex = r_cmy_ex * np.sin(th_cmy_ex)
th_cmy_ex = (th_cmy_ex + np.pi) % (2 * np.pi) - np.pi

x_cmy_ex, y_cmy_ex = torch.from_numpy(z_cmy_ex).tensor_split(2, -1)
h_cmy_ex = GuidingCenter().hamiltonian(x_cmy_ex, y_cmy_ex, None).numpy()
h_cmy_ex = (h_cmy_ex - h_cmy_ex[:, :1]) / h_cmy_ex[:, :1]

nf_cmy_ex = [
    find_index_period(RZk, eps_dist=2e-2, eps_diff=0.0)
    for RZk in np.stack((R_cmy_ex, Z_cmy_ex), -1)
]
nf_cmy_ex = np.array(nf_cmy_ex) + 2

In [None]:
t_cmy, z_cmy = simulate(
    "cmy_dvi", model=models["vf_no_reg"], scheme=EulerDVISimulation, dt=dt, nt=30_000
)

x_cmy, y_cmy = torch.from_numpy(z_cmy).tensor_split(2, -1)
h_cmy = GuidingCenter().hamiltonian(x_cmy, y_cmy, None).numpy()
h_cmy = (h_cmy - h_cmy[:, :1]) / h_cmy[:, :1]

th_cmy, r_cmy, u_cmy = z_cmy[..., 0], z_cmy[..., 2], z_cmy[..., 3]
R_cmy, Z_cmy = GuidingCenter().r0 + r_cmy * np.cos(th_cmy), r_cmy * np.sin(th_cmy)
th_cmy = (th_cmy + np.pi) % (2 * np.pi) - np.pi

nf_cmy = [
    find_index_period(RZk, eps_dist=2e-2, eps_diff=0.0)
    for RZk in np.stack((R_cmy, Z_cmy), -1)
]
nf_cmy = np.array(nf_cmy) + 2

In [None]:
figs_vf = plot_results_nn(z_cmy_ex, nf_cmy_ex, z_cmy, title="VF learning")

fname = os.path.join("figures", "cmy_results_rz.pdf")
figs_vf.savefig(fname, bbox_inches="tight")

In [None]:
fig, (ax, ax_ex) = plt.subplots(2, 1, figsize=(8, 3.5), sharex=True)

for k in range(4):
    nk = np.where(t_cmy[k] == t_cmy_ex[k, -1])[0][0]
    params = dict(color=colors[k], lw=2, label=labels[k])
    ex_params = dict(color=colors[k], lw=2.5, ls="dashed", label=labels[k])
    ax.plot(t_cmy[k, :nk], h_cmy[k, :nk], **params)
    ax_ex.plot(t_cmy_ex[k], h_cmy_ex[k], **ex_params)

ax_ex.set_xlabel("$t$")
fig.supylabel("rel. err. on $H$ (VF learning)", fontsize=BIGGER_SIZE)

ax.legend(title="DVI")
ax_ex.legend(title="exact")

fig.tight_layout()

fname = os.path.join("figures", "cmy_err_h.pdf")
fig.savefig(fname, bbox_inches="tight")

### Scheme-fitted model

In [None]:
t_sch_ex, z_sch_ex = simulate("sch_ex", model=models["dvi"])

th_sch_ex, r_sch_ex, u_sch_ex = z_sch_ex[..., 0], z_sch_ex[..., 2], z_sch_ex[..., 3]
R_sch_ex = GuidingCenter().r0 + r_sch_ex * np.cos(th_sch_ex)
Z_sch_ex = r_sch_ex * np.sin(th_sch_ex)
th_sch_ex = (th_sch_ex + np.pi) % (2 * np.pi) - np.pi

x_sch_ex, y_sch_ex = torch.from_numpy(z_sch_ex).tensor_split(2, -1)
h_sch_ex = GuidingCenter().hamiltonian(x_sch_ex, y_sch_ex, None).numpy()
h_sch_ex = (h_sch_ex - h_sch_ex[:, :1]) / h_sch_ex[:, :1]

nf_sch_ex = [
    find_index_period(RZk, eps_dist=2e-2, eps_diff=0.0)
    for RZk in np.stack((R_sch_ex, Z_sch_ex), -1)
]
nf_sch_ex = np.array(nf_sch_ex) + 2

t_sch, z_sch = simulate(
    "sch_dvi", model=models["dvi"], scheme=EulerDVISimulation, dt=dt, nt=30_000
)

x_sch, y_sch = torch.from_numpy(z_sch).tensor_split(2, -1)
h_sch = GuidingCenter().hamiltonian(x_sch, y_sch, None).numpy()
h_sch = (h_sch - h_sch[:, :1]) / h_sch[:, :1]

th_sch, _, r_sch, u_sch = z_sch[..., 0], z_sch[..., 1], z_sch[..., 2], z_sch[..., 3]
R_sch, Z_sch = GuidingCenter().r0 + r_sch * np.cos(th_sch), r_sch * np.sin(th_sch)
th_sch = (th_sch + np.pi) % (2 * np.pi) - np.pi

nf_sch = [
    find_index_period(RZk, eps_dist=2e-2, eps_diff=0.0)
    for RZk in np.stack((R_sch, Z_sch), -1)
]
nf_sch = np.array(nf_sch) + 2

In [None]:
figs_sch = plot_results_nn(z_sch_ex, nf_sch_ex, z_sch, step=62, title="Sch. learning")

fname = os.path.join("figures", "sch_results_rz.pdf")
figs_sch.savefig(fname, bbox_inches="tight")

# for k, fig in enumerate(figs_sch):
#     fname = os.path.join("figures", "sch", labels[k] + ".pdf")
#     fig.savefig(fname, bbox_inches="tight")

In [None]:
fig, (ax, ax_ex) = plt.subplots(2, 1, figsize=(8, 3.5), sharex=True)

nf_h = 2 * max(nf_sch)
for k in range(4):
    nk = np.where(t_sch[k] == t_sch_ex[k, -1])[0][0]
    params = dict(color=colors[k], lw=2, label=labels[k])
    ex_params = dict(color=colors[k], lw=2.5, ls="dashed", label=labels[k])
    ax.plot(t_sch[k, :nk], h_sch[k, :nk], **params)
    ax_ex.plot(t_sch_ex[k], h_sch_ex[k], **ex_params)

ax_ex.set_xlabel("$t$")

# ax.set_ylabel("DVI")
# ax_ex.set_ylabel("exact")
fig.supylabel("rel. err. on $H$ (scheme learning)", fontsize=BIGGER_SIZE)

ax.legend(title="DVI")
ax_ex.legend(title="exact")

fig.tight_layout()

fname = os.path.join("figures", "sch_err_h.pdf")
fig.savefig(fname, bbox_inches="tight")

### VF-fitted without Gram-based norm

In [None]:
t_ng_ex, z_ng_ex = simulate("no_gram_ex", model=models["vf_no_gram"])

th_ng_ex, r_ng_ex, u_ng_ex = z_ng_ex[..., 0], z_ng_ex[..., 2], z_ng_ex[..., 3]
R_ng_ex = GuidingCenter().r0 + r_ng_ex * np.cos(th_ng_ex)
Z_ng_ex = r_ng_ex * np.sin(th_ng_ex)
th_ng_ex = (th_ng_ex + np.pi) % (2 * np.pi) - np.pi

x_ng_ex, y_ng_ex = torch.from_numpy(z_ng_ex).tensor_split(2, -1)
h_ng_ex = GuidingCenter().hamiltonian(x_ng_ex, y_ng_ex, None).numpy()
h_ng_ex = (h_ng_ex - h_ng_ex[:, :1]) / h_ng_ex[:, :1]

nf_ng_ex = [
    find_index_period(RZk, eps_dist=2e-2, eps_diff=0.0)
    for RZk in np.stack((R_ng_ex, Z_ng_ex), -1)
]
nf_ng_ex = np.array(nf_ng_ex) + 2

In [None]:
figs_ng = plot_results_nn(z_ng_ex, nf_ng_ex, step=62, title="VF without Gram")

fname = os.path.join("figures", "no_gram_results_rz.pdf")
figs_ng.savefig(fname, bbox_inches="tight")

In [None]:
fig, ax_ex = plt.subplots(figsize=(8, 2))

for k in range(4):
    ex_params = dict(color=colors[k], lw=2.5, ls="dashed", label=labels[k])
    ax_ex.plot(t_ng_ex[k], h_ng_ex[k], **ex_params)

ax_ex.set_xlabel("$t$")
ax_ex.set_ylabel("rel. err. on $H$")

ax_ex.legend(title="VF without Gram")

fig.tight_layout()

fname = os.path.join("figures", "no_gram_err_h.pdf")
fig.savefig(fname, bbox_inches="tight")

In [None]:
t_cmy_rk4, z_cmy_rk4 = simulate(
    "cmy_rk4", model=models["vf_no_reg"], scheme=RK4Simulation, dt=dt, nt=30_000
)

x_cmy_rk4, y_cmy_rk4 = torch.from_numpy(z_cmy_rk4).tensor_split(2, -1)
h_cmy_rk4 = GuidingCenter().hamiltonian(x_cmy_rk4, y_cmy_rk4, None).numpy()
h_cmy_rk4 = (h_cmy_rk4 - h_cmy_rk4[:, :1]) / h_cmy_rk4[:, :1]

th_cmy_rk4, r_cmy_rk4 = z_cmy_rk4[..., 0], z_cmy_rk4[..., 2]
R_cmy_rk4 = GuidingCenter().r0 + r_cmy_rk4 * np.cos(th_cmy_rk4)
Z_cmy_rk4 = r_cmy_rk4 * np.sin(th_cmy_rk4)
th_cmy_rk4 = (th_cmy_rk4 + np.pi) % (2 * np.pi) - np.pi
u_cmy_rk4 = z_cmy_rk4[..., 3]

nf_cmy_rk4 = [
    find_index_period(RZk, eps_dist=2e-2, eps_diff=0.0)
    for RZk in np.stack((R_cmy_rk4, Z_cmy_rk4), -1)
]
nf_cmy_rk4 = np.array(nf_cmy_rk4) + 2

In [None]:
t_vf_rk4, z_vf_rk4 = simulate(
    "vf_rk4", model=models["vf_reg"], scheme=RK4Simulation, dt=dt, nt=30_000
)

x_vf_rk4, y_vf_rk4 = torch.from_numpy(z_vf_rk4).tensor_split(2, -1)
h_vf_rk4 = GuidingCenter().hamiltonian(x_vf_rk4, y_vf_rk4, None).numpy()
h_vf_rk4 = (h_vf_rk4 - h_vf_rk4[:, :1]) / h_vf_rk4[:, :1]

th_vf_rk4, r_vf_rk4 = z_vf_rk4[..., 0], z_vf_rk4[..., 2]
R_vf_rk4 = GuidingCenter().r0 + r_vf_rk4 * np.cos(th_vf_rk4)
Z_vf_rk4 = r_vf_rk4 * np.sin(th_vf_rk4)
th_vf_rk4 = (th_vf_rk4 + np.pi) % (2 * np.pi) - np.pi
u_vf_rk4 = z_vf_rk4[..., 3]

nf_vf_rk4 = [
    find_index_period(RZk, eps_dist=2e-2, eps_diff=0.0)
    for RZk in np.stack((R_vf_rk4, Z_vf_rk4), -1)
]
nf_vf_rk4 = np.array(nf_vf_rk4) + 2

In [None]:
t_sch_rk4, z_sch_rk4 = simulate(
    "sch_rk4", model=models["dvi"], scheme=RK4Simulation, dt=dt, nt=30_000
)

x_sch_rk4, y_sch_rk4 = torch.from_numpy(z_sch_rk4).tensor_split(2, -1)
h_sch_rk4 = GuidingCenter().hamiltonian(x_sch_rk4, y_sch_rk4, None).numpy()
h_sch_rk4 = (h_sch_rk4 - h_sch_rk4[:, :1]) / h_sch_rk4[:, :1]

th_sch_rk4, r_sch_rk4 = z_sch_rk4[..., 0], z_sch_rk4[..., 2]
R_sch_rk4 = GuidingCenter().r0 + r_sch_rk4 * np.cos(th_sch_rk4)
Z_sch_rk4 = r_sch_rk4 * np.sin(th_sch_rk4)
th_sch_rk4 = (th_sch_rk4 + np.pi) % (2 * np.pi) - np.pi
u_sch_rk4 = z_sch_rk4[..., 3]

nf_sch_rk4 = [
    find_index_period(RZk, eps_dist=2e-2, eps_diff=0.0)
    for RZk in np.stack((R_sch_rk4, Z_sch_rk4), -1)
]
nf_sch_rk4 = np.array(nf_sch_rk4) + 2

In [None]:
mpl.colormaps["Greys"](0.5)

In [None]:
def plot_num_rz(ax, R_rk4, Z_rk4, t_rk4, R_dvi, Z_dvi, t_dvi, title=None, step=22):
    Ri, Zi, ti = R_rk4[0, ::step], Z_rk4[0, ::step], t_vf_rk4[0, ::step]
    handler_rk4 = HandlerColormap(mpl.colormaps["plasma_r"])
    ax.scatter(Ri, Zi, c=ti, s=0.5, cmap="plasma_r", label="RK4")

    Ri, Zi, ti = R_dvi[0, ::step], Z_dvi[0, ::step], t_dvi[0, ::step]
    cmap_dvi = mpl.colormaps["Greys"]
    colors_dvi = cmap_dvi(np.linspace(0.2, 0.8, len(ti)))
    handler_dvi = HandlerColormap(cmap_dvi)
    ax.scatter(Ri, Zi, c=colors_dvi, s=0.5, label="DVI")

    ax.set(xlabel="$R$")
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(
        title=title,
        handles=handles,
        labels=labels,
        handler_map={handles[-2]: handler_rk4, handles[-1]: handler_dvi},
        loc=(0.15, 0.4)
    )


fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(7.2, 2.7))

plot_num_rz(ax1, R_vf_rk4, Z_vf_rk4, t_vf_rk4, R_vf, Z_vf, t_vf, title="VF with reg.")
plot_num_rz(ax2, R_cmy_rk4, Z_cmy_rk4, t_cmy_rk4, R_cmy, Z_cmy, t_cmy, title="VF no reg.")
plot_num_rz(ax3, R_sch_rk4, Z_sch_rk4, t_sch_rk4, R_sch, Z_sch, t_sch, title="Sch. learning")

ax1.set(ylabel="$Z$")
ax2.set(yticklabels=[])
ax3.set(yticklabels=[])

fig.tight_layout(pad=0.2, w_pad=0.8)

fig.savefig(os.path.join("figures", "bp_nn.pdf"), bbox_inches="tight")

In [None]:
def plot_err_h(t_rk4, h_rk4, t_dvi, h_dvi, ax, title=None, **kwargs_dvi):
    i, step = labels.index("BP"), 22

    ax.plot(t_dvi[i, ::step], h_dvi[i, ::step], c="gray", label="DVI", **({"lw": 1.0} | kwargs_dvi))

    t_rk4_plt, h_rk4_plt = t_rk4[i, ::step], h_rk4[i, ::step]
    ax.plot(t_rk4_plt, h_rk4_plt, alpha=0.0, label=None)  # dummy line for ylims
    colored_line(t_rk4_plt, h_rk4_plt, t_rk4_plt, ax, label="RK4", cmap="plasma_r")

    handles, plt_labels = ax.get_legend_handles_labels()
    ax.legend(
        title=title,
        handles=handles, 
        labels=plt_labels, 
        # special treatment for the multicolored line
        handler_map={handles[-1]: HandlerColormap(mpl.colormaps["plasma_r"])},
    )

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(7.2, 2.2))

i, step = labels.index("BP"), 22
ti = t_vf[i, ::step]

plot_err_h(t_vf_rk4, h_vf_rk4, t_vf, h_vf, ax1, title="VF with reg.")
plot_err_h(t_cmy_rk4, h_cmy_rk4, t_cmy, h_cmy, ax2, title="VF no reg.")
plot_err_h(t_sch_rk4, h_sch_rk4, t_sch, h_sch, ax3, title="sch. fit", zorder=2.01, lw=1.5)

ylims1, ylims2, ylims3 = ax1.get_ylim(), ax2.get_ylim(), ax3.get_ylim()
for ax in (ax1, ax2, ax3):
    ax.set_ylim(min(ylims1[0], ylims2[0], ylims3[0]), max(ylims1[1], ylims2[1], ylims3[1]))

ax1.set_xlabel("$t$")
ax1.set_ylabel("rel. err. on $H$")

ax2.set(xlabel = "$t$", yticklabels=[])
ax3.set(xlabel = "$t$", yticklabels=[])

fig.tight_layout(pad=0.2, w_pad=0.8)

fig.savefig(os.path.join("figures", "bp_err_h.pdf"), bbox_inches="tight")

In [None]:
def plot_err_h(t_rk4, h_rk4, t_dvi, h_dvi, ax, title=None, **kwargs_dvi):
    i, step = labels.index("BP"), 22

    ax.plot(t_dvi[i, ::step], h_dvi[i, ::step], c="gray", label="DVI", **({"lw": 1.0} | kwargs_dvi))

    t_rk4_plt, h_rk4_plt = t_rk4[i, ::step], h_rk4[i, ::step]
    ax.plot(t_rk4_plt, h_rk4_plt, alpha=0.0, label=None)  # dummy line for ylims
    colored_line(t_rk4_plt, h_rk4_plt, t_rk4_plt, ax, label="RK4", cmap="plasma_r")

    handles, plt_labels = ax.get_legend_handles_labels()
    ax.legend(
        title=title,
        handles=handles, 
        labels=plt_labels, 
        # special treatment for the multicolored line
        handler_map={handles[-1]: HandlerColormap(mpl.colormaps["plasma_r"])},
    )

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7.2, 2.4))

i, step = labels.index("BP"), 22
ti = t_vf[i, ::step]

ax1.plot(t_vf[i, ::step], h_vf[i, ::step], t_cmy[i, ::step], h_cmy[i, ::step])
ax1.plot(t_sch[i, ::step], h_sch[i, ::step])

ax2.plot(t_vf_rk4[i, ::step], h_vf_rk4[i, ::step], t_cmy_rk4[i, ::step], h_cmy_rk4[i, ::step])
ax2.plot(t_sch_rk4[i, ::step], h_sch_rk4[i, ::step], zorder=1.99)

fig.tight_layout(pad=0.2, w_pad=0.8)

# fig.savefig(os.path.join("figures", "bp_err_h.pdf"), bbox_inches="tight")

In [None]:
fig, (ax_rz, ax_h) = plt.subplots(1, 2, figsize=(7.2, 2.8), width_ratios=(4, 7))

R_rk4_plt, Z_rk4_plt = R_vf_rk4[0, ::5], Z_vf_rk4[0, ::5]
th_rk4_plt, u_rk4_plt = th_vf_rk4[0, ::5], u_vf_rk4[0, ::5]
c_rk4_plt = t_vf_rk4[0, ::5]

nf_ex_plt = nf_ex[0]
R_ex_plt, Z_ex_plt = R_ex[0, :nf_ex_plt], Z_ex[0, :nf_ex_plt]
th_ex_plt, u_ex_plt = th_ex[0, :nf_ex_plt], u_ex[0, :nf_ex_plt]

nf_dvi_plt = nf_vf[0] - 1
R_dvi_plt, Z_dvi_plt = R_vf[0, :nf_dvi_plt], Z_vf[0, :nf_dvi_plt]
th_dvi_plt, u_dvi_plt = th_vf[0, :nf_dvi_plt], u_vf[0, :nf_dvi_plt]

ax_rz.scatter(R_rk4_plt, Z_rk4_plt, c=c_rk4_plt, s=0.5, cmap="plasma_r", label="RK4")
ax_rz.plot(R_dvi_plt, Z_dvi_plt, color="gray", marker="o", label="DVI")
ax_rz.plot(R_ex_plt, Z_ex_plt, color="black", ls="dashed", lw=1.5, label="exact")

ax_rz.set_xlabel("$R$")
ax_rz.set_ylabel("$Z$")

(idx,) = np.where(np.diff(th_dvi_plt) > 1.5 * np.pi)
idx = np.concatenate(([0], idx + 1))

t_rk4_plt, h_rk4_plt = t_vf_rk4[0, ::21], h_vf_rk4[0, ::21]
ax_h.plot(t_vf[0, ::6], h_vf[0, ::6], c="gray", lw=0.5)
ax_h.plot(t_rk4_plt, h_rk4_plt, alpha=0.0, label=None)  # dummy line for ylims
colored_line(t_rk4_plt, h_rk4_plt, t_rk4_plt, ax_h, cmap="plasma_r")
ax_h.set_xlabel("$t$")
ax_h.set_ylabel("rel. err. on $H$")

handles, plt_labels = ax_rz.get_legend_handles_labels()
ax_rz.legend(
    handles=handles, 
    labels=plt_labels, 
    # special treatment for the multicolored line
    handler_map={handles[0]: HandlerColormap(mpl.colormaps["plasma_r"])},
)

# ax_rz.legend()

fig.tight_layout(pad=0.5)

In [None]:
fig, (ax_rz, ax_h) = plt.subplots(1, 2, figsize=(7.2, 2.8), width_ratios=(4, 7))

R_rk4_plt, Z_rk4_plt = R_sch_rk4[0, ::5], Z_sch_rk4[0, ::5]
th_rk4_plt, u_rk4_plt = th_sch_rk4[0, ::5], u_sch_rk4[0, ::5]
c_rk4_plt = t_sch_rk4[0, ::5]

nf_ex_plt = nf_ex[0]
R_ex_plt, Z_ex_plt = R_ex[0, :nf_ex_plt], Z_ex[0, :nf_ex_plt]
th_ex_plt, u_ex_plt = th_ex[0, :nf_ex_plt], u_ex[0, :nf_ex_plt]

nf_dvi_plt = nf_sch[0] - 1
R_dvi_plt, Z_dvi_plt = R_sch[0, :nf_dvi_plt], Z_sch[0, :nf_dvi_plt]
th_dvi_plt, u_dvi_plt = th_sch[0, :nf_dvi_plt], u_sch[0, :nf_dvi_plt]

ax_rz.scatter(R_rk4_plt, Z_rk4_plt, c=c_rk4_plt, s=0.5, cmap="plasma_r", label="RK4")
ax_rz.plot(R_dvi_plt, Z_dvi_plt, color="gray", marker="o", label="DVI")
ax_rz.plot(R_ex_plt, Z_ex_plt, color="black", ls="dashed", lw=1.5, label="exact")

ax_rz.set_xlabel("$R$")
ax_rz.set_ylabel("$Z$")

(idx,) = np.where(np.diff(th_dvi_plt) > 1.5 * np.pi)
idx = np.concatenate(([0], idx + 1))

t_rk4_plt, h_rk4_plt = t_sch_rk4[0, ::21], h_sch_rk4[0, ::21]
ax_h.plot(t_rk4_plt, h_rk4_plt, alpha=0.0, label=None)  # dummy line for ylims
colored_line(t_rk4_plt, h_rk4_plt, t_rk4_plt, ax_h, cmap="plasma_r")
ax_h.plot(t_sch[0, ::6], h_sch[0, ::6], c="gray", lw=0.5)
ax_h.set_xlabel("$t$")
ax_h.set_ylabel("rel. err. on $H$")

handles, plt_labels = ax_rz.get_legend_handles_labels()
ax_rz.legend(
    handles=handles, 
    labels=plt_labels, 
    # special treatment for the multicolored line
    handler_map={handles[0]: HandlerColormap(mpl.colormaps["plasma_r"])},
)

# ax_rz.legend()

fig.tight_layout(pad=0.2, w_pad=1.0)