Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Events #387

Merged
merged 28 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2b6d8ff
Changes to how events are handled in diffrax.
cholberg Feb 21, 2024
3e367f2
Test now fails when no root finder is provided
cholberg May 7, 2024
3f67a27
Saving events with `SubSaveAt`s
cholberg May 15, 2024
76762c1
Accounting for `SubSaveAt.fn` returning a PyTree
cholberg May 15, 2024
76cd083
Adjustments to #387 (events):
patrick-kidger May 19, 2024
e2ab3ce
Save values returned by root find when
cholberg May 23, 2024
0853e49
now returns condition function
cholberg May 24, 2024
1488206
Fixed error for . All tests pass now.
cholberg May 24, 2024
4872009
Added additional tests
cholberg May 26, 2024
a1f577c
Fixed save_index update and shape+dtype check for cond_fn
cholberg May 27, 2024
95ac30f
Added PyTree check in _outer_cond_fn
cholberg May 27, 2024
0c820f3
Added tests for checking that events error out correctly under misspe…
cholberg May 27, 2024
57d90c5
Fixed small error in the save_index update for events
cholberg May 27, 2024
1bdf1d2
Updated how events are saved
cholberg May 27, 2024
55e04e8
Added tests for different configurations of saveat
cholberg May 27, 2024
d8a8ba7
Changed to ValueError when cond_fn returns non-boolean/float.
cholberg May 28, 2024
70e044f
Added docstring to Event class
cholberg May 28, 2024
4c509b6
Updated docstring for steady_state_event
cholberg May 28, 2024
fbea794
Updated docstring for ImplicitAdjoint
cholberg May 28, 2024
b158800
Added example to Event docstring
cholberg May 28, 2024
812e5c6
Updated steady state example to use the new syntax
cholberg May 28, 2024
eb16e57
Fixed weird type checker error
cholberg Jun 1, 2024
09c92c3
Updated steady state test to use the new syntax
cholberg Jun 10, 2024
d07c8f4
Doc tweaks for events
patrick-kidger Jun 15, 2024
883841f
Typo in comment
cholberg Jun 16, 2024
0c62f4c
Simplified unsaving
cholberg Jun 18, 2024
7588482
Deleted extra unnecessary argument
cholberg Jun 25, 2024
e4935ae
Changed to strict inequality to be in line with the usual saving behv…
cholberg Jun 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
SpaceTimeTimeLevyArea as SpaceTimeTimeLevyArea,
)
from ._event import (
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
DiscreteTerminatingEvent as DiscreteTerminatingEvent,
SteadyStateEvent as SteadyStateEvent,
# Deliberately not provided with `X as X` as these are now deprecated, so we'd like
# static type checkers to warn about using them.
AbstractDiscreteTerminatingEvent, # noqa: F401
DiscreteTerminatingEvent, # noqa: F401
Event as Event,
steady_state_event as steady_state_event,
SteadyStateEvent, # noqa: F401
)
from ._global_interpolation import (
AbstractGlobalInterpolation as AbstractGlobalInterpolation,
Expand Down
15 changes: 8 additions & 7 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def loop(
terms,
solver,
stepsize_controller,
discrete_terminating_event,
event,
saveat,
t0,
t1,
Expand Down Expand Up @@ -450,7 +450,8 @@ class ImplicitAdjoint(AbstractAdjoint):
r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).

This is used when solving towards a steady state, typically using
[`diffrax.SteadyStateEvent`][]. In this case, the output of the solver is $y(θ)$
[`diffrax.Event`][] where the condition function is obtained by calling
[`diffrax.steady_state_event`][]. In this case, the output of the solver is $y(θ)$
for which $f(t, y(θ), θ) = 0$. (Where $θ$ corresponds to all parameters found
through `terms` and `args`, but not `y0`.) Then we can skip backpropagating through
the solver and instead directly compute
Expand Down Expand Up @@ -563,7 +564,7 @@ def _loop_backsolve_bwd(
self,
solver,
stepsize_controller,
discrete_terminating_event,
event,
saveat,
t0,
t1,
Expand All @@ -573,7 +574,7 @@ def _loop_backsolve_bwd(
init_state,
progress_meter,
):
assert discrete_terminating_event is None
assert event is None

#
# Unpack our various arguments. Delete a lot of things just to make sure we're not
Expand Down Expand Up @@ -787,7 +788,7 @@ def loop(
init_state,
passed_solver_state,
passed_controller_state,
discrete_terminating_event,
event,
**kwargs,
):
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
Expand Down Expand Up @@ -829,7 +830,7 @@ def loop(
"`diffrax.BacksolveAdjoint` is only compatible with solvers that take "
"a single term."
)
if discrete_terminating_event is not None:
if event is not None:
raise NotImplementedError(
"`diffrax.BacksolveAdjoint` is not compatible with events."
)
Expand All @@ -846,7 +847,7 @@ def loop(
saveat=saveat,
init_state=init_state,
solver=solver,
discrete_terminating_event=discrete_terminating_event,
event=event,
**kwargs,
)
final_state = _only_transpose_ys(final_state)
Expand Down
209 changes: 149 additions & 60 deletions diffrax/_event.py
Original file line number Diff line number Diff line change
@@ -1,98 +1,187 @@
import abc
from collections.abc import Callable
from typing import Optional
from typing import Optional, Union

import equinox as eqx
import optimistix as optx
from jaxtyping import Array, PyTree

from ._custom_types import BoolScalarLike, RealScalarLike
from ._custom_types import BoolScalarLike, FloatScalarLike, RealScalarLike
from ._step_size_controller import AbstractAdaptiveStepSizeController


class AbstractDiscreteTerminatingEvent(eqx.Module):
"""Evaluated at the end of each integration step. If true then the solve is stopped
at that time.
"""

@abc.abstractmethod
def __call__(self, state, **kwargs) -> BoolScalarLike:
"""**Arguments:**

- `state`: a dataclass of the evolving state of the system, including in
particular the solution `state.y` at time `state.tprev`.
- `**kwargs`: the integration options held constant throughout the solve
are passed as keyword arguments: `terms`, `solver`, `args`. etc.

**Returns**
class Event(eqx.Module):
"""Can be used to terminate the solve early if a condition, or one of multiple
conditions, is triggered. It allows for both boolean and continuous condition
functions. In the latter case, a root finder can be used to find the exact time of
the event. Boolean and continuous conditions can be used together.

A boolean. If true then the solve is terminated.
"""


class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent):
"""Terminates the solve if its condition is ever active."""

cond_fn: Callable[..., BoolScalarLike]
Instances of this class should be passed as the `event` argument of
[`diffrax.diffeqsolve`][].
"""

def __call__(self, state, **kwargs):
return self.cond_fn(state, **kwargs)
cond_fn: PyTree[Callable[..., Union[BoolScalarLike, RealScalarLike]]]
root_finder: Optional[optx.AbstractRootFinder] = None


Event.__init__.__doc__ = """**Arguments:**

- `cond_fn`: A function or PyTree of functions `f(t, y, args, **kwargs) -> c` each
returning either a boolean or a real number. If the return value is a boolean, then
the solve will terminate on the first step on which `c` becomes `True`. If the
return value is a real number, then the solve will terminate on the step when `c`
changes sign.

- `root_finder`: An optional [root finder](../nonlinear_solver/) to use for finding
the exact time of the event. If the triggered condition function returns a real
number, then the final time will be the time at which that real number equals zero.
(If the triggered condition function returns a boolean, then the returned time will
just be the end of the step on which it becomes `True`.)
[`optimistix.Newton`](https://docs.kidger.site/optimistix/api/root_find/#optimistix.Newton)
would be a typical choice here.

!!! Example

Consider a bouncing ball dropped from some intial height $x_0$. We can model
the ball by a 2-dimensional ODE

$\\frac{dx_t}{dt} = v_t, \\quad \\frac{dv_t}{dt} = -g,$

where $x_t$ represents the height of the ball, $v_t$ its velocity,
and $g$ is the gravitational constant. With $g=8$, this corresponds to the
vector field:

```python
def vector_field(t, y, args):
_, v = y
return jnp.array([v, -8.0])
```

Figuring out exactly when the ball hits the ground amounts to
solving the ODE until the event $x_t=0$ is triggered. This can be done by using
the real-valued condition function:

```python
def cond_fn(t, y, args, **kwargs):
x, _ = y
return x
```

With $x_0=10$, this would yield:

```python
y0 = jnp.array([10.0, 0.0])
t0 = 0
t1 = jnp.inf
dt0 = 0.1
term = diffrax.ODETerm(vector_field)
root_finder = optx.Newton(1e-5, 1e-5, optx.rms_norm)
event = diffrax.Event(cond_fn, root_finder)
solver = diffrax.Tsit5()
sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, event=event)
print(f"Event time: {sol.ts[0]}") # Event time: 1.58...
print(f"Velocity at event time: {sol.ys[0, 1]}") # Velocity at event time: -12.64...
```
"""


DiscreteTerminatingEvent.__init__.__doc__ = """**Arguments:**
def steady_state_event(
rtol: Optional[float] = None,
atol: Optional[float] = None,
norm: Optional[Callable[[PyTree[Array]], RealScalarLike]] = None,
):
"""Create a condition function that terminates the solve once a steady state is
achieved. The returned function should be passed as the `cond_fn` argument of
[`diffrax.Event`][].

- `cond_fn`: A function `(state, **kwargs) -> bool` that is evaluated on every step of
the differential equation solve. If it returns `True` then the solve is finished at
that timestep. `state` is a dataclass of the evolving state of the system,
including in particular the solution `state.y` at time `state.tprev`. Passed as
keyword arguments are the `terms`, `solver`, `args` etc. that are constant
throughout the solve.
"""
**Arguments:**

- `rtol`, `atol`, `norm`: the solve will terminate once
`norm(f) < atol + rtol * norm(y)`, where `f` is the result of evaluating the
vector field. Will default to the values used in the `stepsize_controller` if
they are not specified here.

class SteadyStateEvent(AbstractDiscreteTerminatingEvent):
"""Terminates the solve once it reaches a steady state."""
**Returns:**

rtol: Optional[float] = None
atol: Optional[float] = None
norm: Callable[[PyTree[Array]], RealScalarLike] = optx.rms_norm
A function `f(t, y, args, **kwargs)`, that can be passed to
`diffrax.Event(cond_fn=..., ...)`.
"""

def __call__(self, state, *, terms, args, solver, stepsize_controller, **kwargs):
def _cond_fn(t, y, args, *, terms, solver, stepsize_controller, **kwargs):
del kwargs
msg = (
"The `rtol` and `atol` tolerances for `SteadyStateEvent` default "
"to the `rtol` and `atol` used with an adaptive step size "
"controller (such as `diffrax.PIDController`). Either use an "
"adaptive step size controller, or specify these tolerances "
"manually."
"The `rtol`, `atol`, and `norm` for `steady_state_event` default to the "
"values used with an adaptive step size controller (such as "
"`diffrax.PIDController`). Either use an adaptive step size controller, or "
"specify these tolerances manually."
)
if self.rtol is None:
if rtol is None:
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
_rtol = stepsize_controller.rtol
else:
raise ValueError(msg)
else:
_rtol = self.rtol
if self.atol is None:
_rtol = rtol
if atol is None:
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
_atol = stepsize_controller.atol
else:
raise ValueError(msg)
else:
_atol = self.atol
_atol = atol
if norm is None:
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
_norm = stepsize_controller.norm
else:
raise ValueError(msg)
else:
_norm = norm

# TODO: this makes an additional function evaluation that in practice has
# probably already been made by the solver.
vf = solver.func(terms, state.tprev, state.y, args)
return self.norm(vf) < _atol + _rtol * self.norm(state.y)
vf = solver.func(terms, t, y, args)
return _norm(vf) < _atol + _rtol * _norm(y)

return _cond_fn

SteadyStateEvent.__init__.__doc__ = """**Arguments:**

- `rtol`: The relative tolerance for determining convergence. Defaults to the
same `rtol` as passed to an adaptive step controller if one is used.
- `atol`: The absolute tolerance for determining convergence. Defaults to the
same `atol` as passed to an adaptive step controller if one is used.
- `norm`: A function `PyTree -> Scalar`, which is called to determine whether
the vector field is close to zero.
"""
#
# Backward compatibility: continue to support `AbstractDiscreteTerminatingEvent`.
# TODO: eventually remove everything below this line.
#


class AbstractDiscreteTerminatingEvent(eqx.Module):
@abc.abstractmethod
def __call__(self, state, **kwargs) -> BoolScalarLike:
pass


class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent):
cond_fn: Callable[..., BoolScalarLike]

def __call__(self, state, **kwargs):
return self.cond_fn(state, **kwargs)


class SteadyStateEvent(AbstractDiscreteTerminatingEvent):
rtol: Optional[float] = None
atol: Optional[float] = None
norm: Callable[[PyTree[Array]], RealScalarLike] = optx.rms_norm

def __call__(self, state, *, args, **kwargs):
return steady_state_event(self.rtol, self.atol, self.norm)(
state.tprev, state.y, args, **kwargs
)


class _StateCompat(eqx.Module):
tprev: FloatScalarLike
y: PyTree[Array]


class DiscreteTerminatingEventToCondFn(eqx.Module):
event: AbstractDiscreteTerminatingEvent

def __call__(self, t, y, args, **kwargs):
return self.event(_StateCompat(tprev=t, y=y), args=args, **kwargs)
Loading
Loading