In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import torch

from symplearn.datasets import VectorFieldDataset, SnapshotDataset
from symplearn.numerics import QuasiExactSimulation, EulerDVISimulation, RK4Simulation

from models import LotkaVolterra
from utils import get_reduced_val_data, find_index_period, HandlerColormap, init_env

pb, models = init_env()

In [None]:
z_train = VectorFieldDataset("train").z
z_test = VectorFieldDataset("test").z
z_val = VectorFieldDataset("val").z

t0, z0, dt = get_reduced_val_data()
x0, y0 = z0.tensor_split(2, -1)

## Long-time simulations

In [None]:
def simul_long_time(dir, model, dt=dt, z0=z0):
    sim_schemes = {
        "ex": QuasiExactSimulation,
        "dvi": EulerDVISimulation,
        "rk4": RK4Simulation,
    }
    dt = {"ex": dt * 0.2, "dvi": dt, "rk4": dt}
    nt = {"ex": 10_000, "dvi": 100_000, "rk4": 250_000}
    sol = {k: {} for k in sim_schemes.keys()}
    fname = lambda name: os.path.join("simul", "longtime", dir, f"{name}.pt")

    for sim in sim_schemes.keys():
        tname, zname = fname(f"t_{sim}"), fname(f"z_{sim}")
        if os.path.exists(tname) and os.path.exists(zname):
            sol[sim]["t"] = torch.load(tname, weights_only=True)
            sol[sim]["z"] = torch.load(zname, weights_only=True)
        else:
            sol[sim] = {"t": torch.zeros(len(z0), 1), "z": z0[:, None, :]}

        if sol[sim]["t"].shape[1] < nt[sim] + 1:
            nt_sim = nt[sim] + 1 - sol[sim]["t"].shape[1]
            sim_sch = sim_schemes[sim]
            t_init, z_init = sol[sim]["t"][:, -1], sol[sim]["z"][:, -1]
            t_sim, z_sim = sim_sch(model, dt[sim]).simulate(z_init, nt_sim, t0=t_init)

            t_sim, z_sim = t_sim.detach(), z_sim.detach()
            sol[sim]["t"] = torch.concat((sol[sim]["t"][:, :-1], t_sim), axis=1)
            sol[sim]["z"] = torch.concat((sol[sim]["z"][:, :-1], z_sim), axis=1)

            torch.save(sol[sim]["t"], tname)
            torch.save(sol[sim]["z"], zname)

    sol[sim]["t"] = sol[sim]["t"].cpu().numpy()
    sol[sim]["z"] = sol[sim]["z"].cpu().numpy()
    return sol

#### Perform simulations

In [None]:
sol_ref = simul_long_time("ref", models["ref"])

In [None]:
sol_vf = simul_long_time("vf", models["vf_reg"])

In [None]:
sol_cmy = simul_long_time("cmy", models["vf_no_reg"])

In [None]:
sol_sch = simul_long_time("sch", models["dvi"])

In [None]:
sol_ref["dvi"]["t"].shape, sol_ref["rk4"]["t"].shape
sol_sch["dvi"]["t"].shape, sol_sch["rk4"]["t"].shape
sol_vf["dvi"]["t"].shape, sol_vf["rk4"]["t"].shape
sol_cmy["dvi"]["t"].shape, sol_cmy["rk4"]["t"].shape

In [None]:
# colorblind friendly colormap (to the best of my ability)
cmap0 = mpl.colormaps["Blues"]
cmap1 = mpl.colormaps["Greens"]
cmap2 = mpl.colormaps["Oranges"]
cmap3 = mpl.colormaps["Purples"]

cmaps = [cmap0, cmap1, cmap2, cmap3]
colors = [cmap0(0.9), cmap1(0.75), cmap2(0.6), cmap3(0.6)]
colors_dark = [cm(0.9) for cm in cmaps]
colors_light = [cm(0.6) for cm in cmaps]

