In [None]:
import pathlib
import importlib.util
import sys

file_path = pathlib.Path("/Users/jschmidt/Documents/Uni/MSc Machine Learning/Semesters/4 Thesis/Paper/high-dim-solvers/tornado/tornado/__init__.py")
module_name = "tornado"

spec = importlib.util.spec_from_file_location(module_name, file_path)
tornado = importlib.util.module_from_spec(spec)
sys.modules[module_name] = tornado
spec.loader.exec_module(tornado)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp

In [None]:
# ivp = tornado.ivp.vanderpol(t0=0.0, tmax=10.0, stiffness_constant=10.0)
ivp = tornado.ivp.brusselator(N=100)

In [None]:
d = ivp.y0.shape[0]
dt = 0.05
tol = 1e-7
steps = tornado.step.AdaptiveSteps(0.01, abstol=tol, reltol=tol)
num_derivatives = 5
ensemble_size = 1000

In [None]:
scipy_sol = solve_ivp(ivp.f, t_span=(ivp.t0, ivp.tmax), y0=ivp.y0, method="Radau", dense_output=True)
final_t_scipy = scipy_sol.t[-1]
final_y_scipy = scipy_sol.y[:, -1]

In [None]:
enkf0 = tornado.enkf.EnK0(num_derivatives=num_derivatives, steprule=steps, ensemble_size=ensemble_size)

In [None]:
ek1 = tornado.ek1.ReferenceEK1(num_derivatives=num_derivatives, steprule=steps)

In [None]:
%%time
enkf_states = list(enkf0.solution_generator(ivp))

In [None]:
%%time
ek1_states = list(ek1.solution_generator(ivp))

In [None]:
def extract_states(solver, states):
    ts = [s.t for s in states]
    try:
        means = [enkf0.E0 @ s.mean() for s in states]
    except:
        means = [enkf0.E0 @ s.y.mean for s in states]
    return ts, means

In [None]:
enkf_ts, enkf_means = extract_states(enkf0, enkf_states)
ek1_ts, ek1_means = extract_states(ek1, ek1_states)

In [None]:
fig = plt.figure()
ax_enkf = fig.add_subplot(1, 3, 1)
ax_ek1 = fig.add_subplot(1, 3, 2, sharey=ax_enkf)
ax_scipy = fig.add_subplot(1, 3, 3, sharey=ax_ek1)

ax_enkf.plot(enkf_ts, enkf_means)
ax_ek1.plot(ek1_ts, ek1_means)
ax_scipy.plot(enkf_ts, scipy_sol.sol(enkf_ts).T)

ax_enkf.set_title("EnKF")
ax_ek1.set_title("ReferenceEK1")
ax_scipy.set_title("SciPy reference")

display(fig)
# fig.savefig("enkf.pdf")
plt.close(fig)

In [None]:
enkf_error = np.linalg.norm([enkf0.E0 @ s.samples[:1000] for s in enkf_states] - scipy_sol.sol(enkf_ts).T[..., None], axis=1)
ek1_error = np.linalg.norm([enkf0.E0 @ s.y.mean for s in ek1_states] - scipy_sol.sol(ek1_ts).T, axis=1)

In [None]:
err_fig = plt.figure()
enkf_err_ax = err_fig.add_subplot(1, 2, 1)
ek1_err_ax = err_fig.add_subplot(1, 2, 2, sharey=enkf_err_ax)

enkf_err_ax.plot(enkf_ts, enkf_error)
ek1_err_ax.plot(ek1_ts, ek1_error)
enkf_err_ax.set_yscale("log")
ek1_err_ax.set_yscale("log")
display(err_fig)
plt.close(err_fig)

In [None]:
plt.figure()
plt.semilogy(enkf_ts[:-1], np.diff(enkf_ts), label="enkf")
plt.semilogy(ek1_ts[:-1], np.diff(ek1_ts), label="ek1")
plt.legend()
plt.show()