#### imports


In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt
from numpy.linalg import cholesky

from ode_filters.priors.GMP_priors import taylor_mode_initialization, PrecondIWP
from ode_filters.measurement.measurement_models import (
    ODEconservation,
)
from ode_filters.filters.ODE_filter_loop import (
    ekf1_sqr_loop_preconditioned,
    rts_sqr_smoother_loop_preconditioned,
)

### SIR model with conserved Population

$$\begin{bmatrix} S \\ I \\ R \end{bmatrix} = \begin{bmatrix} -\beta IS \\ \beta IS - \gamma I \\ \gamma I  \end{bmatrix}, \quad \begin{bmatrix} \beta \\ \gamma \end{bmatrix} = \begin{bmatrix} 0.5 \\ 0.1  \end{bmatrix}, \quad \begin{bmatrix} S_0 \\ I_0 \\ R_0 \end{bmatrix} = \begin{bmatrix} 0.99 \\ 0.01 \\ 0 \end{bmatrix}, \quad t \in [0,100]$$

#### Conservation Law:

$$ S + I + R = P = 1 \quad \forall t$$


In [None]:
def vf(x, *, t, beta=0.5, gamma=0.1):
    return jnp.array(
        [-1 * beta * x[0] * x[1], beta * x[0] * x[1] - gamma * x[1], gamma * x[1]]
    )


x0 = np.array([0.99, 0.01, 0.0])
t0, t1 = [0.0, 100.0]
d = x0.shape[0]

In [None]:
# prior
q = 2
D = d * (q + 1)
xi = 1.0 * np.eye(d)
prior = PrecondIWP(q, d, Xi=xi)
mu_0, Sigma_0_sqr = taylor_mode_initialization(vf, x0, q)


# domain discretization (unifrom grid)
N = 100
ts, h = np.linspace(t0, t1, N + 1, retstep=True)
A_bar = prior.A()
Q_sqr_bar = cholesky(prior.Q(), upper=True)
b_bar = np.zeros(D)
T_h = prior.T(h)

#### Define linear conservation law as Ax = z

$$ \begin{bmatrix} 1 \quad 1 \quad 1 \end{bmatrix} \begin{bmatrix} S \\ I \\ R \end{bmatrix} = S + I + R $$


In [None]:
C = np.ones(d).reshape(1, -1)
pop = np.array([0.0])
k = pop.shape[0]
measure = ODEconservation(vf, C, pop, d=d, q=q)
g = measure.g
jacobian_g = measure.jacobian_g
R_h_sqr = np.eye(d + k) * 0.0

In [None]:
g(mu_0, t=0.0)
jacobian_g(mu_0, t=0.0).shape

In [None]:
# apply ODE filter
(
    m_seq,
    P_seq_sqr,
    m_seq_bar,
    P_seq_sqr_bar,
    m_pred_seq_bar,
    P_pred_seq_sqr_bar,
    G_back_seq_bar,
    d_back_seq_bar,
    P_back_seq_sqr_bar,
    mz_seq,
    Pz_seq_sqr,
) = ekf1_sqr_loop_preconditioned(
    mu_0, Sigma_0_sqr, T_h, A_bar, b_bar, Q_sqr_bar, R_h_sqr, measure, N, ts
)

# apply ODE smoother
m_smoothed, P_smoothed_sqr = rts_sqr_smoother_loop_preconditioned(
    m_seq[-1],
    P_seq_sqr[-1],
    m_seq_bar[-1],
    P_seq_sqr_bar[-1],
    G_back_seq_bar,
    d_back_seq_bar,
    P_back_seq_sqr_bar,
    N,
    T_h,
)

In [None]:
plt.figure(figsize=(10, 4), dpi=600)
plt.xlabel("t")
plt.ylabel("x(t)")
plt.plot(ts, m_seq[:, :3])
plt.plot(ts, m_seq[:, :3].sum(axis=1))
plt.show()