In [None]:
import jax.numpy as np
import jax.random as jrandom
import matplotlib.pyplot as plt

from ode_filters.filters import ekf1_sqr_loop, rts_sqr_smoother_loop
from ode_filters.measurement import Measurement, SecondOrderODEInformation
from ode_filters.priors import IWP, JointPrior, taylor_mode_initialization

In [None]:
# Damped harmonic oscillator: d²x/dt² = -ω²x - γ(dx/dt)
# Parameters
omega = 1.0  # angular frequency
zeta = 0.1  # damping ratio
gamma = 2 * zeta * omega  # damping coefficient
x0 = 1.0  # initial position
v0 = 0.0  # initial velocity

# Initial conditions
u0 = np.array([x0])  # position
du0 = np.array([v0])  # velocity


def vf_2(y, dy, t):
    """Damped harmonic oscillator: d²x/dt² = -ω²x - γ(dx/dt)"""
    return -(omega**2) * y - gamma * dy


x0 = np.array([1.0])
dx0 = np.array([0.0])  # initial velocity
tspan = [0, 30]
d = x0.shape[0]

q = 2
xi = 0.5 * np.eye(d)
prior = IWP(q, d, Xi=xi)
mu_0, Sigma_0_sqr = taylor_mode_initialization(vf_2, (x0, dx0), q, order=2)


In [None]:
measure = SecondOrderODEInformation(vf_2, prior.E0, prior.E1, prior.E2)
N = 100

(
    m_seq,
    P_seq_sqr,
    m_pred_seq,
    P_pred_seq_sqr,
    G_back_seq,
    d_back_seq,
    P_back_seq_sqr,
    mz_seq,
    Pz_seq_sqr,
) = ekf1_sqr_loop(mu_0, Sigma_0_sqr, prior, measure, tspan, N)


m_smoothed, P_smoothed_sqr = rts_sqr_smoother_loop(
    m_seq[-1], P_seq_sqr[-1], G_back_seq, d_back_seq, P_back_seq_sqr, N
)

In [None]:
m_seq = np.array(m_smoothed)
P_seq_sqr = np.array(P_smoothed_sqr)
ts = np.linspace(tspan[0], tspan[1], N + 1)
plt.figure(figsize=(10, 4), dpi=600)
plt.xlabel("t")
plt.ylabel("x(t)")
plt.plot(ts, m_seq[:, 0], label="filtered estimate")
plt.plot(
    ts,
    np.cos(np.sqrt(1 - zeta**2) * omega * ts) * np.exp(-0.5 * gamma * ts),
    label="true",
    color="black",
    linestyle="--",
)
P_seq = np.matmul(np.transpose(P_seq_sqr, (0, 2, 1)), P_seq_sqr)
margin = 2 * np.sqrt(P_seq[:, 0, 0])
plt.fill_between(
    ts,
    m_seq[:, 0] - margin,
    m_seq[:, 0] + margin,
    alpha=0.5,
    label=r"2 $\sigma$ interval",
)
plt.legend()
plt.show()

In [None]:
# Generate noisy observations from damped harmonic oscillator
key = jrandom.PRNGKey(42)
noise_std = 0.01

# True solution: x(t) = cos(ω_d * t) * exp(-ζωt) where ω_d = ω√(1-ζ²)
omega_d = np.sqrt(1 - zeta**2) * omega  # damped frequency
x_true = np.cos(omega_d * ts[1:]) * np.exp(-zeta * omega * ts[1:])

# Add noise
noise = noise_std * jrandom.normal(key, shape=x_true.shape)
x_obs = x_true + noise

# Format for measurement model
z = x_obs.reshape(-1, 1)  # Shape: (N, 1)
z_t = ts[1:]  # Observation times (skip t=0)
A_obs = np.array([[1.0, 0.0]])  # Observe x from joint state [x, gamma]

plt.figure(figsize=(10, 3), dpi=150)
plt.scatter(z_t, z, s=5, alpha=0.5, label="noisy observations")
plt.plot(
    ts, np.cos(omega_d * ts) * np.exp(-zeta * omega * ts), "k--", label="true solution"
)
plt.xlabel("t")
plt.ylabel("x(t)")
plt.legend()
plt.title("Noisy observations of damped harmonic oscillator")
plt.show()


In [None]:
q_x, d_x = 2, 1
q_gamma, d_gamma = 2, 1
xi_x = 0.5 * np.eye(d_x)
xi_gamma = 1e-4 * np.eye(d_gamma)
prior_x = IWP(q_x, d_x, Xi=xi_x)
# prior_gamma = MaternPrior(q_gamma, d_gamma, Xi=xi_gamma, length_scale=1.0)
prior_gamma = IWP(q_gamma, d_gamma, Xi=xi_gamma)
prior_joint = JointPrior(prior_x, prior_gamma)

# True damping
gamma_true = gamma

# Initial guess for damping
gamma0 = np.array([1.2])  # adjust as desired


# Initialize x using Taylor mode with damping guess
def vf_x_init(x, dx, *, t):
    return -(omega**2) * x - gamma0 * dx


mu_0_x, Sigma_0_sqr_x = taylor_mode_initialization(vf_x_init, (x0, dx0), q_x, order=2)

