diff --git a/torch_xla/_internal/jax_workarounds.py b/torch_xla/_internal/jax_workarounds.py index d2d66570418..d16bc8fa034 100644 --- a/torch_xla/_internal/jax_workarounds.py +++ b/torch_xla/_internal/jax_workarounds.py @@ -58,7 +58,7 @@ def maybe_get_torchax(): return None -def maybe_get_jax(): +def maybe_get_jax(log=True): try: jax_import_guard() with jax_env_context(): @@ -67,6 +67,8 @@ def maybe_get_jax(): jax.config.update('jax_use_shardy_partitioner', False) return jax except (ModuleNotFoundError, ImportError): - logging.warn('You are trying to use a feature that requires jax/pallas.' - 'You can install Jax/Pallas via pip install torch_xla[pallas]') - return None \ No newline at end of file + if log: + logging.warning( + 'You are trying to use a feature that requires jax/pallas.' + 'You can install Jax/Pallas via pip install torch_xla[pallas]') + return None diff --git a/torch_xla/debug/profiler.py b/torch_xla/debug/profiler.py index ffbd754b1c7..37dfa501e4f 100644 --- a/torch_xla/debug/profiler.py +++ b/torch_xla/debug/profiler.py @@ -131,7 +131,7 @@ def __enter__(self): self._jax_scope = None # Also enter the JAX named scope, to support torchax lowering. - if jax := maybe_get_jax(): + if jax := maybe_get_jax(log=False): self._jax_scope = jax.named_scope(self.name) self._jax_scope.__enter__() diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index c010fd4c352..be6daca582e 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -646,7 +646,8 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh, f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})." tx = maybe_get_torchax() - jax = maybe_get_jax() + # Do not log jax warnings when workarounds are available. + jax = maybe_get_jax(log=False) if (jax is not None) and (tx is not None) and isinstance(t, tx.tensor.Tensor): from jax.sharding import PartitionSpec as P, NamedSharding jmesh = mesh.get_jax_mesh()