-
-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zero gradient when using jnp.piecewise inside an ODE #363
Comments
Ah, this is a known (although pretty obscure) limitation of using autodifferentiable differential equation solvers with discontinuous vector fields. TL;DR: the solution First of all, the solution: explicitly declare the jump time in the stepsize controller, typically by doing What went wrong? As for what's going on, we can explain this in a few different ways.
Why does the solution above work? So how do we fix this? We've seen that writing Our first insight into fixing this is to observe that if we had split this into In fact, go ahead and test this, and you'll get the expected gradient! Reasoning in terms of the computation graphs described above, we can see that the reason for this is that So using Can we do better? This is an unfortunate user footgun! But I don't know of an automatic solution to this; so far as I know it may be an open question in the theory of autodifferentiation. (?) Spitballing, I imagine this could maybe be solved by having the ODE solver try and detect when it thinks a jump has occured, if so to solve a root-finding problem to find the jump, and then use that in its step size control. I think investigating this might be an interesting research question in autodifferentiation, for those curious enough to try :) |
Thank you, Patrick, for your detailed and instructive answer! |
Hi,
applying jax.grad to a function which uses diffrax to integrate a piecewise defined ODE, I observe that one partial derivative is unexpectedly zero. The ODE solver returns correct function values, just the gradient is wrong. I’m wondering whether this is a bug, or whether I’m doing something wrong.
Thanks in advance!
David
Example:
Consider the piecewise defined ODE
to which the solution reads
I'm interested in the partial derivatives w.r.t.$T$ and $k_0$ . In the code example below, I compare the gradient obtained from integrating the ODE using diffrax to the analytical solution and to a finite difference calculation.
prints out the following:
I'm using Python 3.11.7, jax 0.4.23, jaxlib 0.4.23.dev20231223, diffrax 0.5.0, MacOS 14.2.1, x86_64, running on CPU
The text was updated successfully, but these errors were encountered: