Skip to content

torchax fails on a simple matrix slicing example. #9644

@joaospinto

Description

@joaospinto

🐛 Bug

torchax fails on a simple matrix slicing example.

To Reproduce

Here is the code to repro:

import torch
import torchax as tx
import torchax.export

import jax
import jax.numpy as jnp

import sys


tx.enable_globally()


def f(M, p):
    return M[torch.arange(M.shape[0]), p]


class Wrapper(torch.nn.Module):
    def forward(self, M, p):
        return f(M, p)


def main():
    torch_outputs = Wrapper()(torch.arange(4).reshape([2, 2]), torch.tensor([1, 0]))

    print(f"{torch_outputs=}")

    M = jnp.arange(4).reshape([2, 2])
    p = jnp.array([1, 0])
    sample_input = (M, p)

    weights, jfunc = tx.extract_jax(Wrapper())

    def jfunc_inlined(args):
        return jfunc(weights, args)

    jitted = jax.jit(jfunc_inlined)

    jax_outputs = jitted(sample_input)

    print(f"{jax_outputs=}")


if __name__ == "__main__":
    main()

If you run it, you'll get:

AssertionError: Expect a Tensor or a View but got <class 'torch.Tensor'>; usually this means there is a mixed math between XLATensor and torch.Tensor

Expected behavior

jax_outputs should be computed without errors and match the torch_outputs value.

Environment

einops==0.8.1
filelock==3.19.1
fsspec==2025.9.0
jax==0.7.1
jaxlib==0.7.1
Jinja2==3.1.6
MarkupSafe==3.0.2
ml_dtypes==0.5.3
mpmath==1.3.0
networkx==3.5
numpy==2.3.3
opt_einsum==3.4.0
scipy==1.16.2
setuptools==80.9.0
sympy==1.14.0
torch==2.8.0
torchax==0.0.7
typing_extensions==4.15.0

Additional context

No additional context; should be pretty clear.

Metadata

Metadata

Assignees

Labels

torchxla2triage reviewIssues that need to be reviewed by the triage team.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions