In [None]:
import os

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

import torch

from symplearn.datasets import VectorFieldDataset, SnapshotDataset
from symplearn.numerics import QuasiExactSimulation, EulerDVISimulation, RK4Simulation

from utils import get_reduced_val_data, load_models

torch.set_default_dtype(torch.float64)
torch.set_default_device("cpu")

_, z0, dt = get_reduced_val_data()
# z0 = torch.tensor([[2.0, 4.0]])
models = load_models()

model = models["ref"]
t, z = {key: {} for key in models}, {key: {} for key in models}

In [None]:
from symplearn.training.norms import GramMSNorm, ScaledMSNorm
from symplearn.training.losses import VectorFieldLoss

data_vf = VectorFieldDataset("val")
z, t, dt_z = data_vf[:]

reg_ref = torch.vmap(VectorFieldLoss(models["ref"])._error_reg)(z, t, dt_z)[1]
reg_vf_reg = torch.vmap(VectorFieldLoss(models["vf_reg"])._error_reg)(z, t, dt_z)[1]
reg_vf_no_reg = torch.vmap(VectorFieldLoss(models["vf_no_reg"])._error_reg)(z, t, dt_z)[1]

GramMSNorm()(reg_ref, dt_z), GramMSNorm()(reg_vf_reg, dt_z), GramMSNorm()(reg_vf_no_reg, dt_z)
# ScaledMSNorm()(reg_ref, dt_z), ScaledMSNorm()(reg_vf_reg, dt_z), ScaledMSNorm()(reg_vf_no_reg, dt_z)

In [None]:
dt_ex = 1e-3
sim_ex = QuasiExactSimulation(model, dt_ex)
tf_ex = 130.5
t_ex, z_ex = sim_ex.simulate(z0, int(tf_ex / dt_ex))
x_ex, y_ex = z_ex[0, :, 0], z_ex[0, :, 1]

In [None]:
sim_ref = EulerDVISimulation(model, dt)
# init_step = torch.vmap(RK4Simulation(model, dt).step)
t_ref, z_ref = sim_ref.simulate(z0, 400)

In [None]:
SMALL_SIZE = 7
MEDIUM_SIZE = 8
BIGGER_SIZE = 9
HUGE_SIZE = 11

plt.rc("axes", labelsize=MEDIUM_SIZE, titlesize=HUGE_SIZE)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=BIGGER_SIZE, title_fontsize=BIGGER_SIZE)  # legend fontsize
plt.rc("figure", figsize=(3.4, 2.9))

# colorblind friendly colormap
cmap0 = mpl.colormaps["Blues_r"]
c0a, c0b, c0c, c0d = cmap0([0.1, 0.4, 0.5, 0.6])
cmap1 = mpl.colormaps["Oranges_r"]
c1a, c1b, c1c, c1d = cmap1([0.1, 0.4, 0.5, 0.6])
cmap2 = mpl.colormaps["Greens_r"]
c2a, c2b, c2c, c2d = cmap2([0.1, 0.4, 0.5, 0.6])

colors = np.array([c0a, c1b, c2c])

In [None]:
def plot_sol(zt, zt_ex=z_ex, title=None, labels=["DVI", "ref."], ms=7):
    fig, ax = plt.subplots()
    for k, zk in enumerate(zt):
        ax.scatter(zk[:, 0], zk[:, 1], s=ms, color=colors[k])
        ax.plot(z_ex[k, :, 0], z_ex[k, :, 1], lw=0.5, c=colors[k])
    ax.scatter([], [], c="gray", label=labels[0])
    ax.plot([], [], c="gray", label=labels[-1])
    ax.legend(title=title)
    ax.set_xlabel("$x$")
    ax.set_ylabel("$y$")
    fig.tight_layout(pad=0.3)
    return fig

In [None]:
fig_ref = plot_sol(z_ref, title="ref. model", labels=["DVI", "exact"])
fig_ref.savefig(os.path.join("figures", "ref_dvi.pdf"))

In [None]:
sim_vf_reg = EulerDVISimulation(models["vf_reg"], dt)
init_step = torch.vmap(RK4Simulation(models["vf_reg"], dt).step)
t_vf_reg, zt_vf_reg = sim_vf_reg.simulate(z0, 500, init_step=init_step)

In [None]:
fig_vf_reg = plot_sol(zt_vf_reg, title="VF with reg.")
fig_vf_reg.savefig(os.path.join("figures", "vf_reg_dvi.pdf"))

In [None]:
sim_vf_no_reg = EulerDVISimulation(models["vf_no_reg"], dt)
init_step = torch.vmap(RK4Simulation(models["vf_no_reg"], dt).step)
t_vf_no_reg, zt_vf_no_reg = sim_vf_no_reg.simulate(z0, 500, init_step=init_step)

In [None]:
fig_vf_no_reg = plot_sol(zt_vf_no_reg, title="VF no reg.", ms=2)
fig_vf_no_reg.savefig(os.path.join("figures", "vf_no_reg_dvi.pdf"))

In [None]:
sim_dvi = EulerDVISimulation(models["dvi"], dt)
# init_step = torch.vmap(lambda z, t: model_step(z, t, dt))
# t_dvi, zt_dvi = sim_dvi.simulate(z0, 200, init_step=init_step)
t_dvi, zt_dvi = sim_dvi.simulate(z0, 500)
# t_dvi, zt_dvi = sim_dvi.simulate(z0, 2000)

# from symplearn.numerics.solver import NewtonRaphsonSolver

# t_dvi, zt_dvi = sim_dvi.simulate(z0, 200, solver=NewtonRaphsonSolver(max_iters=0), init_step=init_step)

In [None]:
fig_sch = plot_sol(zt_dvi, title="sch. learning")
fig_sch.savefig(os.path.join("figures", "sch_dvi.pdf"))