In [1]:
"""Demonstrate how probabilistic solvers work via conditioning on constraints."""

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from diffeqzoo import backend

from probdiffeq import ivpsolve, ivpsolvers, stats, taylor

### Quick Start quide

In [None]:
"""Solve the logistic equation."""

@jax.jit
def vf(y, *, t):  # noqa: ARG001
    """Evaluate the dynamics of the logistic ODE."""
    return 2 * y * (1 - y)


u0 = jnp.asarray([0.1])
t0, t1 = 0.0, 5.0
N = 100
h = (t1 - t0) / N
grid = jnp.linspace(t0, t1, N + 1) #for naive fixed grid computation

# Set up a state-space model
tcoeffs = taylor.odejet_padded_scan(lambda y: vf(y, t=t0), (u0,), num=1)
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(tcoeffs, ssm_fact="dense")

# Build a solver
ts = ivpsolvers.correction_ts1(vf, ssm=ssm, ode_order=1)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
solver = ivpsolvers.solver(ssm=ssm, strategy=strategy, prior=ibm, correction=ts)
adaptive_solver = ivpsolvers.adaptive(solver, ssm=ssm)



solution = ivpsolve.solve_fixed_grid(init, grid=grid, solver=solver, ssm=ssm)

# Look at the solution
print(f"\ninitial = {jax.tree.map(jnp.shape, init)}")
print(f"\nsolution = {jax.tree.map(jnp.shape, solution)}")


initial = Normal(mean=(2,), cholesky=(2, 2))

solution = IVPSolution(t=(101,), u=[(101, 1), (101, 1)], u_std=[(101, 1), (101, 1)], output_scale=(100,), marginals=Normal(mean=(101, 2), cholesky=(101, 2, 2)), posterior=Normal(mean=(101, 2), cholesky=(101, 2, 2)), num_steps=(100,), ssm=FactImpl(name='dense', prototypes=<probdiffeq.impl._prototypes.DensePrototype object at 0x1481db770>, normal=<probdiffeq.impl._normal.DenseNormal object at 0x14821c4d0>, stats=<probdiffeq.impl._stats.DenseStats object at 0x14821c710>, linearise=<probdiffeq.impl._linearise.DenseLinearisation object at 0x14821c5f0>, conditional=<probdiffeq.impl._conditional.DenseConditional object at 0x14821c830>, num_derivatives=1, unravel=<jax._src.util.HashablePartial object at 0x1481f7350>))
