Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ jobs:
set -x

pip install expecttest unittest-xml-reporting
pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

# Install torchax
pip install pytorch/xla/torchax

if [[ ! -z "$RUN_BENCHMARK_TESTS" ]]; then
pip install -r pytorch/xla/benchmarks/requirements.txt
Expand Down
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/eager/test_eager_spmd.py"
run_test "$CDIR/test_callback.py"
XLA_USE_SPMD=1 run_test "$CDIR/test_callback.py"
run_test "$CDIR/test_jax_interop.py"
}

# All the new xla op tests should go to run_xla_op_tests3
Expand Down
57 changes: 57 additions & 0 deletions test/test_jax_interop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from absl.testing import absltest

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_builder as xb


class TestJaxInterop(absltest.TestCase):

def test_call_jax(self):
"""
Test that we can call a JAX function from PyTorch/XLA lazy tensor tracing.
"""

dev = xm.xla_device()
a = torch.ones((3, 3), device=dev)

def f(a, b):
import jax.numpy as jnp
return a + jnp.sin(b)

b = xb.call_jax(f, (a, a), {}, 'my_test')
torch_xla.sync()
torch.testing.assert_close(
b, torch.sin(torch.ones(3, 3)) + 1, check_device=False)

def test_call_jax_pytree(self):
"""
Test that call_jax works with PyTree inputs.
"""
dev = xm.xla_device()
a = torch.ones((2, 2), device=dev)
b = torch.ones((2, 2), device=dev) * 2

def f(inputs):
a = inputs['a']
b = inputs['b']
return a @ b

inputs = {'a': a, 'b': b}
c = xb.call_jax(f, (inputs,))
torch_xla.sync()
torch.testing.assert_close(
c,
torch.tensor(
[
[4, 4],
[4, 4],
],
dtype=torch.float32,
),
check_device=False)


if __name__ == "__main__":
absltest.main()
47 changes: 47 additions & 0 deletions torch_xla/core/xla_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
import torch_xla
from torch.utils._pytree import tree_flatten, tree_unflatten
from torch_xla.experimental.custom_kernel import jax_import_guard


class Type:
Expand Down Expand Up @@ -799,3 +801,48 @@ def computation_from_module_proto(name, proto):

def get_computation_hlo(computation):
return torch_xla._XLAC._xla_computation_text(computation)


def call_jax(jax_func, args, kwargs=None, name=None):
"""
Call a JAX function `jax_func` with the given `args` and `kwargs` that may contain
XLA tensors.
"""

if name is None:
name = 'jax_func_' + jax_func.__name__
kwargs = kwargs or {}

# If we don't do this before calling jax, any torch_xla operation will hang.
jax_import_guard()

import jax
import torchax.ops.mappings as mappings

flattened, spec = tree_flatten((args, kwargs))

def fn_flattened_inputs(*flattened):
args, kwargs = tree_unflatten(flattened, spec)
return jax_func(*args, **kwargs)

sample_input_shapes = tuple(
jax.ShapeDtypeStruct(a.shape, mappings.t2j_dtype(a.dtype))
for a in flattened)
# `as_serialized_hlo_module_proto` is mentioned at
# https://github.com/jax-ml/jax/discussions/22266
hlo_module = jax.jit(fn_flattened_inputs).lower(
*sample_input_shapes).compiler_ir(
'hlo').as_serialized_hlo_module_proto() # type: ignore
computation = computation_from_module_proto(name, hlo_module)

builder = create_builder(name)
params = []
for idx, val in enumerate(flattened):
params.append(mkparam(builder, idx, tensor_shape(val)))
call_op = Op.call(computation, params)
call_computation = call_op.build('call_jax')
result = torch_xla._XLAC._xla_user_computation(f'xla::call_jax_{name}',
flattened, call_computation)
if isinstance(result, list) and len(result) == 1:
return result[0]
return result
Loading