#### imports


In [None]:
import numpy as np
import jax.numpy as jnp
import matplotlib.pyplot as plt

from ode_filters.priors.GMP_priors import taylor_mode_initialization, PrecondIWP
from ode_filters.measurement.measurement_models import (
    ODEconservation,
    ODEmeasurement,
    ODEconservationmeasurement,
)
from ode_filters.filters.ODE_filter_loop import (
    ekf1_sqr_loop_preconditioned,
    rts_sqr_smoother_loop_preconditioned,
)

### SIR model with conserved Population

$$\begin{bmatrix} S \\ I \\ R \end{bmatrix} = \begin{bmatrix} -\beta IS \\ \beta IS - \gamma I \\ \gamma I  \end{bmatrix}, \quad \begin{bmatrix} \beta \\ \gamma \end{bmatrix} = \begin{bmatrix} 0.5 \\ 0.1  \end{bmatrix}, \quad \begin{bmatrix} S_0 \\ I_0 \\ R_0 \end{bmatrix} = \begin{bmatrix} 0.99 \\ 0.01 \\ 0 \end{bmatrix}, \quad t \in [0,100]$$

#### Conservation Law:

$$ S + I + R = P = 1 \quad \forall t$$


In [None]:
def vf(x, *, t, beta=0.5, gamma=0.1):
    return jnp.array(
        [-1 * beta * x[0] * x[1], beta * x[0] * x[1] - gamma * x[1], gamma * x[1]]
    )


x0 = np.array([0.99, 0.01, 0.0])
tspan = [0.0, 100.0]
d = x0.shape[0]

In [None]:
# prior
q = 2
D = d * (q + 1)
xi = 1.0 * np.eye(d)
prior = PrecondIWP(q, d, Xi=xi)
mu_0, Sigma_0_sqr = taylor_mode_initialization(vf, x0, q)


# domain discretization (unifrom grid)
N = 100
ts = np.linspace(tspan[0], tspan[1], N + 1)

#### Define linear conservation law as Ax = z

$$ \begin{bmatrix} 1 \quad 1 \quad 1 \end{bmatrix} \begin{bmatrix} S \\ I \\ R \end{bmatrix} = S + I + R $$


In [None]:
C = np.ones(d).reshape(1, -1)
pop = np.array([1.0])
measure = ODEconservation(vf, C, pop, d=d, q=q)

In [None]:
measure.get_noise(t=0.0)

In [None]:
# apply ODE filter
(
    m_seq,
    P_seq_sqr,
    m_seq_bar,
    P_seq_sqr_bar,
    m_pred_seq_bar,
    P_pred_seq_sqr_bar,
    G_back_seq_bar,
    d_back_seq_bar,
    P_back_seq_sqr_bar,
    mz_seq,
    Pz_seq_sqr,
    T_h,
) = ekf1_sqr_loop_preconditioned(mu_0, Sigma_0_sqr, prior, measure, tspan, N)

# apply ODE smoother
m_smoothed, P_smoothed_sqr = rts_sqr_smoother_loop_preconditioned(
    m_seq[-1],
    P_seq_sqr[-1],
    m_seq_bar[-1],
    P_seq_sqr_bar[-1],
    G_back_seq_bar,
    d_back_seq_bar,
    P_back_seq_sqr_bar,
    N,
    T_h,
)

In [None]:
m_seq = np.array(m_seq)
plt.figure(figsize=(10, 4), dpi=600)
plt.xlabel("t")
plt.ylabel("x(t)")
plt.plot(ts, m_seq[:, :3])
plt.plot(ts, m_seq[:, :3].sum(axis=1))
plt.show()

#### Specific Measurements


In [None]:
# Data loading
data = jnp.load("../../../pde_filters/data_info.npz")
z_seq = data["sird_data"][:, 1] / 1000

In [None]:
x0 = np.array([0.99, 0.1, 0.0])
tspan = [0.0, 100.0]
d = x0.shape[0]

In [None]:
# prior
q = 2
xi = 1.0 * np.eye(d)
prior = PrecondIWP(q, d, Xi=xi)
mu_0, Sigma_0_sqr = taylor_mode_initialization(vf, x0, q)

N = z_seq.shape[0]
ts = np.linspace(
    tspan[0],
    tspan[1],
    N + 1,
)
ts2 = np.linspace(
    tspan[0],
    tspan[1],
    2 * (N + 1) - 1,
)

eye = np.eye(d)
A_measure = eye[1].reshape(1, -1)
z_seq = z_seq.reshape(-1, 1)

In [None]:
print(ts[:10])
print(ts2[:10])
print(ts[1] == ts2[2])

In [None]:
# measure = ODEmeasurement(vf, A_measure, z_seq, ts[1:], d=d, q=q)
measure = ODEconservationmeasurement(vf, A_measure, z_seq, ts[1:], C, pop, d=d, q=q)

In [None]:
print(measure.get_noise(t=ts[1]))
print(measure.get_noise(t=(ts[2] + 0.12)))

In [None]:
# apply ODE filter
(
    m_seq,
    P_seq_sqr,
    m_seq_bar,
    P_seq_sqr_bar,
    m_pred_seq_bar,
    P_pred_seq_sqr_bar,
    G_back_seq_bar,
    d_back_seq_bar,
    P_back_seq_sqr_bar,
    mz_seq,
    Pz_seq_sqr,
    T_h,
) = ekf1_sqr_loop_preconditioned(
    mu_0, Sigma_0_sqr, prior, measure, tspan, 2 * (N + 1) - 2
)

# apply ODE smoother
m_smoothed, P_smoothed_sqr = rts_sqr_smoother_loop_preconditioned(
    m_seq[-1],
    P_seq_sqr[-1],
    m_seq_bar[-1],
    P_seq_sqr_bar[-1],
    G_back_seq_bar,
    d_back_seq_bar,
    P_back_seq_sqr_bar,
    N,
    T_h,
)

In [None]:
print(np.linspace(0, 1.2, 6))
print(np.linspace(0, 1.2, 11))

In [None]:
m_seq = np.array(m_seq)
plt.figure(figsize=(10, 4), dpi=600)
plt.xlabel("t")
plt.ylabel("x(t)")
plt.plot(ts2[10:], m_seq[10:, 1])
plt.scatter(ts[1:], z_seq, s=1, alpha=0.4, color="orange", zorder=2)
plt.show()

In [None]:
print(mz_seq[33 + 0])
print(mz_seq[33 + 1])
print(mz_seq[33 + 2])
print(mz_seq[33 + 3])