-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update example in "Adding Jax support for Ops" #654
Comments
@HarshvirSandhu thanks for reporting the issue. That tutorial was written back in a time when JAX + jitting was more flexible. These days all jitted functions must have constant shape, which means a graph like the one in the example can never be translated to JAX, since it's fundamentally a function with dynamic shapes. The first error can be avoided by specifying the scalar has an integer dtype, but that will only kick the can further. After also updating the dispatch function signature we get this: import jax.numpy as jnp
from pytensor.tensor.basic import Eye
from pytensor.link.jax.dispatch import jax_funcify
from tests.link.jax.test_basic import compare_jax_and_py
from pytensor.graph import FunctionGraph
import pytensor.tensor as pt
@jax_funcify.register(Eye)
def jax_funcify_Eye(op, node, **kwargs):
dtype = op.dtype
def eye(N, M, k):
return jnp.eye(N, M, k, dtype=dtype)
return eye
def test_jax_Eye():
"""Test JAX conversion of the `Eye` `Op`."""
x_at = pt.scalar(dtype="int")
eye_var = pt.eye(x_at)
out_fg = FunctionGraph(outputs=[eye_var])
compare_jax_and_py(out_fg, [3])
test_jax_Eye()
We should use a different example and perhaps keep this in a section of its own explaining that not all PyTensor graphs can be converted into a jitted JAX-graph, usually those that have dynamic shapes. |
Hi, I would like to take up this issue as part of my GSoC application. I will open a PR on this after changing to a more suitable example. |
Issue with current documentation:
While reproducing this example, I encountered the following error:
AssertionError: (Eye{dtype='float64'}(<Scalar(float64, shape=())>, <Scalar(float64, shape=())>, 0), 'float64')
fromtest_jax_Eye()
Complete code to reproduce the error:
Idea or request for content:
Instead of passing
x_at
topt.eye
, an integer can be used, like so:pytensor/tests/link/jax/test_tensor_basic.py
Lines 207 to 212 in e8693bd
The text was updated successfully, but these errors were encountered: