# Advanced Features Tutorial

This notebook demonstrates advanced ODE solving capabilities:

1. **First-Order ODEs with Hidden States** (Joint state-parameter estimation)
2. **Second-Order ODEs with Hidden States**
3. **Conservation Constraints** (Algebraic constraints)
4. **Linear Measurements** (Time-varying observations)
5. **Black-Box Measurements** (Custom observation models)
6. **Transformed Measurements** (Nonlinear state transformations)


In [None]:
import jax.numpy as np
import jax.random as jrandom
import matplotlib.pyplot as plt

from ode_filters.filters import ekf1_sqr_loop, rts_sqr_smoother_loop
from ode_filters.measurement import (
    BlackBoxMeasurement,
    Conservation,
    Measurement,
    ODEInformation,
    ODEInformationWithHidden,
    SecondOrderODEInformation,
    SecondOrderODEInformationWithHidden,
    TransformedMeasurement,
)
from ode_filters.priors import IWP, JointPrior, taylor_mode_initialization


## 1. First-Order ODEs with Hidden States

### Problem: Exponential Decay with Unknown Rate

Consider a radioactive decay problem where we observe the amount of material but don't know the decay rate:

$$\frac{dx}{dt} = -\lambda x, \quad x(0) = 1.0$$

where $\lambda$ is an **unknown parameter** we want to infer from (noisy) observations.

We model $\lambda$ as a hidden state with its own prior (e.g., constant or slowly varying).


In [None]:
# True decay rate (unknown to the solver)
lambda_true = 0.5


# Vector field with hidden parameter
def vf_decay(x, lam, *, t):
    """dx/dt = -lambda * x, where lambda is the hidden state"""
    return -lam * x


# Initial conditions
x0 = np.array([1.0])  # Initial amount
lambda0 = np.array([0.3])  # Initial guess for decay rate
tspan = [0, 10]
N = 50


### Setup Joint Prior for State and Hidden Parameter

We use `JointPrior` to combine:

- Prior for state `x` (IWP with q=2)
- Prior for parameter `λ` (IWP with q=1, since it's slowly varying)


In [None]:
# Prior for the state x (d=1, q=2)
prior_x = IWP(q=2, d=1, Xi=0.5 * np.eye(1))

# Prior for the hidden parameter lambda (d=1, q=1, smaller diffusion)
prior_lambda = IWP(q=1, d=1, Xi=0.01 * np.eye(1))

# Combine into joint prior
joint_prior = JointPrior(prior_x, prior_lambda)

print(f"State extraction matrix E0_x shape: {joint_prior.E0_x.shape}")
print(f"Hidden extraction matrix E0_hidden shape: {joint_prior.E0_hidden.shape}")
print(f"Derivative extraction matrix E1 shape: {joint_prior.E1.shape}")


### Initialize Joint State

For first-order ODEs with hidden states, we need to initialize both the state and the hidden parameter.


In [None]:
# We need to wrap the vector field for initialization
def vf_for_init(x, *, t):
    """For initialization, we use our best guess for lambda"""
    return -lambda0[0] * x


# Initialize state coefficients
mu_x, _ = taylor_mode_initialization(vf_for_init, x0, q=2)

# Initialize hidden parameter (constant, so higher derivatives are zero)
mu_lambda = np.concatenate([lambda0, np.zeros(1)])  # [lambda, d_lambda/dt]

# Combine into joint initialization
mu_0 = np.concatenate([mu_x, mu_lambda])
D_total = mu_0.shape[0]
Sigma_0_sqr = np.zeros((D_total, D_total))

print(f"Joint initial state dimension: {D_total}")
print(f"Initial state values: x={mu_0[:3]}, lambda={mu_0[3:]}")


### Generate Synthetic Data and Setup Measurement


In [None]:
# Generate true solution
ts = np.linspace(tspan[0], tspan[1], N + 1)
x_true = np.exp(-lambda_true * ts)

# Add noise to observations (observe x, not lambda)
key = jrandom.PRNGKey(42)
noise_std = 0.05
z = x_true[1:] + noise_std * jrandom.normal(key, shape=(N,))
z = z.reshape(-1, 1)
z_t = ts[1:]

plt.figure(figsize=(10, 3))
plt.scatter(z_t, z, s=10, alpha=0.5, label="Noisy observations", color="orange")
plt.plot(ts, x_true, "k--", label=f"True (λ={lambda_true})")
plt.xlabel("t"), plt.ylabel("x(t)")
plt.legend(), plt.title("Exponential Decay with Unknown Rate")
plt.show()


In [None]:
# Measurement matrix: observe only x (not lambda)
A = np.array([[1.0]])  # Extract first state component
measurement = Measurement(A, z, z_t, noise=noise_std**2)

# Create ODE measurement model with hidden states
measure = ODEInformationWithHidden(
    vf=vf_decay,
    E0=joint_prior.E0_x,  # Extract x from joint state
    E1=joint_prior.E1,  # Extract dx/dt
    E0_hidden=joint_prior.E0_hidden,  # Extract lambda
    constraints=[measurement],
)


### Run Filter and Smoother


In [None]:
# Run EKF
m_seq, P_sqr, _, _, G_back, d_back, P_back_sqr, *_ = ekf1_sqr_loop(
    mu_0, Sigma_0_sqr, joint_prior, measure, tspan, N
)

# Run RTS smoother
m_smooth, P_smooth_sqr = rts_sqr_smoother_loop(
    m_seq[-1], P_sqr[-1], G_back, d_back, P_back_sqr, N
)
m_smooth = np.array(m_smooth)
P_smooth_sqr = np.array(P_smooth_sqr)


### Visualize Results: State and Parameter Estimation


In [None]:
# Extract state x and parameter lambda
x_est = m_smooth[:, 0]  # First component is x
lambda_est = m_smooth[:, 3]  # Fourth component is lambda (after x, dx, d2x)

P_smooth = np.einsum("ijk,ijl->ikl", P_smooth_sqr, P_smooth_sqr)
x_std = np.sqrt(P_smooth[:, 0, 0])
lambda_std = np.sqrt(P_smooth[:, 3, 3])

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))

