Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions experimental/torch_xla2/test/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest

import torch
import torch_xla2
from torch_xla2 import tensor


class TestContext(unittest.TestCase):
def test_mode_context_manager(self):
with torch_xla2.mode():
x = torch.full((3, 3), -1)
self.assertIsInstance(x, tensor.XLATensor2)
y = x.abs()
self.assertIsInstance(y, tensor.XLATensor2)

@staticmethod
@torch_xla2.mode()
def _test_mode_decorator():
x = torch.full((3, 3), -1)
y = x.abs()

return x, y

def test_mode_decorator(self):
x, y = self._test_mode_decorator()
self.assertIsInstance(x, tensor.XLATensor2)
self.assertIsInstance(y, tensor.XLATensor2)


if __name__ == "__main__":
unittest.main()
10 changes: 8 additions & 2 deletions experimental/torch_xla2/torch_xla2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import contextlib
import jax
import torch
from torch._functorch import make_functional
from torch.utils import _pytree as pytree
from torch_xla2 import tensor
from torch_xla2 import export, _ops, ops_registry, tensor, tf_integration
from torch_xla2 import export, _ops, ops_registry, tensor, tf_integration, functions

jax.config.update('jax_enable_x64', True)


@contextlib.contextmanager
def mode():
with tensor.XLADispatchMode(), functions.XLAFunctionMode():
yield


def extract_jax(mod: torch.nn.Module):
"""Returns a pytree of jax.ndarray and a jax callable."""
func, weights, buffer = make_functional.make_functional_with_buffers(mod)
Expand Down
1 change: 0 additions & 1 deletion experimental/torch_xla2/torch_xla2/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def __torch_function__(self,
kwargs=None) -> torch.Tensor:
jax_func = registry.get(func)
if not jax_func:
logging.warning(f'Falling back to default implementation of {func.__name__}')
return func(*args, **(kwargs or {}))

# TODO: unwrap args here or in implementations?
Expand Down