# Initialize damping state (value + derivatives)
D_gamma = (q_gamma + 1) * d_gamma
mu_0_gamma = np.concatenate([gamma0, np.zeros(q_gamma * d_gamma)])
Sigma_0_sqr_gamma = np.eye(D_gamma) * 0.01

# Combine initial state and covariance
mu_0_joint = np.concatenate([mu_0_x, mu_0_gamma])
D_x = (q_x + 1) * d_x
zeros = np.zeros((D_x, D_gamma))
Sigma_0_sqr_joint = np.block([[Sigma_0_sqr_x, zeros], [zeros.T, Sigma_0_sqr_gamma]])

# Joint vector field for second-order system
# E0 extracts [x, gamma]; E1 extracts dx; E2 extracts d²x


def vf_joint(pos, vel, *, t):
    x = pos[0]
    gamma_est = pos[-1]
    dx = vel[0]
    return -(omega**2) * x - gamma_est * dx


# Measurement model with observations for second-order ODE using composable constraints
obs_constraint = Measurement(A=A_obs, z=z, z_t=z_t, noise=noise_std**2)
measure_joint = SecondOrderODEInformation(
    vf_joint, prior_joint.E0, prior_joint.E1, prior_joint.E2, constraints=[obs_constraint]
)

# Run filter with observations
(
    m_seq_joint,
    P_seq_sqr_joint,
    m_pred_seq_joint,
    P_pred_seq_sqr_joint,
    G_back_seq_joint,
    d_back_seq_joint,
    P_back_seq_sqr_joint,
    mz_seq_joint,
    Pz_seq_sqr_joint,
) = ekf1_sqr_loop(mu_0_joint, Sigma_0_sqr_joint, prior_joint, measure_joint, tspan, N)

# Store filtered trajectories
m_seq_joint_filt = np.array(m_seq_joint)
P_seq_sqr_joint_filt = np.array(P_seq_sqr_joint)

# Smooth
m_smoothed_joint, P_smoothed_sqr_joint = rts_sqr_smoother_loop(
    m_seq_joint[-1],
    P_seq_sqr_joint[-1],
    G_back_seq_joint,
    d_back_seq_joint,
    P_back_seq_sqr_joint,
    N,
)

m_seq_joint_smooth = np.array(m_smoothed_joint)
P_seq_sqr_joint_smooth = np.array(P_smoothed_sqr_joint)


In [None]:
# Plot state x (smoothed)
plt.figure(figsize=(10, 4), dpi=150)
plt.xlabel("t")
plt.ylabel("x(t)")
plt.scatter(z_t, z, s=5, alpha=0.3, label="observations", color="gray")
plt.plot(ts, m_seq_joint_smooth[:, 0], label="smoothed x", color="C0")
plt.plot(ts, np.cos(omega_d * ts) * np.exp(-zeta * omega * ts), "k--", label="true x")
P_seq_smooth = np.matmul(
    np.transpose(P_seq_sqr_joint_smooth, (0, 2, 1)), P_seq_sqr_joint_smooth
)
margin_smooth = 2 * np.sqrt(P_seq_smooth[:, 0, 0])
plt.fill_between(
    ts,
    m_seq_joint_smooth[:, 0] - margin_smooth,
    m_seq_joint_smooth[:, 0] + margin_smooth,
    alpha=0.3,
    label=r"2$\sigma$ interval (smooth)",
    color="C0",
)
plt.legend()
plt.title("Smoothed state with joint damping inference")
plt.show()

# Plot inferred damping: filtered vs smoothed
plt.figure(figsize=(10, 4), dpi=150)
plt.xlabel("t")
plt.ylabel("gamma")
plt.plot(ts, m_seq_joint_filt[:, -1], label="filtered gamma", color="C2", alpha=0.7)
plt.plot(ts, m_seq_joint_smooth[:, -1], label="smoothed gamma", color="C1")
plt.axhline(gamma_true, color="k", linestyle="--", label="true gamma")
P_seq_filt = np.matmul(
    np.transpose(P_seq_sqr_joint_filt, (0, 2, 1)), P_seq_sqr_joint_filt
)
margin_gamma_filt = 2 * np.sqrt(P_seq_filt[:, -1, -1])
margin_gamma_smooth = 2 * np.sqrt(P_seq_smooth[:, -1, -1])
plt.fill_between(
    ts,
    m_seq_joint_filt[:, -1] - margin_gamma_filt,
    m_seq_joint_filt[:, -1] + margin_gamma_filt,
    alpha=0.2,
    label=r"2$\sigma$ interval (filtered)",
    color="C2",
)
plt.fill_between(
    ts,
    m_seq_joint_smooth[:, -1] - margin_gamma_smooth,
    m_seq_joint_smooth[:, -1] + margin_gamma_smooth,
    alpha=0.25,
    label=r"2$\sigma$ interval (smoothed)",
    color="C1",
)
plt.legend()
plt.title("Damping coefficient inference (filtered vs smoothed)")
plt.show()

In [None]:
ts[0]

In [None]:
def f_test(x1, x2):
    return 2*x1 + 3*x2 + x1*x2


def f_test_jac(x1, x2):
    return np.array([2 + x2, 3 + x1])


def f_test_hess(x1, x2):
    return np.array([[0, 1], [1, 0]])





In [None]:
import jax
jacobi =jax.jacfwd(f_test, argnums=(0, 1))
print(jacobi(1., 2.))
print(f_test_jac(1., 2.))
