-
Notifications
You must be signed in to change notification settings - Fork 1
feat: support simple sequential batching rule (support for jax.vmap and jax.jacobian)
#47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@johnbcoughlin you might want to try this branch out along with the main version of |
|
@dionhaefner For now this should be enough to demonstrate the |
dionhaefner
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this looks great!
Let's add at least 1 test without_axes and see what happens. If things break, IMO we could add a check that raises an exception is out_axes is given (i.e., explicitly mark as unsupported for now) and merge early.
|
Thanks @dionhaefner I added a test over |
|
The other consideration I had was whether we should add any assertions that batch dims are equal or just rely on standard vmaps checks. |
There's no way to know which outputs are affected by which inputs, so yes this is the intended behavior. We may want to add additional bells and whistles later, but for now this should be the best we can do without a lot more work. |
Don't think we need that, error messages should be informative enough in that case. |
|
@dionhaefner this should be good for the final sweep now. Agree that fixing the non-batched arguments is not possible without tracing through the Tesseract, maybe in the future we could add a By the way the second argument of the batching rule is the axis that has been batched over. When vmapping this is referred to if needing to transpose to another axis. With regards to supporting multiple tangent vectors as default which one of the three options below should I go with?:
|
|
Option 1 sounds like a good default, although I'd nitpick that describing problems is more useful than describing solutions (where solutions emerge from the team / community by themselves once the problem is described well enough). |
dionhaefner
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jpbrodrick89!
jax.vmap and jax.jacobian)
Co-authored-by: Dion Häfner <dion.haefner@simulation.science>
Relevant issue or PR
Fixes #1
Description of changes
Implement a simple sequential batching rule enabling the use of
vmapandjax.jacobianby simply calling the relevant endpoint multiple times. The implementation was directly lifted from jax._src.ffi.ffi_batching_rule with refactoring of variable names and eliminating dead code. No private JAX API is used; insteadsplit_argsandunflatten_argswere refactored to serve additional purposes. The only thing I'm really not sure about is removing the update of theoutput_layoutwhich I have just commented out (the relevant code now refers tooutput_pytreedef_expanded).Should we provide any attributions to
jaxin accordance with their Apache 2.0 license somewhere?Testing done
jax.jacfwdandjax.jacrevruns and produces expected results withvectoradd_jaxjacobianvmapand then add selection to CIFurther work
In the future we can look at adding alternative
vmap_methods: the main ones I think would be relevant areasync(like sequential but we allow workers to run in parallel when this support is eventually added) andauto(expand dimensions whereInputSchemahas an ellipsis in the shape definition). For tangent/cotangent vectors I think we should actually consider supporting (or better yet requiring) ourjvp/vjpendpoints to accept a batch of vectors as default.jacobiancalculations using the sequential method will be slow for large array sizes and there is no way to telljaxto use ourjacobianendpoint instead.