In [None]:
import os

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import torch
from torch import nn

torch.set_default_dtype(torch.float64)

from symplearn.datasets import VectorFieldDataset
from symplearn.training import Trainer, VectorFieldLoss
from symplearn.numerics import RK4Simulation, EulerDVISimulation, QuasiExactSimulation

from smallmodels import LotkaVolterra, NeuralBaseLV, NeuralSympLV, NeuralHamLV
from utils import HandlerColormap, colored_line, find_index_period

model_ref = LotkaVolterra()

In [None]:
# ref_params = dict(c="black", lw=1.5, ls="dashed", label="reference")
# base_params = dict(cmap=plt.cm.Greys, label="no structure")
# symp_params = dict(c="mediumseagreen", lw=2.5, label="non-canonical")
# ham_params = dict(c="darkorange", lw=2.5, ls="dotted", label="canonical")
ref_params = dict(c="black", lw=1.5, ls="dashed", label="reference")
base_params = dict(cmap=plt.cm.plasma_r, label="no structure")
symp_params = dict(c="gray", lw=2, label="non-canonical")
ham_params = dict(c="tab:brown", lw=2, ls="dotted", label="canonical")

SMALL_SIZE = 7
MEDIUM_SIZE = 8
BIGGER_SIZE = 9

plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)   # legend fontsize

In [None]:
def simulate(
    name,
    model=LotkaVolterra(),
    scheme=QuasiExactSimulation,
    z0=torch.tensor([[1.0, 1.0]]),
    dt=1e-2,
    nt=20_000,
):
    t_fname = os.path.join("simul", "intro", "t_" + name + ".pt")
    z_fname = os.path.join("simul", "intro", "z_" + name + ".pt")
    if os.path.exists(t_fname) and os.path.exists(z_fname):
        t_sim = torch.load(t_fname, weights_only=True)
        z_sim = torch.load(z_fname, weights_only=True)

    else:
        sim = scheme(model, dt)
        t_sim, z_sim = sim.simulate(z0, nt)
        if z0.shape[0] == 1:
            t_sim, z_sim = t_sim[0], z_sim[0]
        t_sim, z_sim = t_sim.detach(), z_sim.detach()
        torch.save(t_sim, t_fname)
        torch.save(z_sim, z_fname)

    return t_sim.numpy(), z_sim.numpy()


def compute_rel_energy_err(z, model=LotkaVolterra()):
    torch_z = torch.from_numpy(z)[None, ...]
    x, y = torch.tensor_split(torch_z, 2, -1)
    ham = model.hamiltonian(x, y, None).detach().numpy()[0]
    return (ham - ham[0]) / ham

## Train and study small models

In [None]:
train_data = VectorFieldDataset("train")
test_data = VectorFieldDataset("test")


def train(
    nn_model, nn_name, batch_size=500, learning_rates=[1e-2, 1e-3], num_epochs=[50, 150]
):
    nn_save_path = os.path.join("nn", "small_" + nn_name)

    if os.path.exists(nn_save_path + ".pt"):
        nn_model.load_state_dict(torch.load(nn_save_path + ".pt", weights_only=True))
        traces = torch.load(nn_save_path + "_trace.pt", weights_only=True)
        return nn_model, traces

    # else perform training
    loss_fn = VectorFieldLoss(nn_model, reg_weight=None)
    trainer = Trainer(train_data, test_data, loss_fn, batch_size=batch_size)

    traces = []
    for n, lr in zip(num_epochs, learning_rates):
        opt = torch.optim.Adam(nn_model.parameters(), lr=lr)
        trace = trainer.train(n, opt, nn_save_path)
        traces.append(trace)

    torch.save(traces, nn_save_path + "_trace.pt")
    return nn_model, traces

In [None]:
## Training

torch.manual_seed(42)
base_nn, base_trace = train(NeuralBaseLV(), "base")

torch.manual_seed(42)
symp_nn, symp_trace = train(NeuralSympLV(), "symp")

torch.manual_seed(42)
ham_nn, ham_trace = train(NeuralHamLV(), "hnn")

In [None]:
## Simulations

nt = 12_001