# Plot state x
ax1.scatter(z_t, z, s=10, alpha=0.5, color="orange", label="Observations")
ax1.plot(ts, x_true, "k--", linewidth=2, label=f"True")
ax1.plot(ts, x_est, "b-", linewidth=2, label="Estimated")
ax1.fill_between(ts, x_est - 2 * x_std, x_est + 2 * x_std, alpha=0.3)
ax1.set_xlabel("t"), ax1.set_ylabel("x(t)")
ax1.set_title("State Estimation")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot parameter lambda
ax2.axhline(
    lambda_true, color="k", linestyle="--", linewidth=2, label=f"True λ={lambda_true}"
)
ax2.plot(ts, lambda_est, "r-", linewidth=2, label="Estimated λ")
ax2.fill_between(
    ts, lambda_est - 2 * lambda_std, lambda_est + 2 * lambda_std, alpha=0.3, color="red"
)
ax2.set_xlabel("t"), ax2.set_ylabel("λ(t)")
ax2.set_title(f"Parameter Estimation (init: λ₀={lambda0[0]})")
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal estimate: λ = {lambda_est[-1]:.3f} ± {2 * lambda_std[-1]:.3f}")
print(f"True value: λ = {lambda_true}")
print(f"Initial guess: λ₀ = {lambda0[0]}")


## 2. Second-Order ODEs with Hidden States

### Problem: Damped Harmonic Oscillator with Unknown Damping

Consider a spring-mass-damper system where the damping coefficient is unknown:

$$\frac{d^2x}{dt^2} = -\omega^2 x - \gamma \frac{dx}{dt}$$

where $\gamma$ is the **unknown damping coefficient** we want to infer from position observations.


In [None]:
# System parameters
omega = 1.0  # Natural frequency (known)
gamma_true = 0.3  # True damping coefficient (unknown to solver)


# Vector field: d²x/dt² = f(x, dx/dt, gamma, t)
def vf_damped(x, dx, gamma, *, t):
    """Second-order ODE with hidden damping parameter"""
    return -(omega**2) * x - gamma * dx


# Initial conditions
x0_2nd = np.array([1.0])  # Initial position
dx0_2nd = np.array([0.0])  # Initial velocity
gamma0 = np.array([0.1])  # Initial guess for damping
tspan_2nd = [0, 30]
N_2nd = 100


In [None]:
# Prior for state x (q=2 for second-order)
prior_x_2nd = IWP(q=2, d=1, Xi=1.0 * np.eye(1))

# Prior for hidden parameter gamma (q=1, slowly varying)
prior_gamma = IWP(q=1, d=1, Xi=0.01 * np.eye(1))

