Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,9 @@ variables:
your code.
type: bool
default_value: false
TORCH_XLA_ENABLE_JAX:
description:
- Controls JAX integration in PyTorch/XLA. When set to 1, enables JAX
features and imports. When set to 0, disables JAX features silently.
When unset, shows a warning message with guidance.
type: int
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ function run_xla_op_tests1 {
run_test "$_TEST_DIR/test_fp8.py"
run_xla_ir_debug run_test "$_TEST_DIR/test_env_var_mapper.py"
run_xla_hlo_debug run_test "$_TEST_DIR/test_env_var_mapper.py"
run_test "$_TEST_DIR/test_jax_env_var.py"
run_xla_hlo_debug run_test "$_TEST_DIR/stablehlo/test_stablehlo_save_load.py"
run_save_tensor_ir run_test "$_TEST_DIR/spmd/test_spmd_graph_dump.py"
run_save_tensor_hlo run_test "$_TEST_DIR/spmd/test_spmd_graph_dump.py"
Expand Down
62 changes: 62 additions & 0 deletions test/test_jax_env_var.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
from absl.testing import absltest
from unittest.mock import patch
from torch_xla._internal.jax_workarounds import maybe_get_jax


class TestJaxEnvVar(absltest.TestCase):

def setUp(self):
# Clean up environment
if 'TORCH_XLA_ENABLE_JAX' in os.environ:
del os.environ['TORCH_XLA_ENABLE_JAX']

def test_jax_enabled_attempts_import(self):
os.environ['TORCH_XLA_ENABLE_JAX'] = '1'
with patch(
'torch_xla._internal.jax_workarounds.logging.warning') as mock_warn:
result = maybe_get_jax()
self.assertIsNone(result)
mock_warn.assert_called_once()
self.assertIn('JAX explicitly enabled but not installed',
mock_warn.call_args[0][0])

def test_jax_unset_warns_with_guidance(self):
"""Test that unset environment variable warns with guidance"""
with patch(
'torch_xla._internal.jax_workarounds.logging.warning') as mock_warn:
result = maybe_get_jax()
self.assertIsNone(result)
mock_warn.assert_called_once()
warning_msg = mock_warn.call_args[0][0]
self.assertIn('You are trying to use a feature that requires JAX',
warning_msg)
self.assertIn('TORCH_XLA_ENABLE_JAX=1', warning_msg)
self.assertIn('TORCH_XLA_ENABLE_JAX=0', warning_msg)

def test_jax_disabled_values_silent(self):
for val in ['0']:
with self.subTest(val=val):
os.environ['TORCH_XLA_ENABLE_JAX'] = val
with patch(
'torch_xla._internal.jax_workarounds.logging.warning') as mock_warn:
result = maybe_get_jax()
self.assertIsNone(result)
mock_warn.assert_not_called()

def test_jax_enabled_values(self):
"""Test various values that enable JAX"""
for val in ['1']:
with self.subTest(val=val):
os.environ['TORCH_XLA_ENABLE_JAX'] = val
with patch(
'torch_xla._internal.jax_workarounds.logging.warning') as mock_warn:
result = maybe_get_jax()
self.assertIsNone(result)
mock_warn.assert_called_once()
self.assertIn('JAX explicitly enabled but not installed',
mock_warn.call_args[0][0])


if __name__ == '__main__':
absltest.main()
33 changes: 22 additions & 11 deletions torch_xla/_internal/jax_workarounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, Any
import functools
import logging
import torch_xla.utils.utils as xu


# TODO(https://github.com/pytorch/xla/issues/8793): Get rid of this hack.
Expand Down Expand Up @@ -59,14 +60,24 @@ def maybe_get_torchax():


def maybe_get_jax():
try:
jax_import_guard()
with jax_env_context():
import jax
# TorchXLA still expects SPMD style sharding
jax.config.update('jax_use_shardy_partitioner', False)
return jax
except (ModuleNotFoundError, ImportError):
logging.warn('You are trying to use a feature that requires jax/pallas.'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do something simpler.

Instead of the env variable, just remove this logging.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably fix the test if that's a concern? Or do you see some other issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just the tests. Although I don't need the logging in this function, so it can be simplified if so.

i.e. it's reasonable to push the logging to the callsites of maybe_get_jax.

'You can install Jax/Pallas via pip install torch_xla[pallas]')
return None
env_val = xu.getenv_as('TORCH_XLA_ENABLE_JAX', int, None)

if env_val == 1:
try:
jax_import_guard()
with jax_env_context():
import jax
return jax
except (ModuleNotFoundError, ImportError):
logging.warning(
'JAX explicitly enabled but not installed. '
'You can install Jax/Pallas via pip install torch_xla[pallas]')
return None

if env_val is None:
logging.warning(
'You are trying to use a feature that requires JAX. '
'You can install Jax/Pallas via pip install torch_xla[pallas] and Set TORCH_XLA_ENABLE_JAX=1 to enable JAX features, or TORCH_XLA_ENABLE_JAX=0 to suppress this warning'
)
# If explicitly disabled (0), return silently
return None
Loading