t_ref, z_ref = simulate("ref_qe")
t_ref, z_ref = t_ref[:nt], z_ref[:nt]
h_ref = compute_rel_energy_err(z_ref)
n_ref = find_index_period(z_ref)

t_base, z_base = simulate("base_qe", model=base_nn)
t_base, z_base = t_base[:nt], z_base[:nt]
h_base = compute_rel_energy_err(z_base)

t_symp, z_symp = simulate("symp_qe", model=symp_nn)
t_symp, z_symp = t_symp[:nt], z_symp[:nt]
h_symp = compute_rel_energy_err(z_symp)

t_ham, z_ham = simulate("hnn_qe", model=ham_nn)
t_ham, z_ham = t_ham[:nt], z_ham[:nt]
h_ham = compute_rel_energy_err(z_ham)

In [None]:
fig, ax = plt.subplots(figsize=(4.5, 3.5))

(p_ref,) = ax.plot(*z_ref[:n_ref:5].T, zorder=3, **ref_params)

tb_min, tb_max = t_base.min(), t_base.max()
c_base = base_params["cmap"](0.2 + 0.6 * (t_base - tb_min) / (tb_max - tb_min))
p_base = colored_line(*z_base[::3].T, c_base[::3], ax, **base_params)
# p_base = ax.scatter(*z_base[::5].T, c=t_base[::5], **base_params)

n_symp = find_index_period(z_symp)
(p_symp,) = ax.plot(*z_symp[:n_symp:8].T, **symp_params)

n_ham = find_index_period(z_ham)
(p_ham,) = ax.plot(*z_ham[:n_ham:8].T, **ham_params)

handles, labels = ax.get_legend_handles_labels()
ax.legend(
    handles=handles, 
    labels=labels, 
    # special treatment for the multicolored line
    handler_map={handles[1]: HandlerColormap(base_params["cmap"])}
)

ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
fig.tight_layout()

figname = os.path.join("figures", "intro", "small_nns_phasespace.pdf")
fig.savefig(figname, bbox_inches="tight")

In [None]:
t_ham, z_ham
t_symp, z_symp

In [None]:
fig, ax = plt.subplots(figsize=(4, 3.5))

(p_ref,) = ax.plot(t_ref[::200], h_ref[::200], zorder=3, **ref_params)
# ax.scatter(t_base[::16], (h_base[::16], c=t_base[::16], **base_params)
colored_line(t_base[::16], h_base[::16], c_base[::16], ax, **base_params)
(p_ham,) = ax.plot(t_ham, h_ham, **(ham_params | {"lw": 1.5, "zorder": 0.5}))
(p_symp,) = ax.plot(t_symp, h_symp, **(symp_params | {"lw": 1.5}))

ax.set_xlabel("$t$")
ax.set_ylabel("rel. err. on $H$")

handles, labels = ax.get_legend_handles_labels()
ax.legend(
    loc="upper left",
    handles=handles, 
    labels=labels, 
    # special treatment for the multicolored line
    handler_map={handles[1]: HandlerColormap(base_params["cmap"])}
)

fig.tight_layout()

figname = os.path.join("figures", "intro", "small_nns_hamiltonian.pdf")
fig.savefig(figname, bbox_inches="tight")

In [None]:
def plot_sol(
    ax_left, ax_right, t, z, ylabel=None, add_ticks=True, colored=False, **kwargs
):
    ratio = 4
    nt_plt = len(t) // ratio
    t_left, z_left = t[:nt_plt:11], z[:nt_plt:11]
    t_right, z_right = t[:-nt_plt:-11], z[:-nt_plt:-11]

    if colored:
        cmap = kwargs.get("cmap", mpl.colormaps["plasma_r"])
        # Dummy plot for xlims and ylims
        ax_left.plot(t_left, z_left, lw=0.01, color="white")
        ax_right.plot(t_right, z_right, lw=0.01, color="white")

        s_left = (t_left - t.min()) / (t.max() - t.min())
        c_left = cmap(s_left * ratio / 3.0)
        colored_line(t_left, z_left, c_left, ax_left, **kwargs)

        s_right = (t_right - t.min()) / (t.max() - t.min())
        c_right = cmap(1.0 - (1.0 - s_right) * ratio / 3.0)
        colored_line(t_right, z_right, c_right, ax_right, **kwargs)
    else:
        ax_left.plot(t_left, z_left, **kwargs)
        ax_right.plot(t_right, z_right, **kwargs)

    if ylabel is not None:
        ax_left.set_ylabel(ylabel)

    if add_ticks:
        ax_right.spines.left.set_visible(False)
        ax_left.spines.right.set_visible(False)
        ax_right.tick_params(left=None, labelleft=False)

        kwargs = dict(
            marker=[(-1, -2), (1, 2)],
            markersize=6,
            linestyle="none",
            color="k",
            mec="k",
            mew=1,
            clip_on=False,
        )
        ax_left.plot([1, 1], [0, 1], transform=ax_left.transAxes, **kwargs)
        ax_right.plot([0, 0], [0, 1], transform=ax_right.transAxes, **kwargs)

    else:
        ymin = min(ax_left.get_ylim()[0], ax_right.get_ylim()[0])
        ymax = max(ax_left.get_ylim()[1], ax_right.get_ylim()[1])
        ax_left.set_ylim(ymin, ymax)
        ax_right.set_ylim(ymin, ymax)


