From f817acba043d08f33f606544ad5a522cab3151be Mon Sep 17 00:00:00 2001 From: Han Qi Date: Mon, 3 Mar 2025 00:26:28 +0000 Subject: [PATCH 1/6] test: call jax from torch_xla --- torch_xla/core/xla_builder.py | 23 ++++++++++++++++++++++- torch_xla/csrc/init_python_bindings.cpp | 8 ++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/torch_xla/core/xla_builder.py b/torch_xla/core/xla_builder.py index 126b0e889d9..882d92a1829 100644 --- a/torch_xla/core/xla_builder.py +++ b/torch_xla/core/xla_builder.py @@ -1,6 +1,6 @@ import torch import torch_xla - +from torch.fx import _pytree as pytree class Type: F32 = 'f32' @@ -799,3 +799,24 @@ def computation_from_module_proto(name, proto): def get_computation_hlo(computation): return torch_xla._XLAC._xla_computation_text(computation) + +def call_jax(jax_func, args, kwargs=None, name=None): + if name is None: + name = 'jax_func_' + jax_func.__name__ + kwargs = kwargs or {} + import jax + import torchax.ops.mappings as mappings + + flattened, spec = pytree.tree_flatten((args, kwargs)) + def fn_flattened_inputs(*flattened): + args, kwargs = pytree.tree_unflatten(flattened, spec) + return jax_func(*args, **kwargs) + + sample_input_shapes = tuple( + jax.ShapeDtypeStruct(a.shape, mappings.t2j_dtype(a.dtype)) + for a in flattened + ) + hlo_text = jax.jit(jax_func).lower(*sample_input_shapes).as_text('hlo') + hlo_proto = torch_xla._XLAC._xla_computation_text_to_proto(hlo_text) + computation = computation_from_module_proto(name, hlo_proto) + return Op.call(computation, flattened) \ No newline at end of file diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index ea9e49f3f1a..6136cffd55b 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2159,6 +2159,14 @@ void InitXlaModuleBindings(py::module m) { } return hlo_text; }); + m.def("_xla_computation_text_to_proto", + [](const std::string& text) { + auto hlo_module = xla::ParseAndReturnUnverifiedModule( + text + ); + return py::bytes(hlo_module.value()->ToProto().SerializeAsString()); + } + ); m.def("_xla_op_shape", [](op_builder::OpPtr op) { const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(op->op); return op_builder::ShapeToPyShape(shape); From 08d6dc4b54fb7ea9e492240ed05495f89d553949 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Mon, 3 Mar 2025 18:36:16 +0000 Subject: [PATCH 2/6] add jax interop --- jax_interop.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 jax_interop.py diff --git a/jax_interop.py b/jax_interop.py new file mode 100644 index 00000000000..1224cb58492 --- /dev/null +++ b/jax_interop.py @@ -0,0 +1,15 @@ +import torch +import torch_xla.core.xla_model as xm + +dev = xm.xla_device() + +a = torch.ones((3,3), device=dev) + +import torch_xla.core.xla_builder as xb + +import jax.numpy as jnp +def f(a, b): + return a + jnp.sin(b) + +b = xb.call_jax(f, (a, a), {}, 'hame') +print(b) \ No newline at end of file From 1f4f65d01a6daa319d2729d58f484b3bdabb43a6 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Mon, 3 Mar 2025 12:42:42 -0800 Subject: [PATCH 3/6] wip --- jax_interop.py | 15 --------------- test/run_tests.sh | 1 + test/test_jax_interop.py | 24 ++++++++++++++++++++++++ torch_xla/core/xla_builder.py | 14 ++++++++------ 4 files changed, 33 insertions(+), 21 deletions(-) delete mode 100644 jax_interop.py create mode 100644 test/test_jax_interop.py diff --git a/jax_interop.py b/jax_interop.py deleted file mode 100644 index 1224cb58492..00000000000 --- a/jax_interop.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch -import torch_xla.core.xla_model as xm - -dev = xm.xla_device() - -a = torch.ones((3,3), device=dev) - -import torch_xla.core.xla_builder as xb - -import jax.numpy as jnp -def f(a, b): - return a + jnp.sin(b) - -b = xb.call_jax(f, (a, a), {}, 'hame') -print(b) \ No newline at end of file diff --git a/test/run_tests.sh b/test/run_tests.sh index fbb970eec62..46b729338b7 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -208,6 +208,7 @@ function run_xla_op_tests2 { run_test "$CDIR/eager/test_eager_spmd.py" run_test "$CDIR/test_callback.py" XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py" + run_test "$CDIR/test_jax_interop.py" } # All the new xla op tests should go to run_xla_op_tests3 diff --git a/test/test_jax_interop.py b/test/test_jax_interop.py new file mode 100644 index 00000000000..39654752294 --- /dev/null +++ b/test/test_jax_interop.py @@ -0,0 +1,24 @@ +from absl.testing import absltest + +import torch +import torch_xla.core.xla_model as xm +import torch_xla.core.xla_builder as xb + + +class TestJaxInterop(absltest.TestCase): + + def test_call_jax(self): + import jax.numpy as jnp + + dev = xm.xla_device() + a = torch.ones((3, 3), device=dev) + + def f(a, b): + return a + jnp.sin(b) + + b = xb.call_jax(f, (a, a), {}, 'hame') + print(b) + + +if __name__ == "__main__": + absltest.main() diff --git a/torch_xla/core/xla_builder.py b/torch_xla/core/xla_builder.py index 882d92a1829..9b238d69598 100644 --- a/torch_xla/core/xla_builder.py +++ b/torch_xla/core/xla_builder.py @@ -2,6 +2,7 @@ import torch_xla from torch.fx import _pytree as pytree + class Type: F32 = 'f32' F64 = 'f64' @@ -800,23 +801,24 @@ def computation_from_module_proto(name, proto): def get_computation_hlo(computation): return torch_xla._XLAC._xla_computation_text(computation) + def call_jax(jax_func, args, kwargs=None, name=None): if name is None: name = 'jax_func_' + jax_func.__name__ kwargs = kwargs or {} import jax import torchax.ops.mappings as mappings - + flattened, spec = pytree.tree_flatten((args, kwargs)) + def fn_flattened_inputs(*flattened): args, kwargs = pytree.tree_unflatten(flattened, spec) return jax_func(*args, **kwargs) - + sample_input_shapes = tuple( - jax.ShapeDtypeStruct(a.shape, mappings.t2j_dtype(a.dtype)) - for a in flattened - ) + jax.ShapeDtypeStruct(a.shape, mappings.t2j_dtype(a.dtype)) + for a in flattened) hlo_text = jax.jit(jax_func).lower(*sample_input_shapes).as_text('hlo') hlo_proto = torch_xla._XLAC._xla_computation_text_to_proto(hlo_text) computation = computation_from_module_proto(name, hlo_proto) - return Op.call(computation, flattened) \ No newline at end of file + return Op.call(computation, flattened) From 2a33afac9460bb6cfbcd3e5302d41fc8ff80a9f3 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Mon, 3 Mar 2025 16:52:49 -0800 Subject: [PATCH 4/6] Fix HLO verifier crash The key is to use `as_serialized_hlo_module_proto`. Our self cooked `_xla_computation_text_to_proto` causes an undefined op reference error. --- test/test_jax_interop.py | 12 ++++++-- torch_xla/core/xla_builder.py | 38 ++++++++++++++++++++----- torch_xla/csrc/init_python_bindings.cpp | 8 ------ 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/test/test_jax_interop.py b/test/test_jax_interop.py index 39654752294..5454559c914 100644 --- a/test/test_jax_interop.py +++ b/test/test_jax_interop.py @@ -1,6 +1,7 @@ from absl.testing import absltest import torch +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.core.xla_builder as xb @@ -8,16 +9,21 @@ class TestJaxInterop(absltest.TestCase): def test_call_jax(self): - import jax.numpy as jnp + """ + Test that we can call a JAX function from PyTorch/XLA lazy tensor tracing. + """ dev = xm.xla_device() a = torch.ones((3, 3), device=dev) def f(a, b): + import jax.numpy as jnp return a + jnp.sin(b) - b = xb.call_jax(f, (a, a), {}, 'hame') - print(b) + b = xb.call_jax(f, (a, a), {}, 'my_test') + torch_xla.sync() + torch.testing.assert_close( + b, torch.sin(torch.ones(3, 3)) + 1, check_device=False) if __name__ == "__main__": diff --git a/torch_xla/core/xla_builder.py b/torch_xla/core/xla_builder.py index 9b238d69598..f5fdac2b126 100644 --- a/torch_xla/core/xla_builder.py +++ b/torch_xla/core/xla_builder.py @@ -1,6 +1,7 @@ import torch import torch_xla -from torch.fx import _pytree as pytree +from torch.utils._pytree import tree_flatten, tree_unflatten +from torch_xla.experimental.custom_kernel import jax_import_guard class Type: @@ -803,22 +804,45 @@ def get_computation_hlo(computation): def call_jax(jax_func, args, kwargs=None, name=None): + """ + Call a JAX function `jax_func` with the given `args` and `kwargs` that may contain + XLA tensors. + """ + if name is None: name = 'jax_func_' + jax_func.__name__ kwargs = kwargs or {} + + # If we don't do this before calling jax, any torch_xla operation will hang. + jax_import_guard() + import jax import torchax.ops.mappings as mappings - flattened, spec = pytree.tree_flatten((args, kwargs)) + flattened, spec = tree_flatten((args, kwargs)) def fn_flattened_inputs(*flattened): - args, kwargs = pytree.tree_unflatten(flattened, spec) + args, kwargs = tree_unflatten(flattened, spec) return jax_func(*args, **kwargs) sample_input_shapes = tuple( jax.ShapeDtypeStruct(a.shape, mappings.t2j_dtype(a.dtype)) for a in flattened) - hlo_text = jax.jit(jax_func).lower(*sample_input_shapes).as_text('hlo') - hlo_proto = torch_xla._XLAC._xla_computation_text_to_proto(hlo_text) - computation = computation_from_module_proto(name, hlo_proto) - return Op.call(computation, flattened) + # `as_serialized_hlo_module_proto` is mentioned at + # https://github.com/jax-ml/jax/discussions/22266 + hlo_module = jax.jit(fn_flattened_inputs).lower( + *sample_input_shapes).compiler_ir( + 'hlo').as_serialized_hlo_module_proto() # type: ignore + computation = computation_from_module_proto(name, hlo_module) + + builder = create_builder(name) + params = [] + for idx, val in enumerate(flattened): + params.append(mkparam(builder, idx, tensor_shape(val))) + call_op = Op.call(computation, params) + call_computation = call_op.build('call_jax') + result = torch_xla._XLAC._xla_user_computation(f'xla::call_jax_{name}', + flattened, call_computation) + if isinstance(result, list) and len(result) == 1: + return result[0] + return result diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 6136cffd55b..ea9e49f3f1a 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2159,14 +2159,6 @@ void InitXlaModuleBindings(py::module m) { } return hlo_text; }); - m.def("_xla_computation_text_to_proto", - [](const std::string& text) { - auto hlo_module = xla::ParseAndReturnUnverifiedModule( - text - ); - return py::bytes(hlo_module.value()->ToProto().SerializeAsString()); - } - ); m.def("_xla_op_shape", [](op_builder::OpPtr op) { const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(op->op); return op_builder::ShapeToPyShape(shape); From ae6815a8868fa9162fc8229e7137ec52be7976b8 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Mon, 3 Mar 2025 19:49:08 -0800 Subject: [PATCH 5/6] Fix test on CI --- .github/workflows/_test.yml | 3 +++ test/test_jax_interop.py | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index bf974c2bd22..4c840f81b5d 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -128,6 +128,9 @@ jobs: set -x pip install expecttest unittest-xml-reporting + pip install torch_xla[pallas] \ + -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ + -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then pip install -r pytorch/xla/benchmarks/requirements.txt diff --git a/test/test_jax_interop.py b/test/test_jax_interop.py index 5454559c914..7f82555d712 100644 --- a/test/test_jax_interop.py +++ b/test/test_jax_interop.py @@ -25,6 +25,33 @@ def f(a, b): torch.testing.assert_close( b, torch.sin(torch.ones(3, 3)) + 1, check_device=False) + def test_call_jax_pytree(self): + """ + Test that call_jax works with PyTree inputs. + """ + dev = xm.xla_device() + a = torch.ones((2, 2), device=dev) + b = torch.ones((2, 2), device=dev) * 2 + + def f(inputs): + a = inputs['a'] + b = inputs['b'] + return a @ b + + inputs = {'a': a, 'b': b} + c = xb.call_jax(f, (inputs,)) + torch_xla.sync() + torch.testing.assert_close( + c, + torch.tensor( + [ + [4, 4], + [4, 4], + ], + dtype=torch.float32, + ), + check_device=False) + if __name__ == "__main__": absltest.main() From 877adbe5b21a90aa923f92f00eb4cd13d6482698 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Mon, 3 Mar 2025 20:21:46 -0800 Subject: [PATCH 6/6] Also install torchax in PyTorch/XLA CI --- .github/workflows/_test.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 4c840f81b5d..d917fa21d10 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -131,6 +131,9 @@ jobs: pip install torch_xla[pallas] \ -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html + + # Install torchax + pip install pytorch/xla/torchax if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then pip install -r pytorch/xla/benchmarks/requirements.txt