Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions test/test_jax_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
class TestJaxInterop(absltest.TestCase):

def test_call_jax(self):
"""
Test that we can call a JAX function from PyTorch/XLA lazy tensor tracing.
"""
"""Test that we can call a JAX function from PyTorch/XLA lazy tensor tracing."""

dev = xm.xla_device()
a = torch.ones((3, 3), device=dev)
Expand All @@ -26,9 +24,8 @@ def f(a, b):
b, torch.sin(torch.ones(3, 3)) + 1, check_device=False)

def test_call_jax_pytree(self):
"""
Test that call_jax works with PyTree inputs.
"""
"""Test that call_jax works with PyTree inputs."""

dev = xm.xla_device()
a = torch.ones((2, 2), device=dev)
b = torch.ones((2, 2), device=dev) * 2
Expand All @@ -52,6 +49,60 @@ def f(inputs):
),
check_device=False)

def test_call_jax_some_arg_unused(self):
"""Test when the jax function doesn't use some input arguments."""

dev = xm.xla_device()
a = torch.randn((3, 3), device=dev)
b = torch.randn((3, 3), device=dev)
c = torch.randn((3, 3), device=dev)
d = torch.randn((3, 3), device=dev)

def f(a, b, c, d):
import jax.numpy as jnp
return a + jnp.sin(b)

o = xb.call_jax(f, (a, b, c, d), {}, 'my_test')
torch_xla.sync()
torch.testing.assert_close(o, a + torch.sin(b), check_device=False)

def test_call_jax_grad(self):
"""Test calling a simple jax.grad transformed function."""

dev = xm.xla_device()
a = torch.randn((3, 3), device=dev, requires_grad=True)
b = torch.randn((3, 3), device=dev, requires_grad=True)
torch_xla.sync()

import jax

def f_torch(a, b):
return torch.sum(a + torch.sin(b))

def f_backward_torch(f, a, b):
out = f(a, b)
out.backward()
return a.grad, b.grad

def f_jax(a, b):
import jax.numpy as jnp
# JAX optimizes a's grad as constant, so it will never use `a`.
# We should support that.
return jnp.sum(a + jnp.sin(b))

grad_f_jax = jax.grad(f_jax, argnums=(0, 1))

out_torch = f_torch(a, b)
out_grad_torch = f_backward_torch(f_torch, a, b)
out_jax = xb.call_jax(f_jax, (a, b), {})
out_grad_jax = xb.call_jax(grad_f_jax, (a, b), {})
torch_xla.sync()

# forward should produce same output
torch.testing.assert_close(out_torch, out_jax)
# backward should produce same gradient
torch.testing.assert_close(out_grad_torch, out_grad_jax)


if __name__ == "__main__":
absltest.main()
5 changes: 3 additions & 2 deletions torch_xla/core/xla_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,8 +869,9 @@ def fn_flattened_inputs(*flattened):
for a in flattened)
# `as_serialized_hlo_module_proto` is mentioned at
# https://github.com/jax-ml/jax/discussions/22266
hlo_module = jax.jit(fn_flattened_inputs).lower(
*sample_input_shapes).compiler_ir(
hlo_module = jax.jit(
fn_flattened_inputs,
keep_unused=True).lower(*sample_input_shapes).compiler_ir(
'hlo').as_serialized_hlo_module_proto() # type: ignore

return XlaComputation(name, hlo_module, flattened)
Expand Down
Loading