Skip to content

Conversation

aseyboldt
Copy link
Member

@aseyboldt aseyboldt commented Sep 16, 2025

revisit #1120, which seems abandoned.

@jdehning I hope it is ok if I continue this PR?


📚 Documentation preview 📚: https://pytensor--1614.org.readthedocs.build/en/1614/

@aseyboldt aseyboldt force-pushed the as-jax-opt2 branch 3 times, most recently from 10dfa2e to ead4ac7 Compare September 16, 2025 15:08
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements the as_jax_op decorator which allows JAX functions to be used within PyTensor graphs. The decorator wraps JAX functions to make them compatible with PyTensor's variable system while preserving gradient computation capabilities.

  • Implements JAXOp class for wrapping JAX functions as PyTensor operations
  • Creates as_jax_op decorator for easy conversion of JAX functions to PyTensor-compatible operations
  • Adds comprehensive test coverage for various input/output patterns and data types

Reviewed Changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
pytensor/link/jax/ops.py Core implementation of JAXOp class and as_jax_op decorator
tests/link/jax/test_as_jax_op.py Comprehensive test suite covering various use cases and data types
pytensor/init.py Exports as_jax_op function with fallback for missing dependencies
pytensor/link/jax/dispatch/basic.py JAX dispatch registration for the new JAXOp
doc/library/index.rst Documentation entry for the new functionality
doc/environment.yml Updates documentation environment to include JAX dependencies
doc/conf.py Adds Equinox to intersphinx mapping
.github/workflows/test.yml Updates CI to install Equinox dependency

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Copy link

codecov bot commented Sep 16, 2025

Codecov Report

❌ Patch coverage is 89.44444% with 19 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.67%. Comparing base (1dc982c) to head (c3874c1).
⚠️ Report is 31 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/jax/ops.py 88.95% 10 Missing and 9 partials ⚠️