# Joint prior
joint_prior_2nd = JointPrior(prior_x_2nd, prior_gamma)


# Vector field for initialization
def vf_for_init_2nd(x, dx, *, t):
    return -(omega**2) * x - gamma0[0] * dx


# Initialize state (x, dx, d²x)
mu_x_2nd, _ = taylor_mode_initialization(
    vf_for_init_2nd, (x0_2nd, dx0_2nd), q=2, order=2
)

# Initialize hidden parameter
mu_gamma = np.concatenate([gamma0, np.zeros(1)])

# Joint initialization
mu_0_2nd = np.concatenate([mu_x_2nd, mu_gamma])
Sigma_0_sqr_2nd = np.zeros((mu_0_2nd.shape[0], mu_0_2nd.shape[0]))

print(f"Initial state: x={mu_0_2nd[0]:.2f}, dx={mu_0_2nd[1]:.2f}, γ={mu_0_2nd[3]:.2f}")


In [None]:
# True solution (damped oscillator)
ts_2nd = np.linspace(tspan_2nd[0], tspan_2nd[1], N_2nd + 1)
omega_d = np.sqrt(omega**2 - (gamma_true / 2) ** 2)  # Damped frequency
x_true_2nd = np.exp(-gamma_true * ts_2nd / 2) * np.cos(omega_d * ts_2nd)

# Noisy observations of position
key = jrandom.PRNGKey(43)
noise_std_2nd = 0.05
z_2nd = x_true_2nd[1:] + noise_std_2nd * jrandom.normal(key, shape=(N_2nd,))
z_2nd = z_2nd.reshape(-1, 1)
z_t_2nd = ts_2nd[1:]

# Measurement: observe position only
A_2nd = np.array([[1.0]])
measurement_2nd = Measurement(A_2nd, z_2nd, z_t_2nd, noise=noise_std_2nd**2)

# Second-order ODE with hidden states
measure_2nd = SecondOrderODEInformationWithHidden(
    vf=vf_damped,
    E0=joint_prior_2nd.E0_x,
    E1=joint_prior_2nd.E1,
    E2=joint_prior_2nd.E2,
    E0_hidden=joint_prior_2nd.E0_hidden,
    constraints=[measurement_2nd],
)


In [None]:
# Run EKF and Smoother
m_seq_2nd, P_sqr_2nd, _, _, G_back_2nd, d_back_2nd, P_back_sqr_2nd, *_ = ekf1_sqr_loop(
    mu_0_2nd, Sigma_0_sqr_2nd, joint_prior_2nd, measure_2nd, tspan_2nd, N_2nd
)

m_smooth_2nd, P_smooth_sqr_2nd = rts_sqr_smoother_loop(
    m_seq_2nd[-1], P_sqr_2nd[-1], G_back_2nd, d_back_2nd, P_back_sqr_2nd, N_2nd
)
m_smooth_2nd = np.array(m_smooth_2nd)
P_smooth_sqr_2nd = np.array(P_smooth_sqr_2nd)

# Extract estimates
x_est_2nd = m_smooth_2nd[:, 0]
gamma_est = m_smooth_2nd[:, 3]

P_smooth_2nd = np.einsum("ijk,ijl->ikl", P_smooth_sqr_2nd, P_smooth_sqr_2nd)
x_std_2nd = np.sqrt(P_smooth_2nd[:, 0, 0])
gamma_std = np.sqrt(P_smooth_2nd[:, 3, 3])

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))

# Plot position
ax1.scatter(z_t_2nd, z_2nd, s=10, alpha=0.5, color="orange", label="Observations")
ax1.plot(ts_2nd, x_true_2nd, "k--", linewidth=2, label="True")
ax1.plot(ts_2nd, x_est_2nd, "b-", linewidth=2, label="Estimated")
ax1.fill_between(
    ts_2nd, x_est_2nd - 2 * x_std_2nd, x_est_2nd + 2 * x_std_2nd, alpha=0.3
)
ax1.set_xlabel("t"), ax1.set_ylabel("x(t)")
ax1.set_title("Position Estimation")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot damping parameter
ax2.axhline(
    gamma_true, color="k", linestyle="--", linewidth=2, label=f"True γ={gamma_true}"
)
ax2.plot(ts_2nd, gamma_est, "r-", linewidth=2, label="Estimated γ")
ax2.fill_between(
    ts_2nd, gamma_est - 2 * gamma_std, gamma_est + 2 * gamma_std, alpha=0.3, color="red"
)
ax2.set_xlabel("t"), ax2.set_ylabel("γ(t)")
ax2.set_title(f"Damping Coefficient (init: γ₀={gamma0[0]})")
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal estimate: γ = {gamma_est[-1]:.3f} ± {2 * gamma_std[-1]:.3f}")
print(f"True value: γ = {gamma_true}")
print(f"Initial guess: γ₀ = {gamma0[0]}")


