diff --git a/test/test_jax_interop.py b/test/test_jax_interop.py index 7f82555d7127..7ad41cf54c53 100644 --- a/test/test_jax_interop.py +++ b/test/test_jax_interop.py @@ -9,9 +9,7 @@ class TestJaxInterop(absltest.TestCase): def test_call_jax(self): - """ - Test that we can call a JAX function from PyTorch/XLA lazy tensor tracing. - """ + """Test that we can call a JAX function from PyTorch/XLA lazy tensor tracing.""" dev = xm.xla_device() a = torch.ones((3, 3), device=dev) @@ -26,9 +24,8 @@ def f(a, b): b, torch.sin(torch.ones(3, 3)) + 1, check_device=False) def test_call_jax_pytree(self): - """ - Test that call_jax works with PyTree inputs. - """ + """Test that call_jax works with PyTree inputs.""" + dev = xm.xla_device() a = torch.ones((2, 2), device=dev) b = torch.ones((2, 2), device=dev) * 2 @@ -52,6 +49,60 @@ def f(inputs): ), check_device=False) + def test_call_jax_some_arg_unused(self): + """Test when the jax function doesn't use some input arguments.""" + + dev = xm.xla_device() + a = torch.randn((3, 3), device=dev) + b = torch.randn((3, 3), device=dev) + c = torch.randn((3, 3), device=dev) + d = torch.randn((3, 3), device=dev) + + def f(a, b, c, d): + import jax.numpy as jnp + return a + jnp.sin(b) + + o = xb.call_jax(f, (a, b, c, d), {}, 'my_test') + torch_xla.sync() + torch.testing.assert_close(o, a + torch.sin(b), check_device=False) + + def test_call_jax_grad(self): + """Test calling a simple jax.grad transformed function.""" + + dev = xm.xla_device() + a = torch.randn((3, 3), device=dev, requires_grad=True) + b = torch.randn((3, 3), device=dev, requires_grad=True) + torch_xla.sync() + + import jax + + def f_torch(a, b): + return torch.sum(a + torch.sin(b)) + + def f_backward_torch(f, a, b): + out = f(a, b) + out.backward() + return a.grad, b.grad + + def f_jax(a, b): + import jax.numpy as jnp + # JAX optimizes a's grad as constant, so it will never use `a`. + # We should support that. + return jnp.sum(a + jnp.sin(b)) + + grad_f_jax = jax.grad(f_jax, argnums=(0, 1)) + + out_torch = f_torch(a, b) + out_grad_torch = f_backward_torch(f_torch, a, b) + out_jax = xb.call_jax(f_jax, (a, b), {}) + out_grad_jax = xb.call_jax(grad_f_jax, (a, b), {}) + torch_xla.sync() + + # forward should produce same output + torch.testing.assert_close(out_torch, out_jax) + # backward should produce same gradient + torch.testing.assert_close(out_grad_torch, out_grad_jax) + if __name__ == "__main__": absltest.main() diff --git a/torch_xla/core/xla_builder.py b/torch_xla/core/xla_builder.py index 76de1e856ea0..02e9c5c94b70 100644 --- a/torch_xla/core/xla_builder.py +++ b/torch_xla/core/xla_builder.py @@ -869,8 +869,9 @@ def fn_flattened_inputs(*flattened): for a in flattened) # `as_serialized_hlo_module_proto` is mentioned at # https://github.com/jax-ml/jax/discussions/22266 - hlo_module = jax.jit(fn_flattened_inputs).lower( - *sample_input_shapes).compiler_ir( + hlo_module = jax.jit( + fn_flattened_inputs, + keep_unused=True).lower(*sample_input_shapes).compiler_ir( 'hlo').as_serialized_hlo_module_proto() # type: ignore return XlaComputation(name, hlo_module, flattened)