Skip to content

Commit

Permalink
Remove taylor-mode attribute from sovler (#91)
Browse files Browse the repository at this point in the history
  • Loading branch information
pnkraemer committed Oct 11, 2022
1 parent 12077df commit 3096a70
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 116 deletions.
36 changes: 19 additions & 17 deletions odefilter/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,15 @@ class Adaptive(eqx.Module):
control: Any
norm_ord: Union[int, str, None] = None

@partial(jax.jit, static_argnames=("vector_field",))
def init_fn(self, *, vector_field, initial_values, t0):
@jax.jit
def init_fn(self, *, taylor_coefficients, t0):
"""Initialise the IVP solver state."""
state_odefilter = self.odefilter.init_fn(
vector_field=vector_field, initial_values=initial_values, t0=t0
taylor_coefficients=taylor_coefficients, t0=t0
)
state_control = self.control.init_fn()

u0, f0, *_ = taylor_coefficients
error_normalised = self._normalise_error(
error_estimate=state_odefilter.error_estimate,
u=state_odefilter.u,
Expand All @@ -74,8 +75,8 @@ def init_fn(self, *, vector_field, initial_values, t0):
norm_ord=self.norm_ord,
)
dt_proposed = self._propose_first_dt_per_tol(
f=lambda *x: vector_field(*x, t=t0),
u0=initial_values,
f0=f0,
u0=(u0,),
error_order=self.error_order,
atol=self.atol,
rtol=self.rtol,
Expand Down Expand Up @@ -157,7 +158,7 @@ def _normalise_error(*, error_estimate, u, atol, rtol, norm_ord):
return jnp.linalg.norm(error_relative, ord=norm_ord)

@staticmethod
def _propose_first_dt_per_tol(*, f, u0, error_order, rtol, atol):
def _propose_first_dt_per_tol(*, f0, u0, error_order, rtol, atol):
# Taken from:
# https://github.com/google/jax/blob/main/jax/experimental/ode.py
#
Expand All @@ -166,21 +167,22 @@ def _propose_first_dt_per_tol(*, f, u0, error_order, rtol, atol):
# E. Hairer, S. P. Norsett G. Wanner,
# Solving Ordinary Differential Equations I: Nonstiff Problems, Sec. II.4.
assert len(u0) == 1
f0 = f(*u0)
scale = atol + u0[0] * rtol
a = jnp.linalg.norm(u0[0] / scale)
b = jnp.linalg.norm(f0 / scale)
dt0 = jnp.where((a < 1e-5) | (b < 1e-5), 1e-6, 0.01 * a / b)

u1 = u0[0] + dt0 * f0
f1 = f(u1)
c = jnp.linalg.norm((f1 - f0) / scale) / dt0
dt1 = jnp.where(
(b <= 1e-15) & (c <= 1e-15),
jnp.maximum(1e-6, dt0 * 1e-3),
(0.01 / jnp.max(b + c)) ** (1.0 / error_order),
)
return jnp.minimum(100.0 * dt0, dt1)
return 100 * dt0
# todo:
#
# u1 = u0[0] + dt0 * f0
# f1 = f(u1)
# c = jnp.linalg.norm((f1 - f0) / scale) / dt0
# dt1 = jnp.where(
# (b <= 1e-15) & (c <= 1e-15),
# jnp.maximum(1e-6, dt0 * 1e-3),
# (0.01 / jnp.max(b + c)) ** (1.0 / error_order),
# )
# return jnp.minimum(100.0 * dt0, dt1)

@jax.jit
def extract_fn(self, *, state): # noqa: D102
Expand Down
46 changes: 41 additions & 5 deletions odefilter/ivpsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax
import jax.numpy as jnp

from odefilter import _control_flow
from odefilter import _control_flow, taylor_series


def solve(
Expand All @@ -17,6 +17,7 @@ def solve(
t1,
solver,
parameters=(),
taylor_series_fn=None,
):
"""Solve an initial value problem.
Expand All @@ -42,6 +43,7 @@ def solution_generator(
t1,
solver,
parameters=(),
taylor_series_fn=None,
):
"""Construct a generator of an IVP solution.
Expand All @@ -50,10 +52,20 @@ def solution_generator(
"""
_assert_not_scalar(initial_values=initial_values)

taylor_series_fn = taylor_series_fn or taylor_series.TaylorMode()

def vf(*ys, t):
return vector_field(*ys, t, *parameters)

state = solver.init_fn(vector_field=vf, initial_values=initial_values, t0=t0)
def vf_auto_t0(*x):
return vf(*x, t=t0)

taylor_coefficients = taylor_series_fn(
vector_field=vf_auto_t0,
initial_values=initial_values,
num=solver.odefilter.strategy.implementation.num_derivatives,
)
state = solver.init_fn(taylor_coefficients=taylor_coefficients, t0=t0)

while state.accepted.t < t1:
yield solver.extract_fn(state=state)
Expand All @@ -73,6 +85,7 @@ def simulate_terminal_values(
t1,
solver,
parameters=(),
taylor_series_fn=None,
):
"""Simulate the terminal values of an initial value problem.
Expand All @@ -93,10 +106,20 @@ def simulate_terminal_values(
"""
_assert_not_scalar(initial_values=initial_values)

taylor_series_fn = taylor_series_fn or taylor_series.TaylorMode()

def vf(*ys, t):
return vector_field(*ys, t, *parameters)

state0 = solver.init_fn(vector_field=vf, initial_values=initial_values, t0=t0)
def vf_auto_t0(*x):
return vf(*x, t=t0)

taylor_coefficients = taylor_series_fn(
vector_field=vf_auto_t0,
initial_values=initial_values,
num=solver.odefilter.strategy.implementation.num_derivatives,
)
state0 = solver.init_fn(taylor_coefficients=taylor_coefficients, t0=t0)

solution = _advance_ivp_solution_adaptively(
state0=state0,
Expand All @@ -108,13 +131,27 @@ def vf(*ys, t):


@partial(jax.jit, static_argnames=("vector_field",))
def simulate_checkpoints(vector_field, initial_values, *, ts, solver, parameters=()):
def simulate_checkpoints(
vector_field, initial_values, *, ts, solver, parameters=(), taylor_series_fn=None
):
"""Solve an IVP and return the solution at checkpoints."""
_assert_not_scalar(initial_values=initial_values)

taylor_series_fn = taylor_series_fn or taylor_series.TaylorMode()

def vf(*ys, t):
return vector_field(*ys, t, *parameters)

def vf_auto_t0(*x):
return vf(*x, t=ts[0])

taylor_coefficients = taylor_series_fn(
vector_field=vf_auto_t0,
initial_values=initial_values,
num=solver.odefilter.strategy.implementation.num_derivatives,
)
state0 = solver.init_fn(taylor_coefficients=taylor_coefficients, t0=ts[0])

def advance_to_next_checkpoint(s, t_next):
s_next = _advance_ivp_solution_adaptively(
state0=s,
Expand All @@ -124,7 +161,6 @@ def advance_to_next_checkpoint(s, t_next):
)
return s_next, s_next

state0 = solver.init_fn(vector_field=vf, initial_values=initial_values, t0=ts[0])
_, solution = _control_flow.scan_with_init(
f=advance_to_next_checkpoint,
init=state0,
Expand Down
17 changes: 3 additions & 14 deletions odefilter/odefilters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,16 @@ class ODEFilterSolution(eqx.Module):
class ODEFilter(eqx.Module):
"""ODE filter."""

taylor_series_init: Any
strategy: Any

@partial(jax.jit, static_argnames=("vector_field",))
def init_fn(self, *, vector_field, initial_values, t0):
@jax.jit
def init_fn(self, *, taylor_coefficients, t0):
"""Initialise the IVP solver state."""

def vf(*x):
return vector_field(*x, t=t0)

taylor_coefficients = self.taylor_series_init(
vector_field=vf,
initial_values=initial_values,
num=self.strategy.implementation.num_derivatives,
)

posterior, error_estimate = self.strategy.init_fn(
taylor_coefficients=taylor_coefficients
)

u0, *_ = initial_values
u0, *_ = taylor_coefficients
return ODEFilterSolution(
t=t0,
u=u0,
Expand Down
21 changes: 4 additions & 17 deletions odefilter/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,7 @@
Tomorrow, this module might go again.
"""
from odefilter import (
adaptive,
controls,
information,
odefilters,
strategies,
taylor_series,
)
from odefilter import adaptive, controls, information, odefilters, strategies
from odefilter.implementations import dense, isotropic

ATOL_DEFAULTS = 1e-6
Expand All @@ -41,9 +34,7 @@ def dynamic_isotropic_ekf0(num_derivatives, atol=ATOL_DEFAULTS, rtol=RTOL_DEFAUL
strategy = strategies.DynamicFilter(
implementation=implementation, information=information_op
)
stepping = odefilters.ODEFilter(
taylor_series_init=taylor_series.TaylorMode(), strategy=strategy
)
stepping = odefilters.ODEFilter(strategy=strategy)
control = controls.ProportionalIntegral()
return adaptive.Adaptive(
odefilter=stepping,
Expand All @@ -67,9 +58,7 @@ def dynamic_isotropic_eks0(num_derivatives, atol=ATOL_DEFAULTS, rtol=RTOL_DEFAUL
strategy = strategies.DynamicSmoother(
implementation=implementation, information=information_op
)
stepping = odefilters.ODEFilter(
taylor_series_init=taylor_series.TaylorMode(), strategy=strategy
)
stepping = odefilters.ODEFilter(strategy=strategy)
control = controls.ProportionalIntegral()
return adaptive.Adaptive(
odefilter=stepping,
Expand All @@ -94,9 +83,7 @@ def dynamic_ekf1(
strategy = strategies.DynamicFilter(
implementation=implementation, information=information_op
)
stepping = odefilters.ODEFilter(
taylor_series_init=taylor_series.TaylorMode(), strategy=strategy
)
stepping = odefilters.ODEFilter(strategy=strategy)
control = controls.ProportionalIntegral()
return adaptive.Adaptive(
odefilter=stepping,
Expand Down
53 changes: 0 additions & 53 deletions tests/test_recipes.py

This file was deleted.

22 changes: 12 additions & 10 deletions tests/test_src/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,9 @@
from odefilter import adaptive, controls, odefilters, taylor_series


@pytest_cases.parametrize(
"tseries", [taylor_series.TaylorMode(), taylor_series.ForwardMode()]
)
@pytest_cases.parametrize_with_cases("strategy", cases=".cases_strategies")
def case_odefilter(tseries, strategy):
odefilter = odefilters.ODEFilter(
taylor_series_init=tseries,
strategy=strategy,
)
def case_odefilter(strategy):
odefilter = odefilters.ODEFilter(strategy=strategy)
control = controls.ProportionalIntegral()
atol, rtol = 1e-3, 1e-3
return adaptive.Adaptive(
Expand All @@ -34,9 +28,17 @@ def test_solver(solver, vf, u0, t0, t1, p):
def vf_p(*ys, t):
return vf(*ys, t, *p)

state0 = solver.init_fn(
vector_field=vf_p,
def vf_p_0(*ys):
return vf_p(*ys, t=t0)

taylor_series_fn = taylor_series.TaylorMode()
tcoeffs = taylor_series_fn(
vector_field=vf_p_0,
initial_values=u0,
num=solver.odefilter.strategy.implementation.num_derivatives,
)
state0 = solver.init_fn(
taylor_coefficients=tcoeffs,
t0=t0,
)
assert state0.dt_proposed > 0.0
Expand Down

0 comments on commit 3096a70

Please sign in to comment.