In [None]:
import os

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import torch

from symplearn.datasets import VectorFieldDataset, SnapshotDataset
from symplearn.numerics import QuasiExactSimulation, EulerDVISimulation, RK4Simulation

from models import LotkaVolterra
from utils import get_reduced_val_data, find_index_period, HandlerColormap

from utils import init_env

pb, models = init_env()

In [None]:
z_train = VectorFieldDataset("train").z
z_test = VectorFieldDataset("test").z
z_val = VectorFieldDataset("val").z

t0, z0, dt = get_reduced_val_data()
x0, y0 = z0.tensor_split(2, -1)

## Numerical error on the validation set

In [None]:
def compute_sols(model, scheme, z0=z_val, tf=10.0, dt0=dt, rate=1.5, num_steps=7, offset=1):
    pred_dt = dt0 * rate ** (offset - np.arange(num_steps))
    all_nt = np.array(np.ceil(tf / pred_dt), dtype=int)
    all_dt = tf / all_nt

    sols = []
    for nt, dt in zip(all_nt, all_dt):
        sim = scheme(model, dt)
        _, sol = sim.simulate(z0, nt)
        sols.append(sol.detach().cpu().numpy())

    return sols, all_dt


# Compute quasi-exact solutions
sol_ref_ex, all_dt_ref_ex = compute_sols(pb, QuasiExactSimulation)

In [None]:
def err_traj(z, z_ex):
    return np.max(np.sqrt(np.sum((z - z_ex) ** 2, axis=-1)), axis=-1)


def compute_err(model, scheme, compare_with_ref=True, **kwargs):
    sols, all_dt = compute_sols(model, scheme, **kwargs)
    model_ex = pb if compare_with_ref else model
    sols_ex, _ = compute_sols(model_ex, QuasiExactSimulation, **kwargs)
    err = [err_traj(sol, sol_ex) for sol, sol_ex in zip(sols, sols_ex)]
    return {"err": np.stack(err), "dt": all_dt}


def model_err(dir, model, compare_with_ref=True, **kwargs):
    fname = lambda name: os.path.join("simul", "err", dir, f"{name}.pt")

    sch_keys = ["dvi", "rk4"]
    trace_keys = ["err", "dt"]
    names = [f"{k}_{t}" for k in sch_keys for t in trace_keys]

    if all(os.path.exists(fname(name)) for name in names):
        res = {"dvi": {}, "rk4": {}}
        for name in names:
            key1, key2 = name.split("_")
            res[key1][key2] = torch.load(fname(name), weights_only=True).numpy()
    else:
        kwargs["compare_with_ref"] = compare_with_ref
        res = {
            "dvi": compute_err(model, EulerDVISimulation, **kwargs),
            "rk4": compute_err(model, RK4Simulation, **(kwargs | {"offset": 3})),
        }
        for key1 in res:
            for key2 in res[key1]:
                fname_res = fname(f"{key1}_{key2}")
                torch.save(torch.from_numpy(res[key1][key2]), fname_res)
    return res

#### Exact solutions

See `result_longtime.ipynb`

#### Convergence of the reference model

In [None]:
err_ref = model_err("ref", pb, compare_with_ref=False)

In [None]:
def plot_quantiles(ax, all_dt, err, alpha=0.25, quant=0.05, **kwargs):
    err_q0 = np.quantile(err, quant, axis=-1)
    err50 = np.quantile(err, 0.5, axis=-1)
    err_q1 = np.quantile(err, 1.0 - quant, axis=-1)

    p, = ax.loglog(all_dt, err50, zorder=2.01, **kwargs)

    x_poly = np.concatenate([all_dt, all_dt[::-1]])
    y_poly = np.concatenate([err_q1, err_q0[::-1]])
    ax.fill(x_poly, y_poly, alpha=alpha, label=None, color=p.get_color())
    return err50


fig, ax = plt.subplots(1, 1, figsize=(6, 3))
plot_quantiles(ax, err_ref["dvi"]["dt"], err_ref["dvi"]["err"], label="Euler-DVI", color="C0")
plot_quantiles(ax, err_ref["rk4"]["dt"], err_ref["rk4"]["err"], label="RK4", color="C1")
ax.legend()

### Results of neural networks

In [None]:
err_vf = model_err("vf", models["vf_reg"])

In [None]:
err_cmy = model_err("cmy", models["vf_no_reg"])

In [None]:
err_sch = model_err("sch", models["dvi"])

In [None]:
fig, (ax_rk4, ax_dvi) = plt.subplots(1, 2, figsize=(7.1, 2.7), width_ratios=(1, 0.95))

dvi_plt = dict(label="sch. learning", color="grey", marker="s")
cmy_plt = dict(label="VFL no reg.", color="C3", marker="D", alpha=0.35)
vf_plt = dict(label="VFL with reg.", color="C0", marker="o", alpha=0.3)

dt_min = min(dt.min() for dt in err_sch["dvi"]["dt"])
dt_max = max(dt.max() for dt in err_sch["dvi"]["dt"])
dt_dvi = np.array([dt_min, dt_max])
ax_dvi.plot(dt_dvi, 1e1 * dt_dvi, "--", color="black", label="order 1")

# plot_quantiles(ax_dvi, err_ref["dvi"]["dt"], err_ref["dvi"]["err"])
plot_quantiles(ax_dvi, err_sch["dvi"]["dt"], err_sch["dvi"]["err"], **dvi_plt)
plot_quantiles(ax_dvi, err_cmy["dvi"]["dt"], err_cmy["dvi"]["err"], **cmy_plt)
plot_quantiles(ax_dvi, err_vf["dvi"]["dt"], err_vf["dvi"]["err"], **vf_plt)


dt_min = min(dt.min() for dt in err_sch["rk4"]["dt"])
dt_max = max(dt.max() for dt in err_sch["rk4"]["dt"])
dt_rk4 = np.array([dt_min, dt_max])
ax_rk4.plot(dt_rk4, 1e3 * dt_rk4**4, "--", color="black", label="order 4")

# plot_quantiles(ax_rk4, err_ref["rk4"]["dt"], err_ref["rk4"]["err"])
plot_quantiles(ax_rk4, err_sch["rk4"]["dt"], err_sch["rk4"]["err"], **dvi_plt)
plot_quantiles(ax_rk4, err_cmy["rk4"]["dt"], err_cmy["rk4"]["err"], **cmy_plt)
plot_quantiles(ax_rk4, err_vf["rk4"]["dt"], err_vf["rk4"]["err"], **vf_plt)

y_min, y_max = ax_dvi.get_ylim()
ax_dvi.set_ylim(y_min, 3.5 * y_max)

ax_dvi.set_xlabel("$h$")
ax_rk4.set_xlabel("$h$")
ax_rk4.set_ylabel("abs. err. with RK4")
ax_dvi.set_ylabel("abs. err. with DVI")

ax_dvi.legend(ncols=2)
ax_rk4.legend(loc="lower right", ncols=2)

fig.tight_layout(pad=0.2, w_pad=1.2)

fname = os.path.join("figures", "num_err.pdf")
fig.savefig(fname, bbox_inches="tight")