From 9f26f5deeea6de0035ad51b4aaed1a6774e092c8 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 8 Apr 2024 18:41:05 +0000 Subject: [PATCH 1/3] Combined tensor.XLADispatchMode() and functions.XLAFunctionMode() --- experimental/torch_xla2/test/test_context.py | 35 +++++++++++++++++++ .../torch_xla2/torch_xla2/__init__.py | 10 ++++-- 2 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 experimental/torch_xla2/test/test_context.py diff --git a/experimental/torch_xla2/test/test_context.py b/experimental/torch_xla2/test/test_context.py new file mode 100644 index 000000000000..5f535c639ff4 --- /dev/null +++ b/experimental/torch_xla2/test/test_context.py @@ -0,0 +1,35 @@ +import unittest + +import torch +import torch_xla2 +from torch_xla2 import tensor + + +class TestContext(unittest.TestCase): + def test_mode_context_manager(self): + with torch_xla2.mode(): + x = torch.full((3, 3), -1) + self.assertIsInstance(x, tensor.XLATensor2) + y = x.abs() + self.assertIsInstance(y, tensor.XLATensor2) + # TODO: remove print + print(y) + + @staticmethod + @torch_xla2.mode() + def _test_mode_decorator(): + x = torch.full((3, 3), -1) + y = x.abs() + + return x, y + + def test_mode_decorator(self): + x, y = self._test_mode_decorator() + self.assertIsInstance(x, tensor.XLATensor2) + self.assertIsInstance(y, tensor.XLATensor2) + # TODO: remove print + print(x, y) + + +if __name__ == "__main__": + unittest.main() diff --git a/experimental/torch_xla2/torch_xla2/__init__.py b/experimental/torch_xla2/torch_xla2/__init__.py index 4d07006fcd08..b0bb20712d42 100644 --- a/experimental/torch_xla2/torch_xla2/__init__.py +++ b/experimental/torch_xla2/torch_xla2/__init__.py @@ -1,13 +1,19 @@ +import contextlib import jax import torch from torch._functorch import make_functional from torch.utils import _pytree as pytree -from torch_xla2 import tensor -from torch_xla2 import export, _ops, ops_registry, tensor, tf_integration +from torch_xla2 import export, _ops, ops_registry, tensor, tf_integration, functions jax.config.update('jax_enable_x64', True) +@contextlib.contextmanager +def mode(): + with tensor.XLADispatchMode(), functions.XLAFunctionMode(): + yield + + def extract_jax(mod: torch.nn.Module): """Returns a pytree of jax.ndarray and a jax callable.""" func, weights, buffer = make_functional.make_functional_with_buffers(mod) From c64998717bf36577376359cd651e8f404b16f3b5 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 9 Apr 2024 17:41:32 +0000 Subject: [PATCH 2/3] Reduce complaint about unimplemented funtion to warning --- experimental/torch_xla2/test/test_context.py | 4 ---- experimental/torch_xla2/torch_xla2/functions.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/experimental/torch_xla2/test/test_context.py b/experimental/torch_xla2/test/test_context.py index 5f535c639ff4..1a75a7d23d05 100644 --- a/experimental/torch_xla2/test/test_context.py +++ b/experimental/torch_xla2/test/test_context.py @@ -12,8 +12,6 @@ def test_mode_context_manager(self): self.assertIsInstance(x, tensor.XLATensor2) y = x.abs() self.assertIsInstance(y, tensor.XLATensor2) - # TODO: remove print - print(y) @staticmethod @torch_xla2.mode() @@ -27,8 +25,6 @@ def test_mode_decorator(self): x, y = self._test_mode_decorator() self.assertIsInstance(x, tensor.XLATensor2) self.assertIsInstance(y, tensor.XLATensor2) - # TODO: remove print - print(x, y) if __name__ == "__main__": diff --git a/experimental/torch_xla2/torch_xla2/functions.py b/experimental/torch_xla2/torch_xla2/functions.py index 6c40455959cc..e49c74ed0100 100644 --- a/experimental/torch_xla2/torch_xla2/functions.py +++ b/experimental/torch_xla2/torch_xla2/functions.py @@ -103,7 +103,7 @@ def __torch_function__(self, kwargs=None) -> torch.Tensor: jax_func = registry.get(func) if not jax_func: - logging.warning(f'Falling back to default implementation of {func.__name__}') + logging.info(f'Falling back to default implementation of {func.__name__}') return func(*args, **(kwargs or {})) # TODO: unwrap args here or in implementations? From abbf31189cba1ae4a9b7a1f0ea46648fe4b823c5 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 10 Apr 2024 16:29:52 +0000 Subject: [PATCH 3/3] remove log entirely --- experimental/torch_xla2/torch_xla2/functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/torch_xla2/torch_xla2/functions.py b/experimental/torch_xla2/torch_xla2/functions.py index e49c74ed0100..9fcd5653a86c 100644 --- a/experimental/torch_xla2/torch_xla2/functions.py +++ b/experimental/torch_xla2/torch_xla2/functions.py @@ -103,7 +103,6 @@ def __torch_function__(self, kwargs=None) -> torch.Tensor: jax_func = registry.get(func) if not jax_func: - logging.info(f'Falling back to default implementation of {func.__name__}') return func(*args, **(kwargs or {})) # TODO: unwrap args here or in implementations?