In [None]:
fig_init_cond, ax_init_cond = plt.subplots(figsize=(5, 3))

ax_init_cond.scatter(*z_val.T, c="grey", s=3, label="val. data")
ax_init_cond.scatter([], [], c="black", marker="D", label="select.", s=20)

ax_init_cond.scatter(*z0.T, marker="D", s=20, c=colors_dark, label=None)

ax_init_cond.set(xlabel="$x$", ylabel="$y$")
ax_init_cond.legend()

fig_init_cond.tight_layout(pad=0.2)

In [None]:
def plot_phasespace(ax, sol, label=None, title=None, step=1, end=None, **kwargs):
    for k, cmapk in enumerate(cmaps):
        zk_ref_ex = sol_ref["ex"]["z"][k]
        nt_ref_ex = find_index_period(zk_ref_ex) + 1
        ref_plt = dict(color="grey", lw=1.5, alpha=0.5)
        ax.plot(*zk_ref_ex[:nt_ref_ex].T, label=None, **ref_plt)

        tk, zk = sol["t"][k, :end:step], sol["z"][k, :end:step]
        tk = (tk - tk.min()) / (tk.max() - tk.min())
        ck = cmapk(0.2 + 0.6 * tk)
        ax.scatter(*zk.T, c=ck, s=1, label=None)

    if label is not None:
        ax.plot([], [], color="black", label=label, lw=1.5)
        ax.plot([], [], label="ref.", **ref_plt)
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(
            title=title,
            handles=handles,
            labels=labels,
            handler_map={handles[-2]: HandlerColormap(mpl.colormaps["Greys"])},
        )


def plot_long_time(sol, labels=True, title=None, step_rk4=73, step_dvi=73):
    fig, (ax_rk4, ax_dvi) = plt.subplots(1, 2, figsize=(7.1, 2.7))

    plot_phasespace(ax_rk4, sol["rk4"], label="RK4", step=step_rk4, title=title)
    plot_phasespace(ax_dvi, sol["dvi"], label="DVI", step=step_dvi, title=title)

    if labels:
        ax_rk4.set(xlabel="$x$", ylabel="$y$")
        ax_dvi.set(xlabel="$x$", ylabel="$y$")
        fig.tight_layout(pad=0.2, w_pad=1.0)
    return fig

In [None]:
fig_ref = plot_long_time(sol_ref, title="ref. model")

for ax in fig_ref.get_axes():
    handles, labels = ax.get_legend_handles_labels()
    labels[1] = "exact"
    ax.legend(handles, labels, title="ref. model")

fname = os.path.join("figures", "longtime_ref")
fig_ref.savefig(fname + ".pdf", bbox_inches="tight")

In [None]:
fig_vf = plot_long_time(sol_vf, title="VF with reg.")

fname = os.path.join("figures", "longtime_vf")
fig_vf.savefig(fname + ".pdf", bbox_inches="tight")

In [None]:
sol_cmy_dvi = {k: sol_cmy["dvi"][k] for k in ["t", "z"]}
sol_cmy["dvi"]["t"] = sol_cmy["dvi"]["t"][:, :200]
sol_cmy["dvi"]["z"] = sol_cmy["dvi"]["z"][:, :200]

fig_cmy = plot_long_time(sol_cmy, labels=False, step_dvi=1, title="VF no reg.")

ax1, ax2 = fig_cmy.axes
ax1.set(xlabel="$x$", ylabel="$y$")

bounds = (0.03, 0.25, 0.8, 0.71)
xlim, ylim = ax1.get_xlim(), ax1.get_ylim()
x_range, y_range = xlim[1] - xlim[0], ylim[1] - ylim[0]
xlim = (xlim[0] - 0.04 * x_range, xlim[1] + 0.15 * x_range)
ylim = (ylim[0] - 0.05 * y_range, ylim[1] + 0.1 * y_range)
axins = ax2.inset_axes(bounds, xlim=xlim, ylim=ylim, xticks=xlim, yticks=ylim)
axins.patch.set_alpha(0.8)
axins.set(xticks=[], xticklabels=[], xlabel=None)
axins.set(yticks=[], yticklabels=[], ylabel=None)

