In [None]:
import jax.numpy as np
import jax.random as jrandom
import matplotlib.pyplot as plt

from ode_filters.filters import ekf1_sqr_loop, rts_sqr_smoother_loop
from ode_filters.measurement import (
    Measurement,
    SecondOrderODEInformation,
    SecondOrderODEInformationWithHidden,
)
from ode_filters.priors import IWP, JointPrior, taylor_mode_initialization

# Damped harmonic oscillator parameters
omega, zeta = 1.0, 0.1
gamma = 2 * zeta * omega
omega_d = np.sqrt(1 - zeta**2) * omega  # damped frequency
tspan, N = [0, 30], 100
ts = np.linspace(tspan[0], tspan[1], N + 1)

In [None]:
# Basic second-order ODE: d^2x/dt^2 = -omega^2 * x - gamma * dx/dt
def vf(x, dx, *, t):
    return -(omega**2) * x - gamma * dx


x0, dx0 = np.array([1.0]), np.array([0.0])
prior = IWP(q=2, d=1, Xi=1 * np.eye(1))
mu_0, Sigma_0_sqr = taylor_mode_initialization(vf, (x0, dx0), q=2, order=2)
measure = SecondOrderODEInformation(vf, prior.E0, prior.E1, prior.E2)


In [None]:
# Run filter and smoother
m_seq, P_sqr, _, _, G_back, d_back, P_back_sqr, *_ = ekf1_sqr_loop(
    mu_0, Sigma_0_sqr, prior, measure, tspan, N
)
m_smooth, P_smooth_sqr = rts_sqr_smoother_loop(
    m_seq[-1], P_sqr[-1], G_back, d_back, P_back_sqr, N
)
m_smooth, P_smooth_sqr = np.array(m_smooth), np.array(P_smooth_sqr)

In [None]:
# Plot smoothed estimate vs true solution
x_true = np.cos(omega_d * ts) * np.exp(-zeta * omega * ts)
P_smooth = np.einsum("ijk,ijl->ikl", P_smooth_sqr, P_smooth_sqr)
margin = 2 * np.sqrt(P_smooth[:, 0, 0])

plt.figure(figsize=(10, 3))
plt.plot(ts, m_smooth[:, 0], label="smoothed")
plt.plot(ts, x_true, "k--", label="true")
plt.fill_between(ts, m_smooth[:, 0] - margin, m_smooth[:, 0] + margin, alpha=0.3)
plt.xlabel("t"), plt.ylabel("x(t)"), plt.legend()
plt.show()

In [None]:
# Joint state-parameter estimation: infer damping gamma from noisy observations
noise_std = 1e-2
key = jrandom.PRNGKey(42)
z = (x_true[1:] + noise_std * jrandom.normal(key, shape=(N,))).reshape(-1, 1)
z_t = ts[1:]

plt.figure(figsize=(10, 3))
plt.scatter(z_t, z, s=10, alpha=0.7, color="black", label="observations")
plt.plot(ts, x_true, "k--", label="true")
plt.xlabel("t"), plt.ylabel("x(t)"), plt.legend()
plt.show()


In [None]:
# Setup joint prior: x (q=2) + gamma (q=1)
prior_x = IWP(q=2, d=1, Xi=0.1 * np.eye(1))
prior_gamma = IWP(q=2, d=1, Xi=1e-2 * np.eye(1))
prior_joint = JointPrior(prior_x, prior_gamma)

# Initial conditions (start with true gamma for initialization)
gamma0 = np.array([gamma])
mu_0_x, Sig_0_x = taylor_mode_initialization(vf, (x0, dx0), q=2, order=2)
print(mu_0_x.shape, Sig_0_x.shape)
mu_0_gamma = np.concatenate([gamma0, np.zeros(2)])
Sig_0_gamma = 1e-6 * np.eye(3)

D_x, D_gamma = 3, 3
mu_0_joint = np.concatenate([mu_0_x, mu_0_gamma])
zeros = np.zeros((D_x, D_gamma))
Sig_0_joint = np.block([[Sig_0_x, zeros], [zeros.T, Sig_0_gamma]])


# Vector field with hidden state u = gamma
def vf_joint(x, v, u, *, t):
    return -(omega**2) * x - u * v


# Measurement model: ODE + observations
obs = Measurement(A=np.array([[1.0]]), z=z, z_t=z_t, noise=noise_std)
measure_joint = SecondOrderODEInformationWithHidden(
    vf_joint,
    E0=prior_joint.E0_x,
    E1=prior_joint.E1,
    E2=prior_joint.E2,
    E0_hidden=prior_joint.E0_hidden,
    constraints=[obs],
)

M = 1 * N
ts2 = np.linspace(tspan[0], tspan[1], M + 1)
# Filter and smooth
m_filt, P_filt_sqr, _, _, G, d, P_back, *_ = ekf1_sqr_loop(
    mu_0_joint, Sig_0_joint, prior_joint, measure_joint, tspan, M
)
m_smooth, P_smooth_sqr = rts_sqr_smoother_loop(
    m_filt[-1], P_filt_sqr[-1], G, d, P_back, M
)
m_smooth, P_smooth_sqr = np.array(m_smooth), np.array(P_smooth_sqr)


In [None]:
# Plot results
P = np.einsum("ijk,ijl->ikl", P_smooth_sqr, P_smooth_sqr)


fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 7), sharex=True)

# State x
ax1.scatter(z_t, z, s=10, alpha=0.7, color="black", label="observations")
ax1.plot(ts2, m_smooth[:, 0], label="smoothed")
ax1.plot(ts, x_true, "k--", label="true")
ax1.fill_between(
    ts2,
    m_smooth[:, 0] - 2 * np.sqrt(P[:, 0, 0]),
    m_smooth[:, 0] + 2 * np.sqrt(P[:, 0, 0]),
    alpha=0.3,
)
ax1.set_ylabel("x(t)"), ax1.legend()

# Inferred gamma (at index D_x=3)
ax2.plot(ts2, m_smooth[:, D_x], label="smoothed gamma")
ax2.axhline(gamma, color="k", linestyle="--", label="true gamma")
ax2.fill_between(
    ts2,
    m_smooth[:, D_x] - 2 * np.sqrt(P[:, D_x, D_x]),
    m_smooth[:, D_x] + 2 * np.sqrt(P[:, D_x, D_x]),
    alpha=0.3,
)
ax2.set_xlabel("t"), ax2.set_ylabel("gamma"), ax2.legend()
ax2.set_ylim(0, 1.0)

plt.tight_layout()
plt.show()