## 3. Conservation Constraints

The SIR epidemiological model has a natural conservation law: $S + I + R = N_{\text{pop}}$ (constant total population).

We can enforce this as a hard constraint using the `Conservation` class.


In [None]:
# SIR model parameters
beta_sir = 0.5
gamma_sir = 0.1


def vf_sir(x, *, t):
    """SIR model: x = [S, I, R]"""
    return np.array(
        [
            -beta_sir * x[0] * x[1],
            beta_sir * x[0] * x[1] - gamma_sir * x[1],
            gamma_sir * x[1],
        ]
    )


x0_sir = np.array([0.99, 0.01, 0.0])
tspan_sir = [0, 100]
N_sir = 100

# Prior
prior_sir = IWP(q=2, d=3, Xi=1.0 * np.eye(3))
mu_0_sir, Sigma_0_sqr_sir = taylor_mode_initialization(vf_sir, x0_sir, q=2)

# Conservation constraint: S + I + R = 1
A_conservation = np.array([[1.0, 1.0, 1.0]])
p_conservation = np.array([1.0])
conservation = Conservation(A_conservation, p_conservation)

# Create measurement model with conservation
measure_sir = ODEInformation(
    vf=vf_sir, E0=prior_sir.E0, E1=prior_sir.E1, constraints=[conservation]
)

# Run filter and smoother
m_seq_sir, P_sqr_sir, _, _, G_back_sir, d_back_sir, P_back_sqr_sir, *_ = ekf1_sqr_loop(
    mu_0_sir, Sigma_0_sqr_sir, prior_sir, measure_sir, tspan_sir, N_sir
)
m_smooth_sir, _ = rts_sqr_smoother_loop(
    m_seq_sir[-1], P_sqr_sir[-1], G_back_sir, d_back_sir, P_back_sqr_sir, N_sir
)
m_smooth_sir = np.array(m_smooth_sir)

ts_sir = np.linspace(tspan_sir[0], tspan_sir[1], N_sir + 1)
S_est = m_smooth_sir[:, 0]
I_est = m_smooth_sir[:, 1]
R_est = m_smooth_sir[:, 2]
total = S_est + I_est + R_est

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))

# Plot SIR trajectories
ax1.plot(ts_sir, S_est, label="S (Susceptible)", linewidth=2)
ax1.plot(ts_sir, I_est, label="I (Infected)", linewidth=2)
ax1.plot(ts_sir, R_est, label="R (Recovered)", linewidth=2)
ax1.set_xlabel("Time (days)"), ax1.set_ylabel("Population fraction")
ax1.set_title("SIR Model with Conservation Constraint")
ax1.legend()
ax1.grid(True, alpha=0.3)

# Check conservation
ax2.plot(ts_sir, total, "b-", linewidth=2, label="S + I + R")
ax2.axhline(1.0, color="k", linestyle="--", linewidth=2, label="Expected (1.0)")
ax2.set_xlabel("Time (days)"), ax2.set_ylabel("Total population")
ax2.set_title("Conservation Law Verification")
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0.999, 1.001])

plt.tight_layout()
plt.show()

print(f"\nConservation error: max = {np.max(np.abs(total - 1.0)):.2e}")
print(f"Conservation error: mean = {np.mean(np.abs(total - 1.0)):.2e}")


In [None]:
# Lotka-Volterra parameters
alpha, beta, delta, gamma = 2 / 3, 4 / 3, 1.0, 1.0


def vf_lv(x, *, t):
    """Lotka-Volterra: x = [prey, predator]"""
    return np.array(
        [alpha * x[0] - beta * x[0] * x[1], delta * x[0] * x[1] - gamma * x[1]]
    )


x0_lv = np.array([1.0, 1.0])
tspan_lv = [0, 30]
N_lv = 60

