In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tueplots import figsizes
from tueplots import axes
from numpy.linalg import cholesky

from ode_filters.GMP_priors import IWP

#plitting specifications
plt.rcParams.update(axes.lines())
plt.rcParams.update({"figure.dpi": 600})
figsize_config = figsizes.aaai2024_half(nrows=2, ncols=1)
plt.rcParams.update(figsize_config)
color_cycler = plt.rcParams['axes.prop_cycle']
colors = color_cycler.by_key()['color']

### Problem

In [16]:
#define the problem itself
#define IVP -> ODE + IV
def vf(x):
    return x*(x-1)

t0 = 0
t1 = 1
x0 = np.array([0.01])
d = x0.shape[0]

### Model

In [20]:
#prior 
q = 3
prior = IWP(q, d)

# domain discretization
N = 11 # or vice verca define h -> N = (t1-t0)/h + 1
ts, h = np.linspace(t0, t1, N, retstep=True)
A_h = prior.A(h)
Q_h = prior.Q(h)

In [None]:
# parameters
mu_0, Sigma_0 = taylor_mode_initialization(vf, x0, q)
b_h = np.zeros(d*(q+1))
Sigma_0_sqr = cholesky(Sigma_0, upper=True)
Q_h_sqr = cholesky(Q_h, upper=True)

In [None]:
from jax import P


k = ... #number of extra observations totla summed dimension
R_h_sqr = np.zeros(d+k)

def g(X, P=1):
    g_ODE = E1@X - vf(E0@X)
    g_conserved = np.sum(E0 @ X) - P
    g_meas


jacobian_g = ...
z_sequence = ...

In [None]:


#apply ODE filter
results = kf1_sqr_loop(
    mu_0, Sigma_0_sqr, A_h, b_h, Q_h_sqr, R_h_sqr, g, jacobian_g, z_sequence, N)

NameError: name 'GMP_priors' is not defined

In [1]:
import jax
import jax.numpy as jnp
import jax.experimental.jet
import numpy as np

In [2]:
def vf(y):  # noqa: ARG001
    """Evaluate the dynamics of the logistic ODE."""
    return y * (1 - y)

y0 = jnp.array([0.01])

In [3]:
vf(y0)
vf_grad = jax.grad(vf)
vf_jacobi = jax.jacfwd(vf)
print(vf_grad(0.01))
print(vf_jacobi(y0))

0.98
[[0.98]]


In [4]:
# Compute the Jacobian of vf using jax.jacfwd
mu0 = [y0, vf(y0)]
jacobi1 = lambda a: jax.jacfwd(vf)(a)*vf(a)
x_next = jacobi1(y0)
mu0.append(x_next.squeeze())
jacobi2 = lambda a: jax.jacfwd(jacobi1)(a)*vf(a)
x_next = jacobi2(y0)
mu0.append(x_next.squeeze())
print(mu0)

[Array([0.01], dtype=float32), Array([0.0099], dtype=float32), Array(0.009702, dtype=float32), Array(0.00931194, dtype=float32)]


In [5]:
def _subsets(x, /, n):
    """Compute staggered subsets.

    See example below.

    Examples
    --------
    >>> a = (1, 2, 3, 4, 5)
    >>> print(_subsets(a, n=1))
    [(1, 2, 3, 4, 5)]
    >>> print(_subsets(a, n=2))
    [(1, 2, 3, 4), (2, 3, 4, 5)]
    >>> print(_subsets(a, n=3))
    [(1, 2, 3), (2, 3, 4), (3, 4, 5)]
    """

    def mask(i):
        return None if i == 0 else i

    return [x[mask(k) : mask(k + 1 - n)] for k in range(n)]

In [6]:
# Initial Taylor series (u_0, u_1, ..., u_k)
inits = y0
primals = vf(y0)
tcoeffs = [*inits, primals]
num_arguments = 1
for _ in range(4):
    series = _subsets(tcoeffs[1:], num_arguments)
    p, s_new = jax.experimental.jet.jet(vf, primals=inits, series=series)
    tcoeffs = [*inits, p, *s_new]
    print(series)



[[Array([0.0099], dtype=float32)]]
[[Array(0.0099, dtype=float32), Array([0.009702], dtype=float32)]]
[[Array(0.0099, dtype=float32), Array(0.009702, dtype=float32), Array([0.00931194], dtype=float32)]]
[[Array(0.0099, dtype=float32), Array(0.009702, dtype=float32), Array(0.00931194, dtype=float32), Array([0.0085494], dtype=float32)]]


In [7]:
# Initial Taylor series (u_0, u_1, ..., u_k)
inits = y0
tcoeffs = [inits, vf(y0)]
for _ in range(3):
    series = [tcoeffs[1:]]
    p, s_new = jax.experimental.jet.jet(vf, primals=inits, series=series)
    coeffs = jax.tree.map(lambda c: jnp.atleast_1d(c), s_new)
    tcoeffs = [inits, p, *coeffs]            # everything has â‰¥1 dims

flat = jnp.concatenate([jnp.atleast_1d(x) for x in series[0]])
flat = jnp.concat([y0, flat])
print(flat)

[0.01       0.0099     0.009702   0.00931194]


In [8]:
def flatten_coeffs(coeffs):
    # Works for scalars, vectors, or nested pytree leaves
    leaves = jax.tree_util.tree_leaves(coeffs)
    return jnp.concatenate([jnp.ravel(arr) for arr in leaves])

