In [None]:
import jax.numpy as np
import matplotlib.pyplot as plt

from ode_filters.filters import ekf1_sqr_loop, rts_sqr_smoother_loop
from ode_filters.measurement import ODEInformation
from ode_filters.priors import IWP, taylor_mode_initialization

In [None]:
omega_0 = 1.0


def vf_2(x, dx, *, t):
    """Second-order ODE: d²x/dt² = dx * (1 - x)"""
    return -(omega_0**2) * x


x0 = np.array([1.0])
dx0 = np.array([0.0])  # initial velocity
tspan = [0, 10]
d = x0.shape[0]

q = 3
xi = 0.5 * np.eye(d)
prior = IWP(q, d, Xi=xi)
mu_0, Sigma_0_sqr = taylor_mode_initialization(vf_2, (x0, dx0), q, order=2)


In [None]:
mu_0

In [None]:
measure = ODEInformation(vf_2, prior.E0, prior.E1, E2=prior.E2, order=2)
N = 21

(
    m_seq,
    P_seq_sqr,
    m_pred_seq,
    P_pred_seq_sqr,
    G_back_seq,
    d_back_seq,
    P_back_seq_sqr,
    mz_seq,
    Pz_seq_sqr,
) = ekf1_sqr_loop(mu_0, Sigma_0_sqr, prior, measure, tspan, N)


m_smoothed, P_smoothed_sqr = rts_sqr_smoother_loop(
    m_seq[-1], P_seq_sqr[-1], G_back_seq, d_back_seq, P_back_seq_sqr, N
)

In [None]:
measure.g(mu_0, t=1.0)


In [None]:
m_seq = np.array(m_seq)
P_seq_sqr = np.array(P_seq_sqr)
ts = np.linspace(tspan[0], tspan[1], N + 1)
plt.figure(figsize=(10, 4), dpi=600)
plt.xlabel("t")
plt.ylabel("x(t)")
plt.plot(ts, m_seq[:, 0], label="filtered estimate")
P_seq = np.matmul(np.transpose(P_seq_sqr, (0, 2, 1)), P_seq_sqr)
margin = 2 * np.sqrt(P_seq[:, 0, 0])
plt.fill_between(
    ts,
    m_seq[:, 0] - margin,
    m_seq[:, 0] + margin,
    alpha=0.5,
    label=r"2 $\sigma$ interval",
)
plt.legend()
plt.show()