diff --git a/configuration.yaml b/configuration.yaml index 6c1a1844ca6..d1a160e6f46 100644 --- a/configuration.yaml +++ b/configuration.yaml @@ -388,3 +388,9 @@ variables: your code. type: bool default_value: false + TORCH_XLA_ENABLE_JAX: + description: + - Controls JAX integration in PyTorch/XLA. When set to 1, enables JAX + features and imports. When set to 0, disables JAX features silently. + When unset, shows a warning message with guidance. + type: int diff --git a/test/run_tests.sh b/test/run_tests.sh index 033089d651f..8abdb149403 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -179,6 +179,7 @@ function run_xla_op_tests1 { run_test "$_TEST_DIR/test_fp8.py" run_xla_ir_debug run_test "$_TEST_DIR/test_env_var_mapper.py" run_xla_hlo_debug run_test "$_TEST_DIR/test_env_var_mapper.py" + run_test "$_TEST_DIR/test_jax_env_var.py" run_xla_hlo_debug run_test "$_TEST_DIR/stablehlo/test_stablehlo_save_load.py" run_save_tensor_ir run_test "$_TEST_DIR/spmd/test_spmd_graph_dump.py" run_save_tensor_hlo run_test "$_TEST_DIR/spmd/test_spmd_graph_dump.py" diff --git a/test/test_jax_env_var.py b/test/test_jax_env_var.py new file mode 100644 index 00000000000..f8d3b8e53e2 --- /dev/null +++ b/test/test_jax_env_var.py @@ -0,0 +1,62 @@ +import os +from absl.testing import absltest +from unittest.mock import patch +from torch_xla._internal.jax_workarounds import maybe_get_jax + + +class TestJaxEnvVar(absltest.TestCase): + + def setUp(self): + # Clean up environment + if 'TORCH_XLA_ENABLE_JAX' in os.environ: + del os.environ['TORCH_XLA_ENABLE_JAX'] + + def test_jax_enabled_attempts_import(self): + os.environ['TORCH_XLA_ENABLE_JAX'] = '1' + with patch( + 'torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: + result = maybe_get_jax() + self.assertIsNone(result) + mock_warn.assert_called_once() + self.assertIn('JAX explicitly enabled but not installed', + mock_warn.call_args[0][0]) + + def test_jax_unset_warns_with_guidance(self): + """Test that unset environment variable warns with guidance""" + with patch( + 'torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: + result = maybe_get_jax() + self.assertIsNone(result) + mock_warn.assert_called_once() + warning_msg = mock_warn.call_args[0][0] + self.assertIn('You are trying to use a feature that requires JAX', + warning_msg) + self.assertIn('TORCH_XLA_ENABLE_JAX=1', warning_msg) + self.assertIn('TORCH_XLA_ENABLE_JAX=0', warning_msg) + + def test_jax_disabled_values_silent(self): + for val in ['0']: + with self.subTest(val=val): + os.environ['TORCH_XLA_ENABLE_JAX'] = val + with patch( + 'torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: + result = maybe_get_jax() + self.assertIsNone(result) + mock_warn.assert_not_called() + + def test_jax_enabled_values(self): + """Test various values that enable JAX""" + for val in ['1']: + with self.subTest(val=val): + os.environ['TORCH_XLA_ENABLE_JAX'] = val + with patch( + 'torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: + result = maybe_get_jax() + self.assertIsNone(result) + mock_warn.assert_called_once() + self.assertIn('JAX explicitly enabled but not installed', + mock_warn.call_args[0][0]) + + +if __name__ == '__main__': + absltest.main() diff --git a/torch_xla/_internal/jax_workarounds.py b/torch_xla/_internal/jax_workarounds.py index d2d66570418..224a7dc5b5d 100644 --- a/torch_xla/_internal/jax_workarounds.py +++ b/torch_xla/_internal/jax_workarounds.py @@ -3,6 +3,7 @@ from typing import Callable, Any import functools import logging +import torch_xla.utils.utils as xu # TODO(https://github.com/pytorch/xla/issues/8793): Get rid of this hack. @@ -59,14 +60,24 @@ def maybe_get_torchax(): def maybe_get_jax(): - try: - jax_import_guard() - with jax_env_context(): - import jax - # TorchXLA still expects SPMD style sharding - jax.config.update('jax_use_shardy_partitioner', False) - return jax - except (ModuleNotFoundError, ImportError): - logging.warn('You are trying to use a feature that requires jax/pallas.' - 'You can install Jax/Pallas via pip install torch_xla[pallas]') - return None \ No newline at end of file + env_val = xu.getenv_as('TORCH_XLA_ENABLE_JAX', int, None) + + if env_val == 1: + try: + jax_import_guard() + with jax_env_context(): + import jax + return jax + except (ModuleNotFoundError, ImportError): + logging.warning( + 'JAX explicitly enabled but not installed. ' + 'You can install Jax/Pallas via pip install torch_xla[pallas]') + return None + + if env_val is None: + logging.warning( + 'You are trying to use a feature that requires JAX. ' + 'You can install Jax/Pallas via pip install torch_xla[pallas] and Set TORCH_XLA_ENABLE_JAX=1 to enable JAX features, or TORCH_XLA_ENABLE_JAX=0 to suppress this warning' + ) + # If explicitly disabled (0), return silently + return None