In [None]:
import os

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import torch

torch.set_default_dtype(torch.float64)

from symplearn.numerics import EulerDVISimulation, QuasiExactSimulation

from models import LotkaVolterra
from utils import load_models, find_index_period

In [None]:
class PerturbedLotkaVolterra(LotkaVolterra):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.w = 2.0

    def oneform(self, x, y):
        q = super().oneform(x, y)
        return q + torch.cos(self.w * x) / self.w

    def euler_lagrange_maps(self, x, y, t):
        (dx_q, dy_q), (dx_h, dy_h) = super().euler_lagrange_maps(x, y, t)
        dx_q -= torch.sin(self.w * x).unsqueeze(-1)
        return (dx_q, dy_q), (dx_h, dy_h)


model_ref = LotkaVolterra()
model_pert = PerturbedLotkaVolterra()
model_nn = load_models()["vf_no_reg"]
model_reg = load_models()["vf_reg"]

In [None]:
z0_default = torch.tensor([[4.0, 3.0]])
# z0_default = torch.tensor([[1.0, 4.0]])


def compute_sol(model, dt, tf=8.0, z0=z0_default, scheme=EulerDVISimulation):
    sim = scheme(model, dt)
    nt = int(np.ceil(tf / dt))
    t, z = sim.simulate(z0, nt)
    if z.shape[0] == 1:
        t, z = t[0], z[0]

    return t, z

In [None]:
SMALL_SIZE = 7
MEDIUM_SIZE = 8
BIGGER_SIZE = 9

plt.rc("axes", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=BIGGER_SIZE, title_fontsize=BIGGER_SIZE)  # legend fontsize

plt.rc("figure", figsize=(3,3))

ref_params = dict(label="reference", color="gray", linestyle="solid", linewidth=1.5)
ex_params = dict(
    label="exact", color="black", linestyle="dotted", linewidth=3, zorder=2.01
)

col_ref = mpl.colormaps["plasma"](0.85)
col_pert = mpl.colormaps["plasma"](0.65)
col_nn = mpl.colormaps["plasma"](0.4)
col_reg = mpl.colormaps["plasma"](0.12)


def plot_with_dt(ax, z, dt, **kwargs):
    lab_dt = "$\\Delta t = 2^{" + str(int(np.log2(dt))) + "}$"
    nf = find_index_period(z) + 2
    ax.plot(z[:nf, 0], z[:nf, 1], label=lab_dt, linewidth=2.5, **kwargs)

In [None]:
dt_ex = 1e-2
t_ex, z_ex = compute_sol(model_ref, dt_ex)
nf_ex = find_index_period(z_ex) + 2

t_ex, z_ex = t_ex[:nf_ex], z_ex[:nf_ex]

In [None]:
## Verify that the perturbed model has a larger regularization term than the reference model

from symplearn.datasets import VectorFieldDataset
from symplearn.training.losses import VectorFieldLoss

data_vf = VectorFieldDataset("val")
z, t, dt_z = data_vf[:]

reg_ref = VectorFieldLoss(model_ref)(z, t, dt_z)[1]["abs_reg"]
reg_pert = VectorFieldLoss(model_pert)(z, t, dt_z)[1]["abs_reg"]
reg_nn = VectorFieldLoss(model_nn)(z, t, dt_z)[1]["abs_reg"]

reg_ref, reg_pert, reg_nn

### Reference model

In [None]:
dt_large_ref = 2.0 ** -3
dt_small_ref = 2.0 ** -4

t_large_ref, z_large_ref = compute_sol(model_ref, dt_large_ref)
t_small_ref, z_small_ref = compute_sol(model_ref, dt_small_ref)

In [None]:
fig, ax = plt.subplots()

plot_with_dt(ax, z_large_ref, dt_large_ref, color=col_ref)
plot_with_dt(ax, z_small_ref, dt_small_ref, color=col_ref, linestyle="dashed")

ax.plot(*z_ex.T, **ex_params)

ax.legend(title="ref. $\\vartheta$")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")

fig.tight_layout()
fig.savefig(os.path.join("figures", "perturbed", "dvi_ref.pdf"))

### Perturbed model

In [None]:
dt_large_pert = 2.0 ** -7
dt_small_pert = 2.0 ** -8

t_large_pert, z_large_pert = compute_sol(model_pert, dt_large_pert)
t_small_pert, z_small_pert = compute_sol(model_pert, dt_small_pert)
t_ex_pert, z_ex_pert = compute_sol(model_pert, dt_ex, scheme=QuasiExactSimulation)

In [None]:
fig, ax = plt.subplots()

