Skip to content

Commit

Permalink
Added first DDE version
Browse files Browse the repository at this point in the history
  • Loading branch information
thibmonsel committed Jul 26, 2023
1 parent b0758f4 commit 995accc
Show file tree
Hide file tree
Showing 22 changed files with 12,220 additions and 42 deletions.
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
RecursiveCheckpointAdjoint,
)
from .brownian import AbstractBrownianPath, UnsafeBrownianPath, VirtualBrownianTree
from .delays import Delays
from .event import (
AbstractDiscreteTerminatingEvent,
DiscreteTerminatingEvent,
Expand Down
379 changes: 379 additions & 0 deletions diffrax/delays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,379 @@
from typing import Callable, Optional, Sequence, Type, Union

import equinox as eqx
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu

from .custom_types import Array, Bool, Int, PyTree, Scalar
from .global_interpolation import DenseInterpolation
from .local_interpolation import AbstractLocalInterpolation
from .misc import rms_norm
from .misc.omega import ω
from .misc.unvmap import unvmap_any
from .nonlinear_solver import NewtonNonlinearSolver
from .term import VectorFieldWrapper


class Delays(eqx.Module):
"""Module that incorportes all the information needed for integrating DDEs"""

delays: PyTree[Callable]
initial_discontinuities: Union[None, Array, Sequence[Scalar]] = jnp.array([0.0])
max_discontinuities: Int = 100
recurrent_checking: Bool = False
sub_intervals: int = 10
max_steps: int = 20
rtol: float = 1e-3
atol: float = 1e-6


class HistoryVectorField(eqx.Module):
"""VectorField equivalent for a DDE solver that incorporates former
estimated values of y(t).
**Arguments:**
- `vector_field`: vector field of the delayed differential equation.
- `t0`: global integration start time
- `tprev`: start time of current integration step
- `tnext`: end time of current integration step
- `dense_info` : dense_info from current integration step
- `y0_history` : DDE's history function
- `delays` : DDE's different deviated arguments
"""

vector_field: Callable
t0: float
tprev: float
tnext: float
dense_info: PyTree[Array]
dense_interp: Optional[DenseInterpolation]
interpolation_cls: Type[AbstractLocalInterpolation]
y0_history: Callable
delays: PyTree[Callable]

def __call__(self, t, y, args):
history_vals = []
delays, treedef = jtu.tree_flatten(self.delays)
if self.dense_interp is None:
assert self.dense_info is None
for delay in self.delays:
delay_val = delay(t, y, args)
alpha_val = t - delay_val
y0_val = self.y0_history(alpha_val)
history_vals.append(y0_val)
else:
assert self.dense_info is not None
for delay in delays:
delay_val = delay(t, y, args)
alpha_val = t - delay_val

is_before_t0 = alpha_val < self.t0
is_before_tprev = alpha_val < self.tprev
at_most_t0 = jnp.where(alpha_val < self.t0, alpha_val, self.t0)
t0_to_tprev = jnp.clip(alpha_val, self.t0, self.tprev)
at_least_tprev = jnp.maximum(self.tprev, alpha_val)
step_interpolation = self.interpolation_cls(
t0=self.tprev, t1=self.tnext, **self.dense_info
)
switch = jnp.where(is_before_t0, 0, jnp.where(is_before_tprev, 1, 2))
history_val = lax.switch(
switch,
[
lambda: self.y0_history(at_most_t0),
lambda: self.dense_interp.evaluate(t0_to_tprev),
lambda: step_interpolation.evaluate(at_least_tprev),
],
)
history_vals.append(history_val)

history_vals = jtu.tree_unflatten(treedef, history_vals)
history_vals = tuple(history_vals)
return self.vector_field(t, y, args, history=history_vals)


def bind_history(
terms,
delays,
dense_info,
dense_interp,
solver,
direction,
t0,
tprev,
tnext,
y0_history,
):
delays_fn = jtu.tree_map(
lambda x: (lambda t, y, args: x(t, y, args) * direction), delays.delays
)

is_vf_wrapper = lambda x: isinstance(x, VectorFieldWrapper)

def _apply_history(
x,
):
if is_vf_wrapper(x):
vector_field = HistoryVectorField(
x.vector_field,
t0,
tprev,
tnext,
dense_info,
dense_interp,
solver.interpolation_cls,
y0_history,
delays_fn,
)
return VectorFieldWrapper(vector_field)
else:
return x

terms_ = jtu.tree_map(_apply_history, terms, is_leaf=is_vf_wrapper)
return terms_


def history_extrapolation_implicit(
implicit_step,
terms,
dense_interp,
solver,
delays,
t0,
y0_history,
state,
args,
):
def _cond_fun(_val):
_, _, _, _, _, pred, step = _val
return (implicit_step & pred) | (jnp.invert(implicit_step) & (step == 0))

def _body_fun(_val):
y_prev, _, dense_info, _, _, _, step = _val
terms_ = bind_history(
terms,
delays,
dense_info,
dense_interp,
solver,
1,
t0,
state.tprev,
state.tnext,
y0_history,
)

(y, y_error, dense_info, solver_state, solver_result) = solver.step(
terms_,
state.tprev,
state.tnext,
state.y,
args,
state.solver_state,
state.made_jump,
)

_pred = (
rms_norm(
(
(ω(y).call(jnp.abs) - y_prev**ω)
/ (delays.atol + delays.rtol * ω(y).call(jnp.abs))
).ω
)
< 1
)
_pred = _pred & (step < 10)
return (
y,
y_error,
dense_info,
solver_state,
solver_result,
_pred,
step + 1,
)

_init_val = (
state.y,
state.y,
jtu.tree_map(lambda x: x[state.dense_save_index - 1], state.dense_infos),
state.solver_state,
0,
True,
0,
)
(
y,
y_error,
dense_info,
solver_state,
solver_result,
_,
final_step,
) = lax.while_loop(_cond_fun, _body_fun, _init_val)

y_error = jtu.tree_map(
lambda _y_error: jnp.where(final_step < 10, _y_error, jnp.inf),
y_error,
)
return y, y_error, dense_info, solver_state, solver_result


def maybe_find_discontinuity(
tprev,
tnext,
dense_info,
state,
delays,
solver,
args,
keep_step,
sub_tprev,
sub_tnext,
):
dense_discont = solver.interpolation_cls(t0=tprev, t1=tnext, **dense_info)
flat_delays = jtu.tree_leaves(delays.delays)
_gs = []

def make_g(delay):
# Creating the artifical event functions g that is used to
# detect future breaking points.
# http://www.cs.toronto.edu/pub/reports/na/hzpEnrightNA09Preprint.pdf
# page 7
def g(t):
return t - delay(t, dense_discont.evaluate(t), args) - state.discontinuities

return g

for delay in flat_delays:
_gs.append(make_g(delay))

def _find_discontinuity():
# Start by doing a cheap bisection search to reduce
# over the stored-discontinuity dimension.

def _cond_fun(_val):
_, _, _pred, _ = _val
return _pred

def _body_fun(_val):
_ta, _tb, _, _step = _val
_step = _step + 1
_tmid = _ta + 0.5 * (_tb - _ta)
_gas = jnp.stack([jnp.sign(g(_ta)) for g in _gs])
_gmids = jnp.stack([jnp.sign(g(_tmid)) for g in _gs])
_gbs = jnp.stack([jnp.sign(g(_tb)) for g in _gs])
_any_left = jnp.any(_gas != _gmids)
_next_ta = jnp.where(_any_left, _ta, _tmid)
_next_tb = jnp.where(_any_left, _tmid, _tb)
_pred = (
jnp.any(jnp.sum(_gas != _gbs, axis=1) > 1) | _step > delays.max_steps
)
return _next_ta, _next_tb, _pred, _step

_init_val = (sub_tprev, sub_tnext, True, 0)
_final_val = lax.while_loop(_cond_fun, _body_fun, _init_val)
_ta, _tb, _, _ = _final_val

# Then do a more expensive Newton search
# to find the first discontinuity.
_discont_solver = NewtonNonlinearSolver(rtol=delays.rtol, atol=delays.atol)
_disconts = []
for g, delay in zip(_gs, flat_delays):
changed_sign = jnp.sign(g(_ta)) != jnp.sign(g(_tb))
_i = jnp.argmax(changed_sign)
_d = state.discontinuities[_i]
_h = (
lambda t, args, delay=delay, _d=_d: t
- delay(t, dense_discont.evaluate(t), args)
- _d
)
_discont = _discont_solver(_h, _tb, args).root
_disconts.append(_discont)
_disconts = jnp.stack(_disconts)

best_candidate = jnp.where(
(_disconts > sub_tprev) & (_disconts < sub_tnext),
_disconts,
jnp.inf,
)
best_candidate = jnp.min(best_candidate)
discont_update = jnp.where(
jnp.isinf(best_candidate),
False,
True,
)
return best_candidate, discont_update

def _find_discontinuity_wrapper():
return lax.cond(
jnp.any(init_discont & jnp.invert(keep_step)),
_find_discontinuity,
lambda: (sub_tnext, False),
)

init_discont = jnp.stack(
[jnp.sign(g(sub_tprev)) != jnp.sign(g(sub_tnext)) for g in _gs]
)
# We might have rejected the step for normal reasons;
# skip looking for a discontinuity if so.
return lax.cond(
unvmap_any((init_discont & jnp.invert(keep_step))),
_find_discontinuity_wrapper,
lambda: (sub_tnext, False),
)


Delays.__init__.__doc__ = """
**Arguments:**
- `delays`: A `PyTree` where the leaves are the DDE's different scalar
deviated arguments.
- `initial_discontinuities`: Discontinuities given by the initial point time
and history function.
- `max_discontinuities`: Array length that tracks the discontinuity jumps
during integration (only relevant when `recurrent_checking` is True). If
`recurrent checking` is set to `True`, the computation quits unconditionally
when the total number of discontinuities detected is larger
than `max_discontinuities`.
- `recurrent_checking` : If `True`, there will be a systematic check at
integration step for potential discontinuities (this involves nonlinear solves
hence expensive). If `False`, discontinuities will only be checked when a step
is rejected. This allows to integrate faster but can also impact
the accuracy of the DDE solution.
- `sub_intervals` : Number of subintervals of the integration step where
discontinuity tracking is done.
- `rtol` : Relative tolerance for the nonlinear solver for the DDE's
implicit stepping and dichotomy for detecting breaking points.
- `atol` : Absolute tolerance for the nonlinear solver for the DDE's
implicit stepping and dichotomy for detecting breaking points.
- `max_steps` : Max iteration of the dichotomy algorithm to
find a discontinuity.
!!! example
To integrate `y'(t) = - y(t-1)`, we need to define the vector
field and the `Delays` object.
```py
y0 = lambda t: 1.2
def vector_field(t, y, args, history):
return - history[0]
delays = Delays(
delays=[lambda t, y, args: 1.0],
initial_discontinuities=jnp.array([0.0])
)
t0, t1 = 0.0, 50.0
ts = jnp.linspace(t0, t1, 500)
sol = diffrax.diffeqsolve(
diffrax.ODETerm(vector_field),
diffrax.Tsit5(),
t0=ts[0],
t1=ts[-1],
dt0=ts[1] - ts[0],
y0 = y0_history,
stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
saveat=diffrax.SaveAt(ts=ts, dense=True),
delays=delays
)
```
"""

0 comments on commit 995accc

Please sign in to comment.