inits = y0
tcoeffs = [inits, vf(inits)]

for _ in range(3):
    p, s_new = jax.experimental.jet.jet(
        vf,
        primals=inits,
        series=[tcoeffs[1:]],  # pass only higher-order terms
    )
    tcoeffs = [inits, p, *s_new]

flat = flatten_coeffs(tcoeffs)
print(flat)

[0.01       0.0099     0.009702   0.00931194 0.0085494 ]


In [9]:
X_0 = np.array([y0, vf(y0), (1-y0)*(1-2*y0)*y0, ((1-2*y0)**2 - 2*y0*(1-y0))*y0*(1-y0)])
print(X_0.flatten())

[0.01       0.0099     0.009702   0.00931194]


In [10]:
from probdiffeq import taylor, ivpsolvers 
# this computes taylor expansion of vf, up to order num, arround u0, and return the coefficient
tcoeffs = taylor.odejet_padded_scan(vf, (y0,), num=4)
print((tcoeffs))

tcoeffs2 = taylor.odejet_unroll(vf, (y0,), 4)
print(tcoeffs2)

[Array([0.01], dtype=float32), Array([0.0099], dtype=float32), Array([0.009702], dtype=float32), Array([0.00931194], dtype=float32), Array([0.0085494], dtype=float32)]
[Array([0.01], dtype=float32), Array([0.0099], dtype=float32), Array([0.009702], dtype=float32), Array([0.00931194], dtype=float32), Array([0.0085494], dtype=float32)]


In [13]:
def taylor_mode_initialization2(vf, inits, q: int) -> jnp.ndarray:
    """Return flattened Taylor-mode coefficients produced via JAX Jet.

    Parameters
    ----------
    vf : callable
        Vector field whose Taylor coefficients are required.
    inits : array-like
        Initial value around which the expansion takes place.
    q : int
        Number of higher-order coefficients to compute.

    Returns
    -------
    jnp.ndarray
        Concatenated Taylor coefficients (including the initial value).
    """

    if q < 0:
        raise ValueError("q must be a non-negative integer.")

    base_state = jnp.asarray(inits)
    coefficients: list[jnp.ndarray] = [base_state]

    if q == 0:
        leaves = jax.tree_util.tree_leaves(coefficients)
        return jnp.concatenate([jnp.ravel(arr) for arr in leaves])

    first_term = jnp.asarray(vf(base_state))
    coefficients.append(first_term)

    if q == 1:
        leaves = jax.tree_util.tree_leaves(coefficients)
        return jnp.concatenate([jnp.ravel(arr) for arr in leaves])

    series_terms: list[jnp.ndarray] = [first_term]

    for _ in range(1, q):
        primals_out, series_out = jax.experimental.jet.jet(
            vf,
            primals=(base_state,),
            series=(tuple(series_terms),),
        )
        # primals_out should match first_term; we keep series_terms authoritative.
        new_term = jnp.asarray(series_out[0][-1])
        series_terms.append(new_term)
        coefficients.append(new_term)

    leaves = jax.tree_util.tree_leaves(coefficients)
    return jnp.concatenate([jnp.ravel(arr) for arr in leaves])

In [14]:
#from ode_filters.GMP_priors import taylor_mode_initialization

taylor_mode_initialization2(vf, y0, 3)

Array([0.01    , 0.0099  , 0.009702, 0.009702], dtype=float32)

In [16]:
from probdiffeq import taylor, ivpsolvers 
#compare init for higher order inputs
def vf(y):
    """Evaluate the Lotka-Volterra vector field."""
    y0, y1 = y[0], y[1]

    y0_new = 0.5 * y0 - 0.05 * y0 * y1
    y1_new = -0.5 * y1 + 0.05 * y0 * y1
    return jnp.asarray([y0_new, y1_new])

y0 = jnp.array([1,2])
tcoeffs = taylor.odejet_padded_scan(vf, (y0,), num=3)
print((tcoeffs))

[Array([1, 2], dtype=int32), Array([ 0.4, -0.9], dtype=float32), Array([0.20500001, 0.445     ], dtype=float32), Array([ 0.09574999, -0.21574998], dtype=float32)]


In [17]:
taylor_mode_initialization(vf, y0, 4)

AssertionError: 

In [16]:
# Initial Taylor series (u_0, u_1, ..., u_k)
inits = y0
primals = vf(y0)
tcoeffs = [*inits, *primals]
num_arguments = 2
for _ in range(4):
    series = _subsets(tcoeffs[1:], num_arguments)
    p, s_new = jax.experimental.jet.jet(vf, primals=inits, series=series)
    tcoeffs = [*inits, p, *s_new]
    print(series)

[[], []]
[[], []]
[[], []]
[[], []]


In [17]:
inits = y0
tcoeffs = [inits, vf(inits)]
print(tcoeffs[1:][0])
print(inits)
print(tcoeffs[1:])

for _ in range(3):
    p, s_new = jax.experimental.jet.jet(
        vf,
        primals=inits,
        series=tcoeffs[1:],  # pass only higher-order terms
    )
    tcoeffs = [inits, p, *s_new]
    print(tcoeffs)

[0.0099]
[0.01]
[Array([0.0099], dtype=float32)]
[Array([0.01], dtype=float32), Array(0.0099, dtype=float32), Array(0.009702, dtype=float32)]


TypeError: len() of unsized object