Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update example in "Adding Jax support for Ops" #654

Closed
HarshvirSandhu opened this issue Feb 28, 2024 · 2 comments · Fixed by #687
Closed

Update example in "Adding Jax support for Ops" #654

HarshvirSandhu opened this issue Feb 28, 2024 · 2 comments · Fixed by #687
Labels
docs help wanted Extra attention is needed jax

Comments

@HarshvirSandhu
Copy link
Contributor

Issue with current documentation:

While reproducing this example, I encountered the following error:
AssertionError: (Eye{dtype='float64'}(<Scalar(float64, shape=())>, <Scalar(float64, shape=())>, 0), 'float64') from test_jax_Eye()

Complete code to reproduce the error:

import jax.numpy as jnp

from pytensor.tensor.basic import Eye
from pytensor.link.jax.dispatch import jax_funcify
from tests.link.jax.test_basic import compare_jax_and_py
from pytensor.graph import FunctionGraph
import pytensor.tensor as pt

@jax_funcify.register(Eye)
def jax_funcify_Eye(op):

    dtype = op.dtype

    def eye(N, M, k):
        return jnp.eye(N, M, k, dtype=dtype)

    return eye


def test_jax_Eye():
    """Test JAX conversion of the `Eye` `Op`."""

    x_at = pt.scalar()
    eye_var = pt.eye(x_at)

    out_fg = FunctionGraph(outputs=[eye_var])

    compare_jax_and_py(out_fg, [3])


test_jax_Eye()

Idea or request for content:

Instead of passing x_at to pt.eye, an integer can be used, like so:

def test_jax_eye():
"""Tests jaxification of the Eye operator"""
out = ptb.eye(3)
out_fg = FunctionGraph([], [out])
compare_jax_and_py(out_fg, [])

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 12, 2024

@HarshvirSandhu thanks for reporting the issue.

That tutorial was written back in a time when JAX + jitting was more flexible. These days all jitted functions must have constant shape, which means a graph like the one in the example can never be translated to JAX, since it's fundamentally a function with dynamic shapes. The first error can be avoided by specifying the scalar has an integer dtype, but that will only kick the can further. After also updating the dispatch function signature we get this:

import jax.numpy as jnp

from pytensor.tensor.basic import Eye
from pytensor.link.jax.dispatch import jax_funcify
from tests.link.jax.test_basic import compare_jax_and_py
from pytensor.graph import FunctionGraph
import pytensor.tensor as pt

@jax_funcify.register(Eye)
def jax_funcify_Eye(op, node, **kwargs):

    dtype = op.dtype

    def eye(N, M, k):
        return jnp.eye(N, M, k, dtype=dtype)

    return eye


def test_jax_Eye():
    """Test JAX conversion of the `Eye` `Op`."""

    x_at = pt.scalar(dtype="int")
    eye_var = pt.eye(x_at)

    out_fg = FunctionGraph(outputs=[eye_var])

    compare_jax_and_py(out_fg, [3])


test_jax_Eye()
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=1/0)>,). 'N' argument of jnp.eye().
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

We should use a different example and perhaps keep this in a section of its own explaining that not all PyTensor graphs can be converted into a jitted JAX-graph, usually those that have dynamic shapes.

@ricardoV94 ricardoV94 added help wanted Extra attention is needed jax labels Mar 12, 2024
@ricardoV94 ricardoV94 changed the title DOC: Correction in Adding Jax support for Ops Update example in "Adding Jax support for Ops" Mar 12, 2024
@HangenYuu
Copy link
Contributor

Hi,

I would like to take up this issue as part of my GSoC application. I will open a PR on this after changing to a more suitable example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docs help wanted Extra attention is needed jax
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants