Fix for https://github.com/patrick-kidger/optimistix/issues/48 #694
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is a fix for patrick-kidger/optimistix#48
It happens occasionally that JAX will insert a spurious vmap. It seems like we're hitting one of these cases.
In particular I found that:
nonbatchable
callequinox/equinox/internal/_loop/common.py
Line 432 in 60612c1
step
getting batched.val = (nonbatchable(val[0]),) + val[1:]
here:equinox/equinox/internal/_loop/bounded.py
Line 72 in 60612c1
jax.checkpoint
(a few lines further down) will not raise an error before the previously-mentioned one.Spelunking through the (very deep) JAX stack traces, it seems to be due to the
lax.scan
saving its evolving state for use in the backward pass -- in particular including thisstep
-- and for some reason the insertion of thejax.checkpoint
causes this state to become batched. Then when we run the backwards scan during backprop, we re-run this function with the batched state and so things explode. I didn't try digging into it further than that, as I'm not hugely surprised -- such spurious batching is a thing JAX does every now and again.In terms of a fix: for the actual mathematical correctness, we do continue to disallow a batch to have nonconstant
step
s (I would need to think a lot harder about whether nonconstantstep
can ever be valid). However if a batch tracer does arise, then we simply punt the nonconstancy check to runtime. This definitely isn't ideal, but this is (a) a pretty unusual edge-case (DirectAdjoint
+vmap
+optimistix
), and (b) it's better than crashing. (Hopefully all of this becomes unnecessary ifjvp-of-custom_vjp
is implemented and we can removeDirectAdjoint
entirely.)