❌ Your patch status has failed because the patch coverage (89.44%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1614      +/-   ##
==========================================
- Coverage   81.69%   81.67%   -0.03%     
==========================================
  Files         230      232       +2     
  Lines       52950    53132     +182     
  Branches     9404     9410       +6     
==========================================
+ Hits        43260    43396     +136     
- Misses       7256     7283      +27     
- Partials     2434     2453      +19     
Files with missing lines Coverage Δ
pytensor/compile/ops.py 83.91% <100.00%> (+0.46%) ⬆️
pytensor/link/jax/dispatch/basic.py 83.52% <100.00%> (+0.81%) ⬆️
pytensor/link/jax/ops.py 88.95% <88.95%> (ø)

... and 38 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

if any(s is None for s in shape):
_, shape = pt.basic.infer_static_shape(var.shape)
if any(s is None for s in shape):
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use this instead? https://docs.jax.dev/en/latest/export/shape_poly.html#shape-polymorphism

PyTensor only needs to know the dtype and rank of the outputs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that would be reliable.
I think jax will throw an error if the code tries to broadcast arrays when it cannot prove that they have compatible shapes.
If we have dims, we could use those to generate jax symbolic shapes. (But only with object dims, not with string ones I think?).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best we could do right now is create a new shape variable for every input dimension that is not statically known.
But then it would fail as soon as you even add two of those together.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it fail or does it infer they must match?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails:

u, v = export.symbolic_shape("u, v")

x1 = jax.ShapeDtypeStruct((u,), dtype=np.int32)
x2 = jax.ShapeDtypeStruct((v,), dtype=np.int32)

export.export(jax.jit(lambda x, y: x + y))(x1, x2)
# add got incompatible shapes for broadcasting: (u,), (v,).

@aseyboldt
Copy link
Member Author

I kept most of what was in the original PR, but made a few changes:

  • There is no longer a different op for the gradient. That is just again a JaxOp
  • I kept support for jax tree inputs and outputs, I think those are quite valuable. For instance when we have a neural network in a model, or if we want to solve an ODE, it is much nicer if we don't have to take apart all jax trees everywhere by hand. I did remove wrapping of returned functions though. That lead to some trouble if the jax trees contain callables that should not be wrapped, and seems overall a bit hackish to me. I also can't think of a use-case where we would really need that? If that does come along, maybe we can revisit this idea.

I think the biggest problem right now is that the as_jax_op wrapper needs pytensor inputs with static shapes.

If we want to avoid having the user specify the output types, we need to call the function at least once. We can do that with jax.infer_shape, but that still needs static shape info.

I'm not sure right now what the best way to handle this is.

@jdehning
Copy link
Contributor

revisit #1120, which seems abandoned.

@jdehning I hope it is ok if I continue this PR?

📚 Documentation preview 📚: https://pytensor--1614.org.readthedocs.build/en/1614/

Yes, sure. Sorry that I dropped the ball.

@jdehning
Copy link
Contributor

I think the biggest problem right now is that the as_jax_op wrapper needs pytensor inputs with static shapes.

I did use pytensor.compile.builders.infer_shape to get static shapes in the original PR. It did work for me for pymc models, if initial static shapes are lost because of a pt.cumsum. However, if I remember well, I didn't test whether it works with pm.Data, i.e. shared variables in the graph, and what happens when the shape of shared variables is changed between runs by setting new values

@jdehning
Copy link
Contributor

I did remove wrapping of returned functions though. That lead to some trouble if the jax trees contain callables that should not be wrapped, and seems overall a bit hackish to me. I also can't think of a use-case where we would really need that? If that does come along, maybe we can revisit this idea.

I wrote it for ODEs that depend on time-dependent parameters; we need a function that takes a time point and returns some time-changing variables that interpolate between parameters. Wrapping the callable was the most user-friendly way to achieve it, as it allows defining the interpolation function and ODE solver separately. However, I agree it was somewhat hackish and not easily parsable. And its usage can be reasonably well avoided if both the interpolation function and the ODE solver are defined in a single function.

@jdehning
Copy link
Contributor

+1 for jax_to_pytensor. It is for me the most easily understandable

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small comments, besides that we should decide on the name. Don't love jax/ops.py filename either.

@aseyboldt
Copy link
Member Author

@ricardoV94 I (hopefully) address all comments I didn't reply to directly.

On top of that, I also removed the equinox dependency. We only needed two small functions from it, so I just copied those over with a note in the source. It's Apache 2.0, so I think that's ok? (Also, it's just a few lines anyway...).

I still like the name wrap_jax and wrap_py. Not sure what to do about this. If you prefer as_jax_op we can also stick with that, should be easy to just undo those two renaming commits.

@ricardoV94
Copy link
Member

I still like the name wrap_jax and wrap_py

Sure, let's go with that. Feel free to open an issue to deprecate the old as_op name and change the docs. No need to slow this PR down for it

@ricardoV94
Copy link
Member

Ah you already did the renaming...

doc/conf.py Outdated
"jax": ("https://jax.readthedocs.io/en/latest", None),
"numpy": ("https://numpy.org/doc/stable", None),
"torch": ("https://pytorch.org/docs/stable", None),
"equinox": ("https://docs.kidger.site/equinox/", None),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"equinox": ("https://docs.kidger.site/equinox/", None),

@twiecki twiecki changed the title Implement as_jax_op Implement wrap_jax Sep 30, 2025
@ricardoV94
Copy link
Member

@aseyboldt nudge so we get this in the next release?

@aseyboldt
Copy link
Member Author

I fixed the doctests and the extra reference to equinox you mentioned.
I think that's all, or did I forget something?

@ricardoV94
Copy link
Member

Changed the deprecation to futurewarning, as users wouldn't see deprecation usually. Will merge once tests pass. Thanks a ton @jdehning and @aseyboldt

@ricardoV94
Copy link
Member

Sorry that I dropped the ball.

Getting 95% there is not dropping the ball in my book ;) Thanks!

@ricardoV94 ricardoV94 changed the title Implement wrap_jax Implement wrap_jax and rename as_op to wrap_py Sep 30, 2025
@aseyboldt
Copy link
Member Author

@ricardoV94 Is there still anything missing, or can we merge (and soon release...) this now?

@ricardoV94
Copy link
Member

Ugh missed it, it was meant to have gone with the last release

@ricardoV94 ricardoV94 merged commit b67ff22 into pymc-devs:main Oct 2, 2025
63 of 64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants