Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions torch_xla/_internal/jax_workarounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def maybe_get_torchax():
return None


def maybe_get_jax():
def maybe_get_jax(log=True):
try:
jax_import_guard()
with jax_env_context():
Expand All @@ -67,6 +67,8 @@ def maybe_get_jax():
jax.config.update('jax_use_shardy_partitioner', False)
return jax
except (ModuleNotFoundError, ImportError):
logging.warn('You are trying to use a feature that requires jax/pallas.'
'You can install Jax/Pallas via pip install torch_xla[pallas]')
return None
if log:
logging.warning(
'You are trying to use a feature that requires jax/pallas.'
'You can install Jax/Pallas via pip install torch_xla[pallas]')
return None
2 changes: 1 addition & 1 deletion torch_xla/debug/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __enter__(self):

self._jax_scope = None
# Also enter the JAX named scope, to support torchax lowering.
if jax := maybe_get_jax():
if jax := maybe_get_jax(log=False):
self._jax_scope = jax.named_scope(self.name)
self._jax_scope.__enter__()

Expand Down
3 changes: 2 additions & 1 deletion torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,8 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."

tx = maybe_get_torchax()
jax = maybe_get_jax()
# Do not log jax warnings when workarounds are available.
jax = maybe_get_jax(log=False)
if (jax is not None) and (tx is not None) and isinstance(t, tx.tensor.Tensor):
from jax.sharding import PartitionSpec as P, NamedSharding
jmesh = mesh.get_jax_mesh()
Expand Down