# Get reference solution (without measurements)
prior_lv_ref = IWP(q=2, d=2, Xi=0.5 * np.eye(2))
mu_0_lv_ref, Sigma_0_sqr_lv_ref = taylor_mode_initialization(vf_lv, x0_lv, q=2)
measure_lv_ref = ODEInformation(vf_lv, prior_lv_ref.E0, prior_lv_ref.E1)
m_ref, *_ = ekf1_sqr_loop(
    mu_0_lv_ref, Sigma_0_sqr_lv_ref, prior_lv_ref, measure_lv_ref, tspan_lv, N_lv
)
m_ref = np.array(m_ref)

ts_lv = np.linspace(tspan_lv[0], tspan_lv[1], N_lv + 1)

# Create sparse noisy observations of predator (every 5th timestep)
obs_indices = np.arange(5, N_lv + 1, 5)
z_lv = m_ref[obs_indices, 1] + 0.1 * jrandom.normal(
    jrandom.PRNGKey(44), shape=(len(obs_indices),)
)
z_lv = z_lv.reshape(-1, 1)
z_t_lv = ts_lv[obs_indices]

# Solve with measurements
prior_lv = IWP(q=2, d=2, Xi=1.0 * np.eye(2))
mu_0_lv, Sigma_0_sqr_lv = taylor_mode_initialization(vf_lv, x0_lv, q=2)

# Measurement matrix: observe only predator (second component)
A_lv = np.array([[0.0, 1.0]])
measurement_lv = Measurement(A_lv, z_lv, z_t_lv, noise=0.01)

measure_lv = ODEInformation(
    vf=vf_lv, E0=prior_lv.E0, E1=prior_lv.E1, constraints=[measurement_lv]
)

m_seq_lv, P_sqr_lv, _, _, G_back_lv, d_back_lv, P_back_sqr_lv, *_ = ekf1_sqr_loop(
    mu_0_lv, Sigma_0_sqr_lv, prior_lv, measure_lv, tspan_lv, N_lv
)
m_smooth_lv, _ = rts_sqr_smoother_loop(
    m_seq_lv[-1], P_sqr_lv[-1], G_back_lv, d_back_lv, P_back_sqr_lv, N_lv
)
m_smooth_lv = np.array(m_smooth_lv)

# Plot
plt.figure(figsize=(10, 4))
plt.plot(ts_lv, m_ref[:, 0], "b--", label="Prey (reference)", linewidth=2, alpha=0.5)
plt.plot(ts_lv, m_smooth_lv[:, 0], "b-", label="Prey (with measurements)", linewidth=2)
plt.plot(
    ts_lv, m_ref[:, 1], "r--", label="Predator (reference)", linewidth=2, alpha=0.5
)
plt.plot(
    ts_lv, m_smooth_lv[:, 1], "r-", label="Predator (with measurements)", linewidth=2
)
plt.scatter(z_t_lv, z_lv, s=50, color="orange", zorder=3, label="Observations")
plt.xlabel("t"), plt.ylabel("Population")
plt.title("Lotka-Volterra with Sparse Predator Observations")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()


## Summary

This notebook demonstrated advanced ODE solving capabilities:

1. ✅ **First-Order ODEs with Hidden States**: Joint inference of state and parameters (decay rate)
2. ✅ **Second-Order ODEs with Hidden States**: Estimating damping coefficient from observations
3. ✅ **Conservation Constraints**: Enforcing algebraic constraints (SIR population conservation)
4. ✅ **Linear Measurements**: Sparse, noisy observations at discrete times (Lotka-Volterra)

### Key Classes and Functions

**For Hidden States:**

- `JointPrior(prior_x, prior_u)`: Combines independent priors for state and hidden parameters
- `ODEInformationWithHidden(vf, E0, E1, E0_hidden, ...)`: First-order ODE with vector field `vf(x, u, *, t)`
- `SecondOrderODEInformationWithHidden(vf, E0, E1, E2, E0_hidden, ...)`: Second-order ODE with `vf(x, dx, u, *, t)`

**For Constraints:**

- `Conservation(A, p)`: Hard algebraic constraint `A·x = p` (always active)
- `Measurement(A, z, z_t, noise)`: Time-varying linear observations `A·x = z[t]` (active at specified times)

**Other Advanced Features** (not shown here, see documentation):

- `BlackBoxMeasurement`: Custom observation models with autodiff Jacobians
- `TransformedMeasurement`: Nonlinear state transformations before measurement (e.g., softmax)
- `MaternPrior`: Gaussian process priors with specified length scales
- `PrecondIWP` and preconditioned filters: For numerical stability

All these features can be **combined** for complex real-world problems!
