Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Events #387

Open
wants to merge 26 commits into
base: dev
Choose a base branch
from
Open

Events #387

wants to merge 26 commits into from

Commits on May 15, 2024

  1. Changes to how events are handled in diffrax.

    The main changes are:
    
        1. Added the generic Event class:
        ```
        class Event:
            event_function: PyTree[EventFn]
            root_finder: Optional[optx.AbstractRootFinder] = None
        ```
        EventFn is defined as:
        ```
        class EventFn(eqx.Module):
            cond_fn: Callable[..., Union[BoolScalarLike, RealScalarLike]]
            transition_fn: Optional[Callable[[PyTree[ArrayLike]], PyTree[ArrayLike]]] = (
                lambda x: x
            )
        ````
    
        2. Added root finding procedure in diffeqsolve to find exact event times that are differentiable. This is only done when root_finder is not None in the given Event class.
    
        3. Added event_mask to the Solution class so that, when multiple event functions are passed, the user can see which one was triggered for a given solve.
    
    Hopefully this new event-handling is sufficiently generic to handle all kinds of events in a unified and simple manner. The main benefit is that we can now differentiate through events. So far the current implementation is enough to deal with ODEs, but I suspect more is needed for dealing with SDEs.
    
    The new approach reduces to the old approach when passing only one EventFn with a boolean cond_fn and no transition_fn.
    
    For now the transition_fn is not used, but it will be useful when adding non-terminating events. Similarly, we might add other attributes to EventFn to distinguish between different types of events.
    
    No event cases in root-finding
    
    At the end of the root-fining step (L1146 in _integrate.py), I changed:
    ```
    return jtu.tree_map(
        _call_real,
        event.event_fn,
        final_state.event_result,
        final_state.event_compare,
        is_leaf=_is_event_fn,
    )
    ```
    
    to
    
    ```
    results = jtu.tree_map(
        _call_real,
        event.event_fn,
        final_state.event_result,
        final_state.event_compare,
        is_leaf=_is_event_fn,
    )
    results_ravel, _ = jfu.ravel_pytree(results)
    return jnp.where(event_happened, results_ravel, final_state.tprev - t)
    ```
    
    Thus, if no event occurs the root-find will return tprev as desired. Before call_real() was constantly 0 in this case which caused in error in the root-find.
    
    Added EventFn and Event to diffrax/__init__.py
    
    Added tests for new event handling
    
    I added new tests for the updated event implementation which, apart from the old ones, also checks that the right event time is found in the case where a root-find is called and that the derivatives match the theoretical derivatives.
    
    Furthermore, I marked the following tests, that rely on the old event implementation, with @pytest.mark.skip:
     - test_event.py::test_discrete_terminate1
     - test_event.py::test_discrete_terminate2
     - test_event.py::test_event_backsolve
     - test_adjoint.py::test_implicit
    
    In order to avoid pyright errors I had to add # pyright: ignore in a few places in the the old test referenced above.
    
    Deleted old event implementation
    
    I deleted the following two classes:
    - diffrax._event.DiscreteTerminatingEvent
    - diffrax._event.SteadyStateEvent
    
    These were also removed from the diffrax.__init__.py
    
    Minor changes to event hadnling
    
    The changes are the following:
    
    - Tweaked the event API and got rid of the EventFn class. Now there is only an Event class:
    
    ```
    class Event(eqx.Module):
        cond_fn: PyTree[Callable[..., Union[BoolScalarLike, RealScalarLike]]]
        root_finder: Optional[optx.AbstractRootFinder] = None
    ```
    
    - Changed the way boolean condition functions are handled in the root finding step. Now instead of calling _bool_event_gradient, we simply return result = final_state.tprev - t.
    
    - Removed all cases where jtu.ravel_pytree was used.
    
    - Changed "teventprev" to "tprevprev" and "event_compare" to "event_mask" in the State class.
    
    - Updated tests.py and __init__.py to reflect the changes.
    
    Minor changes for simplicity
    
    I slightly changed the initialization of the event attributes in the state in _integrate.py mainly for aesthetic reasons.
    
    Made changes according to comments on patrick-kidger#387
    
    No event case
    
    Changed it so that the final value of the solve is returned in cases where no event happens instead of evaluating the interpolator.
    cholberg committed May 15, 2024
    Configuration menu
    Copy the full SHA
    52b4d37 View commit details
    Browse the repository at this point in the history
  2. Configuration menu
    Copy the full SHA
    5e9e1fa View commit details
    Browse the repository at this point in the history
  3. Saving events with SubSaveAts

    Previously, updating the last element of ys and ts did not handle the case where multiple `SubSaveAt`s were used. This is now fixed by adding a `jtu.tree_map` in the appropriate place.
    cholberg committed May 15, 2024
    Configuration menu
    Copy the full SHA
    8fd300c View commit details
    Browse the repository at this point in the history
  4. Configuration menu
    Copy the full SHA
    dcd91a0 View commit details
    Browse the repository at this point in the history

Commits on May 19, 2024

  1. Adjustments to patrick-kidger#387 (events):

    - Semantic change: boolean events now trigger when they become truthy (before they occurred when they swap being falsy<->truthy). Note that this required twiddling around a few things as previously it was impossible for an event to trigger on the first step; now they can.
    - Semantic change: event functions now have the signature `(t, y, args *, terms, solver, **etc)` for consistency with vector fields and with `SaveAt(fn=...)`.
    - Feature: now backward-compatible with the old discrete terminating events.
    - Feature: added `diffrax.steady_state_event`.
    - Bugfix: the final `t` and `y` from an event are now saved in the correct index of `ts` and `ys`, rather than just always being saved at index `-1`.
    - Bugfix: at one point `args` referred to the `args` coming from a root find rather than the overall `diffeqsolve`.
    - Bugfix: the current `state.tprev` was used instead of the previous state's `tnext`. (These are usually but not always the same -- in particular when around jumps.)
    - Bugfix: added some checks when the condition function of an event does not return a bool/float scalar.
    - Performance: includes a fastpath for skipping the rootfind if no events are triggered.
    - Performance: now avoiding tracing for the shape of `dense_info` twice when using adaptive step size controllers alongside events.
    - Performance: avoided quadratic loop for figuring out what was the first event to trigger.
    - Chore: added support for the possibility of the final root find (for the time of the event) failing.
    - Chore: removed some dead code (`_bool_event_gradient`).
    - Chore: removed references in the docs to the old `discrete_terminating_event`.
    
    In addition, some drive-bys:
    
    - Fixed warnings about pending deprecations `jnp.clip(..., a_min=..., a_max=...)`.
    - Had `aux_stats` (in `_integrate.py`) forward to the overall output statistics. In practice this is empty but it's worth doing for the future.
    patrick-kidger committed May 19, 2024
    Configuration menu
    Copy the full SHA
    1470707 View commit details
    Browse the repository at this point in the history

Commits on May 23, 2024

  1. Configuration menu
    Copy the full SHA
    28a7172 View commit details
    Browse the repository at this point in the history

Commits on May 24, 2024

  1. Configuration menu
    Copy the full SHA
    91082c5 View commit details
    Browse the repository at this point in the history
  2. Configuration menu
    Copy the full SHA
    269323a View commit details
    Browse the repository at this point in the history

Commits on May 26, 2024

  1. Added additional tests

    Added a bunch of additional tests for events.
    
    Also changed the way `save_index` was updated to handle PyTrees of subsaveats.
    cholberg committed May 26, 2024
    Configuration menu
    Copy the full SHA
    05f636f View commit details
    Browse the repository at this point in the history

Commits on May 27, 2024

  1. Configuration menu
    Copy the full SHA
    8450a40 View commit details
    Browse the repository at this point in the history
  2. Configuration menu
    Copy the full SHA
    f3b8d3e View commit details
    Browse the repository at this point in the history
  3. Configuration menu
    Copy the full SHA
    0623211 View commit details
    Browse the repository at this point in the history
  4. Configuration menu
    Copy the full SHA
    4e4414b View commit details
    Browse the repository at this point in the history
  5. Updated how events are saved

    When passing `SaveAt(steps=True, ts=ts)` for some array `ts` values will be saved at the times in `ts` in the time increments of each step of the solver. In practice this means that some of the saved values might be after the event time. I changed it so that these values are deleted.
    cholberg committed May 27, 2024
    Configuration menu
    Copy the full SHA
    3bcc094 View commit details
    Browse the repository at this point in the history
  6. Configuration menu
    Copy the full SHA
    1202f98 View commit details
    Browse the repository at this point in the history

Commits on May 28, 2024

  1. Configuration menu
    Copy the full SHA
    02c3c16 View commit details
    Browse the repository at this point in the history
  2. Configuration menu
    Copy the full SHA
    1e4b9c7 View commit details
    Browse the repository at this point in the history
  3. Configuration menu
    Copy the full SHA
    c5e7037 View commit details
    Browse the repository at this point in the history
  4. Configuration menu
    Copy the full SHA
    5a15a87 View commit details
    Browse the repository at this point in the history
  5. Configuration menu
    Copy the full SHA
    cfc8653 View commit details
    Browse the repository at this point in the history
  6. Configuration menu
    Copy the full SHA
    3c267df View commit details
    Browse the repository at this point in the history

Commits on Jun 1, 2024

  1. Configuration menu
    Copy the full SHA
    39a6557 View commit details
    Browse the repository at this point in the history

Commits on Jun 10, 2024

  1. Configuration menu
    Copy the full SHA
    24e9ab1 View commit details
    Browse the repository at this point in the history

Commits on Jun 15, 2024

  1. Doc tweaks for events

    patrick-kidger committed Jun 15, 2024
    Configuration menu
    Copy the full SHA
    a12d069 View commit details
    Browse the repository at this point in the history

Commits on Jun 16, 2024

  1. Typo in comment

    cholberg committed Jun 16, 2024
    Configuration menu
    Copy the full SHA
    84bd8cb View commit details
    Browse the repository at this point in the history

Commits on Jun 18, 2024

  1. Simplified unsaving

    cholberg committed Jun 18, 2024
    Configuration menu
    Copy the full SHA
    05bafb2 View commit details
    Browse the repository at this point in the history