Skip to content

Conversation

@jpbrodrick89
Copy link
Contributor

@jpbrodrick89 jpbrodrick89 commented Jun 20, 2025

Relevant issue or PR

Fixes #1

Description of changes

Implement a simple sequential batching rule enabling the use of vmap and jax.jacobian by simply calling the relevant endpoint multiple times. The implementation was directly lifted from jax._src.ffi.ffi_batching_rule with refactoring of variable names and eliminating dead code. No private JAX API is used; instead split_args and unflatten_args were refactored to serve additional purposes. The only thing I'm really not sure about is removing the update of the output_layout which I have just commented out (the relevant code now refers to output_pytreedef_expanded).

Should we provide any attributions to jax in accordance with their Apache 2.0 license somewhere?

Testing done

  • CI passes
  • jax.jacfwd and jax.jacrev runs and produces expected results with vectoradd_jax
  • Added CI tests of jacobian
  • Manually test more exotic uses of vmap and then add selection to CI

Further work

In the future we can look at adding alternative vmap_methods: the main ones I think would be relevant are async (like sequential but we allow workers to run in parallel when this support is eventually added) and auto (expand dimensions where InputSchema has an ellipsis in the shape definition). For tangent/cotangent vectors I think we should actually consider supporting (or better yet requiring) our jvp/vjp endpoints to accept a batch of vectors as default. jacobian calculations using the sequential method will be slow for large array sizes and there is no way to tell jax to use our jacobian endpoint instead.

@jpbrodrick89
Copy link
Contributor Author

@johnbcoughlin you might want to try this branch out along with the main version of lineax.

@jpbrodrick89
Copy link
Contributor Author

@dionhaefner For now this should be enough to demonstrate the vmap rule works as expected for a variety of use cases with a single output. I will add nested_tesseract vmap tests in on Monday to see whether we get the right error if we abuse out_axes.

Copy link
Contributor

@dionhaefner dionhaefner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks great!

Let's add at least 1 test without_axes and see what happens. If things break, IMO we could add a check that raises an exception is out_axes is given (i.e., explicitly mark as unsupported for now) and merge early.

@jpbrodrick89
Copy link
Contributor Author

Thanks @dionhaefner I added a test over out_axes earlier this morning but didn't push, all worked fine there. I've been working on tests with nested Tesseract and getting a bit confused if I've implemented the right behaviour now. The current implementation will batch over every output even if not affected by the inputs. Is that what you'd expect as a user of vmap? Or should there be some way to change the unaffected outputs back to their original dimension with the standard vmap API (None doesn't work as the outputs have already been batched).

@jpbrodrick89
Copy link
Contributor Author

The other consideration I had was whether we should add any assertions that batch dims are equal or just rely on standard vmaps checks.

@dionhaefner
Copy link
Contributor

dionhaefner commented Jun 23, 2025

Thanks @dionhaefner I added a test over out_axes earlier this morning but didn't push, all worked fine there. I've been working on tests with nested Tesseract and getting a bit confused if I've implemented the right behaviour now. The current implementation will batch over every output even if not affected by the inputs. Is that what you'd expect as a user of vmap? Or should there be some way to change the unaffected outputs back to their original dimension with the standard vmap API (None doesn't work as the outputs have already been batched).

There's no way to know which outputs are affected by which inputs, so yes this is the intended behavior. We may want to add additional bells and whistles later, but for now this should be the best we can do without a lot more work.

@dionhaefner
Copy link
Contributor

The other consideration I had was whether we should add any assertions that batch dims are equal or just rely on standard vmaps checks.

Don't think we need that, error messages should be informative enough in that case.

@jpbrodrick89
Copy link
Contributor Author

@dionhaefner this should be good for the final sweep now. Agree that fixing the non-batched arguments is not possible without tracing through the Tesseract, maybe in the future we could add a "differentiable_only" option and then it only batches over Differentiable outputs which is probably the most common use case.

By the way the second argument of the batching rule is the axis that has been batched over. When vmapping this is referred to if needing to transpose to another axis.

With regards to supporting multiple tangent vectors as default which one of the three options below should I go with?:

  1. Raise an issue to discuss options and point to on slack, perhaps leading to a mini-psync if hard to resolve
  2. Open discussion on dev-tesseract
  3. Outline ideas and justification in Confluence page and schedule a mini-psync.

@dionhaefner
Copy link
Contributor

Option 1 sounds like a good default, although I'd nitpick that describing problems is more useful than describing solutions (where solutions emerge from the team / community by themselves once the problem is described well enough).

Copy link
Contributor

@dionhaefner dionhaefner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jpbrodrick89!

@dionhaefner dionhaefner changed the title feat: support simple sequential batching rule feat: support simple sequential batching rule (support for jax.vmap and jax.jacobian) Jun 23, 2025
Co-authored-by: Dion Häfner <dion.haefner@simulation.science>
@jpbrodrick89 jpbrodrick89 merged commit d52ef6b into main Jun 23, 2025
10 checks passed
@jpbrodrick89 jpbrodrick89 deleted the jpb/batching branch June 23, 2025 19:36
@pasteurlabs pasteurlabs locked and limited conversation to collaborators Jun 23, 2025
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add batching rule to support vmap

3 participants