diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index bf974c2bd22..d917fa21d10 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -128,6 +128,12 @@ jobs: set -x pip install expecttest unittest-xml-reporting + pip install torch_xla[pallas] \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + + # Install torchax + pip install pytorch/xla/torchax if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then pip install -r pytorch/xla/benchmarks/requirements.txt diff --git a/test/run_tests.sh b/test/run_tests.sh index fbb970eec62..46b729338b7 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -208,6 +208,7 @@ function run_xla_op_tests2 { run_test "$CDIR/eager/test_eager_spmd.py" run_test "$CDIR/test_callback.py" XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py" + run_test "$CDIR/test_jax_interop.py" } # All the new xla op tests should go to run_xla_op_tests3 diff --git a/test/test_jax_interop.py b/test/test_jax_interop.py new file mode 100644 index 00000000000..7f82555d712 --- /dev/null +++ b/test/test_jax_interop.py @@ -0,0 +1,57 @@ +from absl.testing import absltest + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.core.xla_builder as xb + + +class TestJaxInterop(absltest.TestCase): + + def test_call_jax(self): + """ + Test that we can call a JAX function from PyTorch/XLA lazy tensor tracing. + """ + + dev = xm.xla_device() + a = torch.ones((3, 3), device=dev) + + def f(a, b): + import jax.numpy as jnp + return a + jnp.sin(b) + + b = xb.call_jax(f, (a, a), {}, 'my_test') + torch_xla.sync() + torch.testing.assert_close( + b, torch.sin(torch.ones(3, 3)) + 1, check_device=False) + + def test_call_jax_pytree(self): + """ + Test that call_jax works with PyTree inputs. + """ + dev = xm.xla_device() + a = torch.ones((2, 2), device=dev) + b = torch.ones((2, 2), device=dev) * 2 + + def f(inputs): + a = inputs['a'] + b = inputs['b'] + return a @ b + + inputs = {'a': a, 'b': b} + c = xb.call_jax(f, (inputs,)) + torch_xla.sync() + torch.testing.assert_close( + c, + torch.tensor( + [ + [4, 4], + [4, 4], + ], + dtype=torch.float32, + ), + check_device=False) + + +if __name__ == "__main__": + absltest.main() diff --git a/torch_xla/core/xla_builder.py b/torch_xla/core/xla_builder.py index 126b0e889d9..f5fdac2b126 100644 --- a/torch_xla/core/xla_builder.py +++ b/torch_xla/core/xla_builder.py @@ -1,5 +1,7 @@ import torch import torch_xla +from torch.utils._pytree import tree_flatten, tree_unflatten +from torch_xla.experimental.custom_kernel import jax_import_guard class Type: @@ -799,3 +801,48 @@ def computation_from_module_proto(name, proto): def get_computation_hlo(computation): return torch_xla._XLAC._xla_computation_text(computation) + + +def call_jax(jax_func, args, kwargs=None, name=None): + """ + Call a JAX function `jax_func` with the given `args` and `kwargs` that may contain + XLA tensors. + """ + + if name is None: + name = 'jax_func_' + jax_func.__name__ + kwargs = kwargs or {} + + # If we don't do this before calling jax, any torch_xla operation will hang. + jax_import_guard() + + import jax + import torchax.ops.mappings as mappings + + flattened, spec = tree_flatten((args, kwargs)) + + def fn_flattened_inputs(*flattened): + args, kwargs = tree_unflatten(flattened, spec) + return jax_func(*args, **kwargs) + + sample_input_shapes = tuple( + jax.ShapeDtypeStruct(a.shape, mappings.t2j_dtype(a.dtype)) + for a in flattened) + # `as_serialized_hlo_module_proto` is mentioned at + # https://github.com/jax-ml/jax/discussions/22266 + hlo_module = jax.jit(fn_flattened_inputs).lower( + *sample_input_shapes).compiler_ir( + 'hlo').as_serialized_hlo_module_proto() # type: ignore + computation = computation_from_module_proto(name, hlo_module) + + builder = create_builder(name) + params = [] + for idx, val in enumerate(flattened): + params.append(mkparam(builder, idx, tensor_shape(val))) + call_op = Op.call(computation, params) + call_computation = call_op.build('call_jax') + result = torch_xla._XLAC._xla_user_computation(f'xla::call_jax_{name}', + flattened, call_computation) + if isinstance(result, list) and len(result) == 1: + return result[0] + return result