fig, ax = plt.subplots(2, 2, figsize=(4.5, 3.5))
fig.subplots_adjust(wspace=0.07)  # adjust space between Axes

# plot_sol(*ax[0], t_ref, z_ref[:, 1], ylabel="reference", c="grey", label="y")
# plot_sol(*ax[0], t_ref, z_ref[:, 0], add_ticks=False, c="black", ls="dashed", label="x")
plot_sol(*ax[0], t_ref, z_ref[:, 0], ylabel="reference", **(ref_params | {"ls": "solid", "label": "x"}))
plot_sol(*ax[0], t_ref, z_ref[:, 1], add_ticks=False, **(ref_params | {"ls": "dashed", "label": "y"}))
ax[0, 1].legend(loc="lower left")

# plot_sol(*ax[1], t_ham, z_ham[:, 1], ylabel="canonical", color="darkorange")
# plot_sol(*ax[1], t_ham, z_ham[:, 0], add_ticks=False, color="saddlebrown", ls="dashed")
plot_sol(*ax[1], t_ham, z_ham[:, 0], ylabel="canonical", **(ham_params))
plot_sol(*ax[1], t_ham, z_ham[:, 1], add_ticks=False, **(ham_params | {"ls": "dashed"}))

fig.supxlabel("$t$", fontsize=MEDIUM_SIZE)
fig.tight_layout(pad=0.2, w_pad=0.1)

figname = os.path.join("figures", "intro", "small_nns_sols1.pdf")
fig.savefig(figname, bbox_inches="tight")

fig, ax = plt.subplots(2, 2, figsize=(4.5, 3.5))
fig.subplots_adjust(wspace=0.07)  # adjust space between Axes

# plot_sol(*ax[0], t_base, z_base[:, 1], ylabel="no structure", color="grey")
# plot_sol(*ax[0], t_base, z_base[:, 0], add_ticks=False, color="dimgrey", ls="dashed")
plot_sol(*ax[0], t_base, z_base[:, 0], ylabel="no structure", colored=True, **base_params)
plot_sol(*ax[0], t_base, z_base[:, 1], add_ticks=False, colored=True, **base_params)

# plot_sol(*ax[1], t_symp, z_symp[:, 0], ylabel="non-canonical", c="seagreen")
# plot_sol(*ax[1], t_symp, z_symp[:, 1], add_ticks=False, c="mediumseagreen", ls="dashed")
plot_sol(*ax[1], t_symp, z_symp[:, 0], ylabel="non-canonical", **symp_params)
plot_sol(*ax[1], t_symp, z_symp[:, 1], add_ticks=False, **(symp_params | {"ls": "dashed"}))

fig.supxlabel("$t$", fontsize=MEDIUM_SIZE)
fig.tight_layout(pad=0.2, w_pad=0.1)

