-
-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Events #387
Open
cholberg
wants to merge
26
commits into
patrick-kidger:dev
Choose a base branch
from
cholberg:dev
base: dev
Could not load branches
Branch not found: {{ refName }}
Could not load tags
Nothing to show
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Events #387
Commits on May 15, 2024
-
Changes to how events are handled in diffrax.
The main changes are: 1. Added the generic Event class: ``` class Event: event_function: PyTree[EventFn] root_finder: Optional[optx.AbstractRootFinder] = None ``` EventFn is defined as: ``` class EventFn(eqx.Module): cond_fn: Callable[..., Union[BoolScalarLike, RealScalarLike]] transition_fn: Optional[Callable[[PyTree[ArrayLike]], PyTree[ArrayLike]]] = ( lambda x: x ) ```` 2. Added root finding procedure in diffeqsolve to find exact event times that are differentiable. This is only done when root_finder is not None in the given Event class. 3. Added event_mask to the Solution class so that, when multiple event functions are passed, the user can see which one was triggered for a given solve. Hopefully this new event-handling is sufficiently generic to handle all kinds of events in a unified and simple manner. The main benefit is that we can now differentiate through events. So far the current implementation is enough to deal with ODEs, but I suspect more is needed for dealing with SDEs. The new approach reduces to the old approach when passing only one EventFn with a boolean cond_fn and no transition_fn. For now the transition_fn is not used, but it will be useful when adding non-terminating events. Similarly, we might add other attributes to EventFn to distinguish between different types of events. No event cases in root-finding At the end of the root-fining step (L1146 in _integrate.py), I changed: ``` return jtu.tree_map( _call_real, event.event_fn, final_state.event_result, final_state.event_compare, is_leaf=_is_event_fn, ) ``` to ``` results = jtu.tree_map( _call_real, event.event_fn, final_state.event_result, final_state.event_compare, is_leaf=_is_event_fn, ) results_ravel, _ = jfu.ravel_pytree(results) return jnp.where(event_happened, results_ravel, final_state.tprev - t) ``` Thus, if no event occurs the root-find will return tprev as desired. Before call_real() was constantly 0 in this case which caused in error in the root-find. Added EventFn and Event to diffrax/__init__.py Added tests for new event handling I added new tests for the updated event implementation which, apart from the old ones, also checks that the right event time is found in the case where a root-find is called and that the derivatives match the theoretical derivatives. Furthermore, I marked the following tests, that rely on the old event implementation, with @pytest.mark.skip: - test_event.py::test_discrete_terminate1 - test_event.py::test_discrete_terminate2 - test_event.py::test_event_backsolve - test_adjoint.py::test_implicit In order to avoid pyright errors I had to add # pyright: ignore in a few places in the the old test referenced above. Deleted old event implementation I deleted the following two classes: - diffrax._event.DiscreteTerminatingEvent - diffrax._event.SteadyStateEvent These were also removed from the diffrax.__init__.py Minor changes to event hadnling The changes are the following: - Tweaked the event API and got rid of the EventFn class. Now there is only an Event class: ``` class Event(eqx.Module): cond_fn: PyTree[Callable[..., Union[BoolScalarLike, RealScalarLike]]] root_finder: Optional[optx.AbstractRootFinder] = None ``` - Changed the way boolean condition functions are handled in the root finding step. Now instead of calling _bool_event_gradient, we simply return result = final_state.tprev - t. - Removed all cases where jtu.ravel_pytree was used. - Changed "teventprev" to "tprevprev" and "event_compare" to "event_mask" in the State class. - Updated tests.py and __init__.py to reflect the changes. Minor changes for simplicity I slightly changed the initialization of the event attributes in the state in _integrate.py mainly for aesthetic reasons. Made changes according to comments on patrick-kidger#387 No event case Changed it so that the final value of the solve is returned in cases where no event happens instead of evaluating the interpolator.
Configuration menu - View commit details
-
Copy full SHA for 52b4d37 - Browse repository at this point
Copy the full SHA 52b4d37View commit details -
Configuration menu - View commit details
-
Copy full SHA for 5e9e1fa - Browse repository at this point
Copy the full SHA 5e9e1faView commit details -
Previously, updating the last element of ys and ts did not handle the case where multiple `SubSaveAt`s were used. This is now fixed by adding a `jtu.tree_map` in the appropriate place.
Configuration menu - View commit details
-
Copy full SHA for 8fd300c - Browse repository at this point
Copy the full SHA 8fd300cView commit details -
Configuration menu - View commit details
-
Copy full SHA for dcd91a0 - Browse repository at this point
Copy the full SHA dcd91a0View commit details
Commits on May 19, 2024
-
Adjustments to patrick-kidger#387 (events):
- Semantic change: boolean events now trigger when they become truthy (before they occurred when they swap being falsy<->truthy). Note that this required twiddling around a few things as previously it was impossible for an event to trigger on the first step; now they can. - Semantic change: event functions now have the signature `(t, y, args *, terms, solver, **etc)` for consistency with vector fields and with `SaveAt(fn=...)`. - Feature: now backward-compatible with the old discrete terminating events. - Feature: added `diffrax.steady_state_event`. - Bugfix: the final `t` and `y` from an event are now saved in the correct index of `ts` and `ys`, rather than just always being saved at index `-1`. - Bugfix: at one point `args` referred to the `args` coming from a root find rather than the overall `diffeqsolve`. - Bugfix: the current `state.tprev` was used instead of the previous state's `tnext`. (These are usually but not always the same -- in particular when around jumps.) - Bugfix: added some checks when the condition function of an event does not return a bool/float scalar. - Performance: includes a fastpath for skipping the rootfind if no events are triggered. - Performance: now avoiding tracing for the shape of `dense_info` twice when using adaptive step size controllers alongside events. - Performance: avoided quadratic loop for figuring out what was the first event to trigger. - Chore: added support for the possibility of the final root find (for the time of the event) failing. - Chore: removed some dead code (`_bool_event_gradient`). - Chore: removed references in the docs to the old `discrete_terminating_event`. In addition, some drive-bys: - Fixed warnings about pending deprecations `jnp.clip(..., a_min=..., a_max=...)`. - Had `aux_stats` (in `_integrate.py`) forward to the overall output statistics. In practice this is empty but it's worth doing for the future.
Configuration menu - View commit details
-
Copy full SHA for 1470707 - Browse repository at this point
Copy the full SHA 1470707View commit details
Commits on May 23, 2024
-
Configuration menu - View commit details
-
Copy full SHA for 28a7172 - Browse repository at this point
Copy the full SHA 28a7172View commit details
Commits on May 24, 2024
-
Configuration menu - View commit details
-
Copy full SHA for 91082c5 - Browse repository at this point
Copy the full SHA 91082c5View commit details -
Configuration menu - View commit details
-
Copy full SHA for 269323a - Browse repository at this point
Copy the full SHA 269323aView commit details
Commits on May 26, 2024
-
Added a bunch of additional tests for events. Also changed the way `save_index` was updated to handle PyTrees of subsaveats.
Configuration menu - View commit details
-
Copy full SHA for 05f636f - Browse repository at this point
Copy the full SHA 05f636fView commit details
Commits on May 27, 2024
-
Configuration menu - View commit details
-
Copy full SHA for 8450a40 - Browse repository at this point
Copy the full SHA 8450a40View commit details -
Configuration menu - View commit details
-
Copy full SHA for f3b8d3e - Browse repository at this point
Copy the full SHA f3b8d3eView commit details -
Configuration menu - View commit details
-
Copy full SHA for 0623211 - Browse repository at this point
Copy the full SHA 0623211View commit details -
Configuration menu - View commit details
-
Copy full SHA for 4e4414b - Browse repository at this point
Copy the full SHA 4e4414bView commit details -
When passing `SaveAt(steps=True, ts=ts)` for some array `ts` values will be saved at the times in `ts` in the time increments of each step of the solver. In practice this means that some of the saved values might be after the event time. I changed it so that these values are deleted.
Configuration menu - View commit details
-
Copy full SHA for 3bcc094 - Browse repository at this point
Copy the full SHA 3bcc094View commit details -
Configuration menu - View commit details
-
Copy full SHA for 1202f98 - Browse repository at this point
Copy the full SHA 1202f98View commit details
Commits on May 28, 2024
-
Configuration menu - View commit details
-
Copy full SHA for 02c3c16 - Browse repository at this point
Copy the full SHA 02c3c16View commit details -
Configuration menu - View commit details
-
Copy full SHA for 1e4b9c7 - Browse repository at this point
Copy the full SHA 1e4b9c7View commit details -
Configuration menu - View commit details
-
Copy full SHA for c5e7037 - Browse repository at this point
Copy the full SHA c5e7037View commit details -
Configuration menu - View commit details
-
Copy full SHA for 5a15a87 - Browse repository at this point
Copy the full SHA 5a15a87View commit details -
Configuration menu - View commit details
-
Copy full SHA for cfc8653 - Browse repository at this point
Copy the full SHA cfc8653View commit details -
Configuration menu - View commit details
-
Copy full SHA for 3c267df - Browse repository at this point
Copy the full SHA 3c267dfView commit details
Commits on Jun 1, 2024
-
Configuration menu - View commit details
-
Copy full SHA for 39a6557 - Browse repository at this point
Copy the full SHA 39a6557View commit details
Commits on Jun 10, 2024
-
Configuration menu - View commit details
-
Copy full SHA for 24e9ab1 - Browse repository at this point
Copy the full SHA 24e9ab1View commit details
Commits on Jun 15, 2024
-
Configuration menu - View commit details
-
Copy full SHA for a12d069 - Browse repository at this point
Copy the full SHA a12d069View commit details
Commits on Jun 16, 2024
-
Configuration menu - View commit details
-
Copy full SHA for 84bd8cb - Browse repository at this point
Copy the full SHA 84bd8cbView commit details
Commits on Jun 18, 2024
-
Configuration menu - View commit details
-
Copy full SHA for 05bafb2 - Browse repository at this point
Copy the full SHA 05bafb2View commit details
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.