# xmin, xmax = ax2.get_xlim()
# xrange = xmax - xmin
# ax2.set_xlim(xmin - 0.5 * xrange, xmax + 6.0 * xrange)
# ymin, ymax = ax2.get_ylim()
# yrange = ymax - ymin
# ax2.set_ylim(ymin - 0.5 * yrange, ymax + 6.0 * yrange)

ax2.set(xlabel="$x$", ylabel="$y$")

plot_phasespace(axins, sol_cmy["dvi"], step=1)
ax2.indicate_inset_zoom(axins, edgecolor="black")

fig_cmy.tight_layout(pad=0.2, w_pad=1.0)

fname = os.path.join("figures", "longtime_cmy")
fig_cmy.savefig(fname + ".pdf", bbox_inches="tight")

sol_cmy["dvi"]["t"] = sol_cmy_dvi["t"]
sol_cmy["dvi"]["z"] = sol_cmy_dvi["z"]

In [None]:
sol_sch_rk4 = {k: sol_sch["rk4"][k] for k in ["t", "z"]}
sol_sch["rk4"]["t"] = sol_sch["rk4"]["t"][:, :100_000]
sol_sch["rk4"]["z"] = sol_sch["rk4"]["z"][:, :100_000]

fig_sch = plot_long_time(sol_sch, step_rk4=31, step_dvi=31, title="sch. learning")

sol_sch["rk4"]["t"] = sol_sch_rk4["t"]
sol_sch["rk4"]["z"] = sol_sch_rk4["z"]

fname = os.path.join("figures", "longtime_sch")
fig_sch.savefig(fname + ".pdf", bbox_inches="tight")

#### Exact ($h \to 0^+$) solutions

In [None]:
def plot_ex(ax, sol, label="neural", title=None, ncols=1):
    ref_plt = dict(color="grey", lw=1.5, alpha=0.5)

    for k, ck_ex in enumerate(colors_dark):
        zk_ref_ex = sol_ref["ex"]["z"][k]
        nt_ref_ex = find_index_period(zk_ref_ex) + 1
        ax.plot(*zk_ref_ex[:nt_ref_ex].T, label=None, **ref_plt)

        zk_ex = sol["ex"]["z"][k]
        nt_ex = find_index_period(zk_ex) + 2
        ax.plot(*zk_ex[:nt_ex].T, c=ck_ex, ls="dashed", label=None)

    ymin, ymax = ax.get_ylim()
    ax.set_ylim(ymin, ymax + 0.14 * (ymax - ymin))
    xmin, xmax = ax.get_xlim()
    ax.set_xlim(xmin, xmax + 0.0 * (xmax - xmin))

    ax.plot([], [], label="ref.", **ref_plt)
    ax.plot([], [], color="black", label=label, ls="dashed", lw=1.5)
    ax.set(xlabel="$x$", ylabel="$y$")
    ax.legend(title=title, ncols=ncols, columnspacing=1.0)

In [None]:
fig_ex, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(7.1, 2.5))
# plot_ex(ax1, sol_vf, title="VFL with reg.")
# plot_ex(ax2, sol_cmy, title="VFL no reg.")
# plot_ex(ax3, sol_sch, title="sch. learning")
plot_ex(ax1, sol_vf, label="VFL with reg.", ncols=2)
plot_ex(ax2, sol_cmy, label="VFL no reg.", ncols=2)
plot_ex(ax3, sol_sch, label="sch. learning", ncols=2)

fig_ex.tight_layout(pad=0.2, w_pad=0.8)

fname = os.path.join("figures", "nn_ex")
fig_ex.savefig(fname + ".pdf", bbox_inches="tight")