plot_with_dt(ax, z_large_pert, dt_large_pert, color=col_pert)
plot_with_dt(ax, z_small_pert, dt_small_pert, color=col_pert, linestyle="dashed")

nf_ex_pert = find_index_period(z_ex_pert) + 2
ax.plot(*z_ex_pert[:nf_ex_pert].T, **ex_params)
# ax.plot(*z_ex.T, **ref_params)

ax.legend(title="pert. $\\vartheta$")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")

fig.tight_layout()
fig.savefig(os.path.join("figures", "perturbed", "dvi_pert.pdf"))

### Neural network model

In [None]:
model_nn = load_models()["vf_no_reg"]

dt_large_nn = 2.0 ** -4
dt_small_nn = 2.0 ** -5

t_large_nn, z_large_nn = compute_sol(model_nn, dt_large_nn)
t_small_nn, z_small_nn = compute_sol(model_nn, dt_small_nn)
t_ex_nn, z_ex_nn = compute_sol(model_nn, dt_ex, scheme=QuasiExactSimulation)

In [None]:
fig, ax = plt.subplots()

plot_with_dt(ax, z_large_nn, dt_large_nn, color=col_nn)
plot_with_dt(ax, z_small_nn, dt_small_nn, color=col_nn, linestyle="dashed")

nf_ex_nn = -1 # find_index_period(z_ex_nn) + 2
ax.plot(*z_ex_nn[:nf_ex_nn].T, **ex_params)
# ax.plot(*z_ex.T, **ref_params)

ax.legend(title="neural $\\vartheta$, $H$")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")

fig.tight_layout()
fig.savefig(os.path.join("figures", "perturbed", "dvi_nn.pdf"))

In [None]:
model_reg = load_models()["vf_reg"]

dt_large_reg = 2.0 ** -3
dt_small_reg = 2.0 ** -4

t_large_reg, z_large_reg = compute_sol(model_reg, dt_large_reg)
t_small_reg, z_small_reg = compute_sol(model_reg, dt_small_reg)
t_ex_reg, z_ex_reg = compute_sol(model_reg, dt_ex, scheme=QuasiExactSimulation)

In [None]:
fig, ax = plt.subplots()

plot_with_dt(ax, z_large_reg, dt_large_reg, color=col_reg)
plot_with_dt(ax, z_small_reg, dt_small_reg, color=col_reg, linestyle="dashed")

nf_ex_reg = -1 # find_index_period(z_ex_reg) + 2
ax.plot(*z_ex_reg[:nf_ex_reg].T, **ex_params)
# ax.plot(*z_ex.T, **ref_params)

ax.legend(title="reg. neural $\\vartheta$, $H$")
ax.set_xlabel("$x$")
ax.set_ylabel("$y$")

fig.tight_layout()
fig.savefig(os.path.join("figures", "perturbed", "dvi_reg.pdf"))

In [None]:
fig, ax = plt.subplots(figsize=(4.3, 3.5))

ref_col = mpl.colormaps["plasma"](0.85)
ax.plot([], [], color="none", label="ref. model")
plot_with_dt(ax, z_large_ref, dt_large_ref, color=ref_col)
plot_with_dt(ax, z_small_ref, dt_small_ref, color=ref_col, linestyle="dashed")

pert_col = mpl.colormaps["plasma"](0.4)
ax.plot([], [], color="none", label="perturbed $\\vartheta$")
plot_with_dt(ax, z_large_pert, dt_large_pert, color=pert_col)
plot_with_dt(ax, z_small_pert, dt_small_pert, color=pert_col, linestyle="dashed")

ax.plot(z_ex[:, 0], z_ex[:, 1], label="exact", color="black", linestyle="dotted", linewidth=3)
# ax.scatter(z0[0, 0], z0[0, 1], color="gray", marker="o", s=12, zorder=2.01, label="initial condition")

# ax.set_xlim(0.8, 6.5)
# ax.set_ylim(0.5, 4.5)
ax.tick_params(axis="both", which="major", labelsize=8)

handles, labels = ax.get_legend_handles_labels()

leg = ax.legend(
    handles=handles[:6],
    labels=labels[:6],
    loc="lower left",
    bbox_to_anchor=(0.02, 0.05),
    fontsize=10,
    ncols=2,
)

item, label = leg.legend_handles[0], leg.texts[0]
width = item.get_window_extent(fig.canvas.get_renderer()).width
label.set_position((-width,0))

item, label = leg.legend_handles[3], leg.texts[3]
width = item.get_window_extent(fig.canvas.get_renderer()).width
label.set_position((-width,0))

fig.tight_layout()