Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for https://github.com/patrick-kidger/optimistix/issues/48 #694

Merged
merged 1 commit into from
Mar 31, 2024

Conversation

patrick-kidger
Copy link
Owner

@patrick-kidger patrick-kidger commented Mar 30, 2024

This is a fix for patrick-kidger/optimistix#48

It happens occasionally that JAX will insert a spurious vmap. It seems like we're hitting one of these cases.

In particular I found that:

  • We are hitting this nonbatchable call
    return nonbatchable(out)
    , due to step getting batched.
  • Inserting a val = (nonbatchable(val[0]),) + val[1:] here:
    def scan_fn(val, _):
    will also trigger an error before the previously-mentioned one. But shifting this additional check to the other side of the jax.checkpoint (a few lines further down) will not raise an error before the previously-mentioned one.

Spelunking through the (very deep) JAX stack traces, it seems to be due to the lax.scan saving its evolving state for use in the backward pass -- in particular including this step -- and for some reason the insertion of the jax.checkpoint causes this state to become batched. Then when we run the backwards scan during backprop, we re-run this function with the batched state and so things explode. I didn't try digging into it further than that, as I'm not hugely surprised -- such spurious batching is a thing JAX does every now and again.

In terms of a fix: for the actual mathematical correctness, we do continue to disallow a batch to have nonconstant steps (I would need to think a lot harder about whether nonconstant step can ever be valid). However if a batch tracer does arise, then we simply punt the nonconstancy check to runtime. This definitely isn't ideal, but this is (a) a pretty unusual edge-case (DirectAdjoint + vmap + optimistix), and (b) it's better than crashing. (Hopefully all of this becomes unnecessary if jvp-of-custom_vjp is implemented and we can remove DirectAdjoint entirely.)

@patrick-kidger patrick-kidger merged commit 0d94077 into dev Mar 31, 2024
2 checks passed
@patrick-kidger patrick-kidger deleted the loop-constant-step branch March 31, 2024 20:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant