diff --git a/experimental/torch_xla2/test/test_context.py b/experimental/torch_xla2/test/test_context.py new file mode 100644 index 000000000000..1a75a7d23d05 --- /dev/null +++ b/experimental/torch_xla2/test/test_context.py @@ -0,0 +1,31 @@ +import unittest + +import torch +import torch_xla2 +from torch_xla2 import tensor + + +class TestContext(unittest.TestCase): + def test_mode_context_manager(self): + with torch_xla2.mode(): + x = torch.full((3, 3), -1) + self.assertIsInstance(x, tensor.XLATensor2) + y = x.abs() + self.assertIsInstance(y, tensor.XLATensor2) + + @staticmethod + @torch_xla2.mode() + def _test_mode_decorator(): + x = torch.full((3, 3), -1) + y = x.abs() + + return x, y + + def test_mode_decorator(self): + x, y = self._test_mode_decorator() + self.assertIsInstance(x, tensor.XLATensor2) + self.assertIsInstance(y, tensor.XLATensor2) + + +if __name__ == "__main__": + unittest.main() diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index 4d07006fcd08..b0bb20712d42 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -1,13 +1,19 @@ +import contextlib import jax import torch from torch._functorch import make_functional from torch.utils import _pytree as pytree -from torch_xla2 import tensor -from torch_xla2 import export, _ops, ops_registry, tensor, tf_integration +from torch_xla2 import export, _ops, ops_registry, tensor, tf_integration, functions jax.config.update('jax_enable_x64', True) +@contextlib.contextmanager +def mode(): + with tensor.XLADispatchMode(), functions.XLAFunctionMode(): + yield + + def extract_jax(mod: torch.nn.Module): """Returns a pytree of jax.ndarray and a jax callable.""" func, weights, buffer = make_functional.make_functional_with_buffers(mod) diff --git a/experimental/torch_xla2/torch_xla2/functions.py b/experimental/torch_xla2/torch_xla2/functions.py index 6c40455959cc..9fcd5653a86c 100644 --- a/experimental/torch_xla2/torch_xla2/functions.py +++ b/experimental/torch_xla2/torch_xla2/functions.py @@ -103,7 +103,6 @@ def __torch_function__(self, kwargs=None) -> torch.Tensor: jax_func = registry.get(func) if not jax_func: - logging.warning(f'Falling back to default implementation of {func.__name__}') return func(*args, **(kwargs or {})) # TODO: unwrap args here or in implementations?