Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions test/test_python_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def test_call_boxed(self) -> None:


class TestPythonRegistration(TestCase):
test_ns = '_test_python_registration'

def tearDown(self):
if hasattr(torch.ops, self.test_ns):
del torch.ops._test_python_registration

def test_override_aten_ops_with_multiple_libraries(self) -> None:
x = torch.tensor([1, 2])
my_lib1 = Library("aten", "IMPL")
Expand Down Expand Up @@ -95,7 +101,7 @@ def test_error_if_fn_not_callable(self):

def test_finalizer(self):
impls_refcnt = sys.getrefcount(torch.library._impls)
lib = Library("_torch_testing", "FRAGMENT")
lib = Library(self.test_ns, "FRAGMENT")
lib.define("foo123(Tensor x) -> Tensor")

# 1 for `lib`, 1 for sys.getrefcount
Expand All @@ -110,8 +116,8 @@ def test_finalizer(self):
def foo123(x):
pass

lib.impl("_torch_testing::foo123", foo123, "CPU")
key = '_torch_testing/foo123/CPU'
lib.impl(f"{self.test_ns}::foo123", foo123, "CPU")
key = f'{self.test_ns}/foo123/CPU'
self.assertTrue(key in torch.library._impls)

saved_op_impls = lib._op_impls
Expand Down Expand Up @@ -287,7 +293,7 @@ def my_sum(*args, **kwargs):
del my_lib1

def test_create_new_library(self) -> None:
my_lib1 = Library("foo", "DEF")
my_lib1 = Library(self.test_ns, "DEF")

my_lib1.define("sum(Tensor self) -> Tensor")

Expand All @@ -297,27 +303,28 @@ def my_sum(*args, **kwargs):
return args[0].clone()

x = torch.tensor([1, 2])
self.assertEqual(torch.ops.foo.sum(x), x)
op = getattr(torch.ops, self.test_ns).sum
self.assertEqual(op(x), x)

my_lib2 = Library("foo", "IMPL")
my_lib2 = Library(self.test_ns, "IMPL")

# Example 2
@torch.library.impl(my_lib2, torch.ops.foo.sum.default, "ZeroTensor")
@torch.library.impl(my_lib2, op.default, "ZeroTensor")
def my_sum_zt(*args, **kwargs):
if args[0]._is_zerotensor():
return torch._efficientzerotensor(args[0].shape)
else:
return args[0].clone()

y = torch._efficientzerotensor(3)
self.assertTrue(torch.ops.foo.sum(y)._is_zerotensor())
self.assertEqual(torch.ops.foo.sum(x), x)
self.assertTrue(op(y)._is_zerotensor())
self.assertEqual(op(x), x)

del my_lib2
del my_lib1

def test_create_new_library_fragment_no_existing(self):
my_lib = Library("foo", "FRAGMENT")
my_lib = Library(self.test_ns, "FRAGMENT")

my_lib.define("sum2(Tensor self) -> Tensor")

Expand All @@ -326,15 +333,15 @@ def my_sum(*args, **kwargs):
return args[0]

x = torch.tensor([1, 2])
self.assertEqual(torch.ops.foo.sum2(x), x)
self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)

del my_lib

def test_create_new_library_fragment_with_existing(self):
my_lib1 = Library("foo", "DEF")
my_lib1 = Library(self.test_ns, "DEF")

# Create a fragment
my_lib2 = Library("foo", "FRAGMENT")
my_lib2 = Library(self.test_ns, "FRAGMENT")

my_lib2.define("sum4(Tensor self) -> Tensor")

Expand All @@ -343,10 +350,10 @@ def my_sum4(*args, **kwargs):
return args[0]

x = torch.tensor([1, 2])
self.assertEqual(torch.ops.foo.sum4(x), x)
self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)

# Create another fragment
my_lib3 = Library("foo", "FRAGMENT")
my_lib3 = Library(self.test_ns, "FRAGMENT")

my_lib3.define("sum3(Tensor self) -> Tensor")

Expand All @@ -355,7 +362,7 @@ def my_sum3(*args, **kwargs):
return args[0]

x = torch.tensor([1, 2])
self.assertEqual(torch.ops.foo.sum3(x), x)
self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)

del my_lib1
del my_lib2
Expand All @@ -364,7 +371,7 @@ def my_sum3(*args, **kwargs):
@unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
def test_alias_analysis(self):
def test_helper(alias_analysis=""):
my_lib1 = Library("foo", "DEF")
my_lib1 = Library(self.test_ns, "DEF")

called = [0]

Expand All @@ -374,9 +381,9 @@ def _op(*args, **kwargs):

@torch.jit.script
def _test():
torch.ops.foo._op()
torch.ops._test_python_registration._op()

assert "foo::_op" in str(_test.graph)
assert "_test_python_registration::_op" in str(_test.graph)

with self.assertRaises(AssertionError):
test_helper("") # alias_analysis="FROM_SCHEMA"
Expand All @@ -399,14 +406,14 @@ def test_returning_symint(self) -> None:

s0, s1 = ft.shape

tlib = Library("tlib", "DEF")
tlib = Library(self.test_ns, "DEF")
tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")

@impl(tlib, "sqsum", "CompositeExplicitAutograd")
def sqsum(a: SymInt, b: SymInt):
return a * a + b * b

out = torch.ops.tlib.sqsum.default(s0, s1)
out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
out_val = shape_env.evaluate_expr(out.node.expr)
self.assertEquals(out_val, 13)

Expand Down