From 9af8b54ffd0ef3693cf5faf8510bd14db7fec592 Mon Sep 17 00:00:00 2001 From: Raj Thakur Date: Wed, 3 Sep 2025 01:23:39 +0000 Subject: [PATCH 1/3] feat: make JAX optional with env variable control --- test/test_jax_env_var.py | 55 ++++++++++++++++++++++++++ torch_xla/_internal/jax_workarounds.py | 29 ++++++++------ 2 files changed, 73 insertions(+), 11 deletions(-) create mode 100644 test/test_jax_env_var.py diff --git a/test/test_jax_env_var.py b/test/test_jax_env_var.py new file mode 100644 index 00000000000..8ffe1ae817b --- /dev/null +++ b/test/test_jax_env_var.py @@ -0,0 +1,55 @@ +import os +from absl.testing import absltest +from unittest.mock import patch +from torch_xla._internal.jax_workarounds import maybe_get_jax + + +class TestJaxEnvVar(absltest.TestCase): + + def setUp(self): + # Clean up environment + if 'TORCH_XLA_ENABLE_JAX' in os.environ: + del os.environ['TORCH_XLA_ENABLE_JAX'] + + def test_jax_enabled_attempts_import(self): + os.environ['TORCH_XLA_ENABLE_JAX'] = '1' + with patch('torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: + result = maybe_get_jax() + self.assertIsNone(result) + mock_warn.assert_called_once() + self.assertIn('JAX explicitly enabled but not installed', mock_warn.call_args[0][0]) + + def test_jax_unset_warns_with_guidance(self): + """Test that unset environment variable warns with guidance""" + with patch('torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: + result = maybe_get_jax() + self.assertIsNone(result) + mock_warn.assert_called_once() + warning_msg = mock_warn.call_args[0][0] + self.assertIn('You are trying to use a feature that requires JAX', warning_msg) + self.assertIn('TORCH_XLA_ENABLE_JAX=1', warning_msg) + self.assertIn('TORCH_XLA_ENABLE_JAX=0', warning_msg) + + def test_jax_disabled_values_silent(self): + for val in ['0', 'false', 'FALSE', 'no', 'NO', 'random', 'disabled']: + with self.subTest(val=val): + os.environ['TORCH_XLA_ENABLE_JAX'] = val + with patch('torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: + result = maybe_get_jax() + self.assertIsNone(result) + mock_warn.assert_not_called() + + def test_jax_enabled_values(self): + """Test various values that enable JAX""" + for val in ['1', 'true', 'TRUE', 'yes', 'YES']: + with self.subTest(val=val): + os.environ['TORCH_XLA_ENABLE_JAX'] = val + with patch('torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: + result = maybe_get_jax() + self.assertIsNone(result) + mock_warn.assert_called_once() + self.assertIn('JAX explicitly enabled but not installed', mock_warn.call_args[0][0]) + + +if __name__ == '__main__': + absltest.main() \ No newline at end of file diff --git a/torch_xla/_internal/jax_workarounds.py b/torch_xla/_internal/jax_workarounds.py index d2d66570418..8166415fce7 100644 --- a/torch_xla/_internal/jax_workarounds.py +++ b/torch_xla/_internal/jax_workarounds.py @@ -59,14 +59,21 @@ def maybe_get_torchax(): def maybe_get_jax(): - try: - jax_import_guard() - with jax_env_context(): - import jax - # TorchXLA still expects SPMD style sharding - jax.config.update('jax_use_shardy_partitioner', False) - return jax - except (ModuleNotFoundError, ImportError): - logging.warn('You are trying to use a feature that requires jax/pallas.' - 'You can install Jax/Pallas via pip install torch_xla[pallas]') - return None \ No newline at end of file + env_val = os.environ.get('TORCH_XLA_ENABLE_JAX', '') + + if env_val.lower() in ('1', 'true', 'yes'): + try: + jax_import_guard() + with jax_env_context(): + import jax + return jax + except (ModuleNotFoundError, ImportError): + logging.warning('JAX explicitly enabled but not installed. ' + 'You can install Jax/Pallas via pip install torch_xla[pallas]') + return None + + if env_val == '': + logging.warning('You are trying to use a feature that requires JAX. ' + 'You can install Jax/Pallas via pip install torch_xla[pallas] and Set TORCH_XLA_ENABLE_JAX=1 to enable JAX features, or TORCH_XLA_ENABLE_JAX=0 to suppress this warning') + # If explicitly disabled (0, false, no, etc.), return silently + return None \ No newline at end of file From bb9e308cf641c420ce4b34374495ee95052dc53a Mon Sep 17 00:00:00 2001 From: Raj Thakur Date: Fri, 5 Sep 2025 08:22:53 +0000 Subject: [PATCH 2/3] Use xu.getenv_as for TORCH_XLA_ENABLE_JAX env var --- configuration.yaml | 6 ++++++ test/test_jax_env_var.py | 29 ++++++++++++++++---------- torch_xla/_internal/jax_workarounds.py | 26 +++++++++++++---------- 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/configuration.yaml b/configuration.yaml index 6c1a1844ca6..d1a160e6f46 100644 --- a/configuration.yaml +++ b/configuration.yaml @@ -388,3 +388,9 @@ variables: your code. type: bool default_value: false + TORCH_XLA_ENABLE_JAX: + description: + - Controls JAX integration in PyTorch/XLA. When set to 1, enables JAX + features and imports. When set to 0, disables JAX features silently. + When unset, shows a warning message with guidance. + type: int diff --git a/test/test_jax_env_var.py b/test/test_jax_env_var.py index 8ffe1ae817b..f8d3b8e53e2 100644 --- a/test/test_jax_env_var.py +++ b/test/test_jax_env_var.py @@ -13,43 +13,50 @@ def setUp(self): def test_jax_enabled_attempts_import(self): os.environ['TORCH_XLA_ENABLE_JAX'] = '1' - with patch('torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: + with patch( + 'torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: result = maybe_get_jax() self.assertIsNone(result) mock_warn.assert_called_once() - self.assertIn('JAX explicitly enabled but not installed', mock_warn.call_args[0][0]) + self.assertIn('JAX explicitly enabled but not installed', + mock_warn.call_args[0][0]) def test_jax_unset_warns_with_guidance(self): """Test that unset environment variable warns with guidance""" - with patch('torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: + with patch( + 'torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: result = maybe_get_jax() self.assertIsNone(result) mock_warn.assert_called_once() warning_msg = mock_warn.call_args[0][0] - self.assertIn('You are trying to use a feature that requires JAX', warning_msg) + self.assertIn('You are trying to use a feature that requires JAX', + warning_msg) self.assertIn('TORCH_XLA_ENABLE_JAX=1', warning_msg) self.assertIn('TORCH_XLA_ENABLE_JAX=0', warning_msg) def test_jax_disabled_values_silent(self): - for val in ['0', 'false', 'FALSE', 'no', 'NO', 'random', 'disabled']: + for val in ['0']: with self.subTest(val=val): os.environ['TORCH_XLA_ENABLE_JAX'] = val - with patch('torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: + with patch( + 'torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: result = maybe_get_jax() self.assertIsNone(result) mock_warn.assert_not_called() def test_jax_enabled_values(self): """Test various values that enable JAX""" - for val in ['1', 'true', 'TRUE', 'yes', 'YES']: + for val in ['1']: with self.subTest(val=val): os.environ['TORCH_XLA_ENABLE_JAX'] = val - with patch('torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: + with patch( + 'torch_xla._internal.jax_workarounds.logging.warning') as mock_warn: result = maybe_get_jax() - self.assertIsNone(result) + self.assertIsNone(result) mock_warn.assert_called_once() - self.assertIn('JAX explicitly enabled but not installed', mock_warn.call_args[0][0]) + self.assertIn('JAX explicitly enabled but not installed', + mock_warn.call_args[0][0]) if __name__ == '__main__': - absltest.main() \ No newline at end of file + absltest.main() diff --git a/torch_xla/_internal/jax_workarounds.py b/torch_xla/_internal/jax_workarounds.py index 8166415fce7..224a7dc5b5d 100644 --- a/torch_xla/_internal/jax_workarounds.py +++ b/torch_xla/_internal/jax_workarounds.py @@ -3,6 +3,7 @@ from typing import Callable, Any import functools import logging +import torch_xla.utils.utils as xu # TODO(https://github.com/pytorch/xla/issues/8793): Get rid of this hack. @@ -59,21 +60,24 @@ def maybe_get_torchax(): def maybe_get_jax(): - env_val = os.environ.get('TORCH_XLA_ENABLE_JAX', '') - - if env_val.lower() in ('1', 'true', 'yes'): + env_val = xu.getenv_as('TORCH_XLA_ENABLE_JAX', int, None) + + if env_val == 1: try: jax_import_guard() with jax_env_context(): import jax return jax except (ModuleNotFoundError, ImportError): - logging.warning('JAX explicitly enabled but not installed. ' - 'You can install Jax/Pallas via pip install torch_xla[pallas]') + logging.warning( + 'JAX explicitly enabled but not installed. ' + 'You can install Jax/Pallas via pip install torch_xla[pallas]') return None - - if env_val == '': - logging.warning('You are trying to use a feature that requires JAX. ' - 'You can install Jax/Pallas via pip install torch_xla[pallas] and Set TORCH_XLA_ENABLE_JAX=1 to enable JAX features, or TORCH_XLA_ENABLE_JAX=0 to suppress this warning') - # If explicitly disabled (0, false, no, etc.), return silently - return None \ No newline at end of file + + if env_val is None: + logging.warning( + 'You are trying to use a feature that requires JAX. ' + 'You can install Jax/Pallas via pip install torch_xla[pallas] and Set TORCH_XLA_ENABLE_JAX=1 to enable JAX features, or TORCH_XLA_ENABLE_JAX=0 to suppress this warning' + ) + # If explicitly disabled (0), return silently + return None From 79731a29843e7a8c8991787e5a38eea56da5ee67 Mon Sep 17 00:00:00 2001 From: Raj Thakur Date: Fri, 5 Sep 2025 08:44:28 +0000 Subject: [PATCH 3/3] Add test_jax_env_var.py to test suite in run_tests --- test/run_tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/test/run_tests.sh b/test/run_tests.sh index 033089d651f..8abdb149403 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -179,6 +179,7 @@ function run_xla_op_tests1 { run_test "$_TEST_DIR/test_fp8.py" run_xla_ir_debug run_test "$_TEST_DIR/test_env_var_mapper.py" run_xla_hlo_debug run_test "$_TEST_DIR/test_env_var_mapper.py" + run_test "$_TEST_DIR/test_jax_env_var.py" run_xla_hlo_debug run_test "$_TEST_DIR/stablehlo/test_stablehlo_save_load.py" run_save_tensor_ir run_test "$_TEST_DIR/spmd/test_spmd_graph_dump.py" run_save_tensor_hlo run_test "$_TEST_DIR/spmd/test_spmd_graph_dump.py"