figname = os.path.join("figures", "intro", "small_nns_sols2.pdf")
fig.savefig(figname, bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(figsize=(4, 3.5))

(p_ref,) = ax.plot(*z_ref[:n_ref:5].T, zorder=3, **ref_params)

p_base = colored_line(*z_base[::3].T, t_base[::3], ax, **base_params)
# p_base = ax.scatter(*z_base[::5].T, c=t_base[::5], **base_params)

n_ham = find_index_period(z_ham)
(p_ham,) = ax.plot(*z_ham[:n_ham:8].T, **ham_params)

n_symp = find_index_period(z_symp)
(p_symp,) = ax.plot(*z_symp[:n_symp:8].T, **symp_params)

handles, labels = ax.get_legend_handles_labels()
ax.legend(
    handles=handles, 
    labels=labels, 
    # special treatment for the multicolored line
    handler_map={handles[1]: HandlerColormap(base_params["cmap"])}
)

ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
fig.tight_layout()

figname = os.path.join("figures", "intro", "small_nns_phasespace.pdf")
fig.savefig(figname, bbox_inches="tight")

## Long-time integration of the reference model

In [None]:
dt = 0.2
nt = 100_000

t_ref, z_ref = simulate("ref_qe")
n_ref = find_index_period(z_ref)

rk4_params = base_params | {"label": "RK4"}
dvi_params = symp_params | {"label": "DVI"}

t_rk4, z_rk4 = simulate("rk4", scheme=RK4Simulation, dt=dt, nt=nt)
t_dvi, z_dvi = simulate("dvi", scheme=EulerDVISimulation, dt=dt, nt=nt)

h_rk4 = compute_rel_energy_err(z_rk4)
h_dvi = compute_rel_energy_err(z_dvi)

In [None]:
fig, ax = plt.subplots(figsize=(4, 3.5))

(p_ref,) = ax.plot(*z_ref[:n_ref:5].T, zorder=2.05, **(ref_params | {"label": "exact"}))
# (p_ref,) = ax.plot(*z_ref[:n_ref:5].T, zorder=3, **(ref_params | {"label": "exact"}))

p_rk4 = ax.scatter(*z_rk4[::20].T, c=t_rk4[::20], s=1, **rk4_params)
# p_rk4 = colored_line(*z_rk4.T, t_rk4, ax, **rk4_params)

n_dvi = find_index_period(z_dvi)
(p_dvi,) = ax.plot(*z_dvi[:n_dvi].T, marker="o", zorder=2.1, **dvi_params)

handles, labels = ax.get_legend_handles_labels()
ax.legend(
    handles=handles, 
    labels=labels, 
    # special treatment for the multicolored line
    handler_map={handles[1]: HandlerColormap(base_params["cmap"])}
)

ax.set_xlabel("$x$")
ax.set_ylabel("$y$")
fig.tight_layout()

figname = os.path.join("figures", "intro", "long_time_phasespace.pdf")
fig.savefig(figname, bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(figsize=(4, 3.5))

# subregion of the original image
bounds = (0.2, 0.62, 0.72, 0.25)
xlim, ylim = (0.0, 220.0), (-0.05, 0.05)
x_range = xlim[1] - xlim[0]
axins = ax.inset_axes(bounds, xlim=xlim, ylim=ylim, xticks=xlim, yticks=ylim)

# (p_ref,) = ax.plot(t_ref[::200], h_ref[::200], zorder=3, **ref_params)
(p_dvi,) = ax.plot(t_dvi[::10], h_dvi[::10], **(dvi_params | {"lw": 1}))
colored_line(t_rk4[::50], h_rk4[::50], t_rk4[::50], ax, **rk4_params)

(p_dvi,) = axins.plot(t_dvi, h_dvi, **(dvi_params | {"lw": 1}))
colored_line(t_rk4, h_rk4, t_rk4, axins, **rk4_params)

ax.set_xlabel("$t$")
ax.set_ylabel("rel. err. on $H$")

axins.tick_params(axis="both", which="major", labelsize=8)

ax.set_ylim(-0.07, 0.055)

handles, labels = ax.get_legend_handles_labels()
ax.legend(
    loc="lower center",
    handles=handles, 
    labels=labels, 
    # special treatment for the multicolored line
    handler_map={handles[1]: HandlerColormap(base_params["cmap"])}
)

ax.indicate_inset_zoom(axins, edgecolor="black")

fig.tight_layout()
figname = os.path.join("figures", "intro", "long_time_hamiltonian.pdf")
fig.savefig(figname, bbox_inches="tight")