-
Notifications
You must be signed in to change notification settings - Fork 143
Implement wrap_jax
and rename as_op
to wrap_py
#1614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…he previous approach for testing purposes
…be used without the decorator @as_jax_op
10dfa2e
to
ead4ac7
Compare
ead4ac7
to
d04f41d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements the as_jax_op
decorator which allows JAX functions to be used within PyTensor graphs. The decorator wraps JAX functions to make them compatible with PyTensor's variable system while preserving gradient computation capabilities.
- Implements
JAXOp
class for wrapping JAX functions as PyTensor operations - Creates
as_jax_op
decorator for easy conversion of JAX functions to PyTensor-compatible operations - Adds comprehensive test coverage for various input/output patterns and data types
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
Show a summary per file
File | Description |
---|---|
pytensor/link/jax/ops.py | Core implementation of JAXOp class and as_jax_op decorator |
tests/link/jax/test_as_jax_op.py | Comprehensive test suite covering various use cases and data types |
pytensor/init.py | Exports as_jax_op function with fallback for missing dependencies |
pytensor/link/jax/dispatch/basic.py | JAX dispatch registration for the new JAXOp |
doc/library/index.rst | Documentation entry for the new functionality |
doc/environment.yml | Updates documentation environment to include JAX dependencies |
doc/conf.py | Adds Equinox to intersphinx mapping |
.github/workflows/test.yml | Updates CI to install Equinox dependency |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Codecov Report❌ Patch coverage is
❌ Your patch status has failed because the patch coverage (89.44%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1614 +/- ##
==========================================
- Coverage 81.69% 81.67% -0.03%
==========================================
Files 230 232 +2
Lines 52950 53132 +182
Branches 9404 9410 +6
==========================================
+ Hits 43260 43396 +136
- Misses 7256 7283 +27
- Partials 2434 2453 +19
🚀 New features to boost your workflow:
|
pytensor/link/jax/ops.py
Outdated
if any(s is None for s in shape): | ||
_, shape = pt.basic.infer_static_shape(var.shape) | ||
if any(s is None for s in shape): | ||
raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use this instead? https://docs.jax.dev/en/latest/export/shape_poly.html#shape-polymorphism
PyTensor only needs to know the dtype and rank of the outputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that would be reliable.
I think jax will throw an error if the code tries to broadcast arrays when it cannot prove that they have compatible shapes.
If we have dims, we could use those to generate jax symbolic shapes. (But only with object dims, not with string ones I think?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The best we could do right now is create a new shape variable for every input dimension that is not statically known.
But then it would fail as soon as you even add two of those together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it fail or does it infer they must match?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It fails:
u, v = export.symbolic_shape("u, v")
x1 = jax.ShapeDtypeStruct((u,), dtype=np.int32)
x2 = jax.ShapeDtypeStruct((v,), dtype=np.int32)
export.export(jax.jit(lambda x, y: x + y))(x1, x2)
# add got incompatible shapes for broadcasting: (u,), (v,).
I kept most of what was in the original PR, but made a few changes:
I think the biggest problem right now is that the If we want to avoid having the user specify the output types, we need to call the function at least once. We can do that with I'm not sure right now what the best way to handle this is. |
Yes, sure. Sorry that I dropped the ball. |
I did use pytensor.compile.builders.infer_shape to get static shapes in the original PR. It did work for me for pymc models, if initial static shapes are lost because of a pt.cumsum. However, if I remember well, I didn't test whether it works with pm.Data, i.e. shared variables in the graph, and what happens when the shape of shared variables is changed between runs by setting new values |
I wrote it for ODEs that depend on time-dependent parameters; we need a function that takes a time point and returns some time-changing variables that interpolate between parameters. Wrapping the callable was the most user-friendly way to achieve it, as it allows defining the interpolation function and ODE solver separately. However, I agree it was somewhat hackish and not easily parsable. And its usage can be reasonably well avoided if both the interpolation function and the ODE solver are defined in a single function. |
+1 for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small comments, besides that we should decide on the name. Don't love jax/ops.py
filename either.
@ricardoV94 I (hopefully) address all comments I didn't reply to directly. On top of that, I also removed the equinox dependency. We only needed two small functions from it, so I just copied those over with a note in the source. It's Apache 2.0, so I think that's ok? (Also, it's just a few lines anyway...). I still like the name |
Sure, let's go with that. Feel free to open an issue to deprecate the old |
Ah you already did the renaming... |
doc/conf.py
Outdated
"jax": ("https://jax.readthedocs.io/en/latest", None), | ||
"numpy": ("https://numpy.org/doc/stable", None), | ||
"torch": ("https://pytorch.org/docs/stable", None), | ||
"equinox": ("https://docs.kidger.site/equinox/", None), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"equinox": ("https://docs.kidger.site/equinox/", None), |
@aseyboldt nudge so we get this in the next release? |
I fixed the doctests and the extra reference to equinox you mentioned. |
Changed the deprecation to futurewarning, as users wouldn't see deprecation usually. Will merge once tests pass. Thanks a ton @jdehning and @aseyboldt |
Getting 95% there is not dropping the ball in my book ;) Thanks! |
wrap_jax
and rename as_op
to wrap_py
@ricardoV94 Is there still anything missing, or can we merge (and soon release...) this now? |
Ugh missed it, it was meant to have gone with the last release |
revisit #1120, which seems abandoned.
@jdehning I hope it is ok if I continue this PR?
📚 Documentation preview 📚: https://pytensor--1614.org.readthedocs.build/en/1614/