Skip to content

Commit

Permalink
Generate type match guard for torch.Size input
Browse files Browse the repository at this point in the history
I suppose hypothetically, if the user code ends up working
polymorphically over the SizeVariable, in such a way that a tuple would
work, this type match is not necessary.  But we do not carefully test
for this.

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 887c96920f948942a1bb8d4f21eec81f87759a2a
Pull Request resolved: #96421
  • Loading branch information
ezyang committed Mar 9, 2023
1 parent f96bd52 commit 760852a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
3 changes: 2 additions & 1 deletion test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,8 @@ def fn(x, s):
self.assertEqual(opt_fn(v, v.size())[0, 0], -10)
self.assertEqual(opt_fn(v, (10, 20))[0, 0], -10)
self.assertEqual(opt_fn(v, [10, 20])[0, 0], -10)
self.assertEqual(cnts.op_count, 2)
# One recompile per differing input type
self.assertEqual(cnts.frame_count, 3)

def test_cell_output1(self):
out = None
Expand Down
17 changes: 17 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -2313,6 +2313,23 @@ def f(x):
exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5]))
self.assertTrue(same(exported(*args), f(*args)))

def test_size_typematch(self):
def f(x, y):
if isinstance(x, torch.Size):
return y + 1
else:
return y + 2

y = torch.zeros(1)
x1 = torch.Size((3,))
x2 = (3,)

cnt = torch._dynamo.testing.CompileCounter()
opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
self.assertTrue(same(f(x1, y), opt_f(x1, y)))
self.assertTrue(same(f(x2, y), opt_f(x2, y)))
self.assertEqual(cnt.frame_count, 2)

@torch._dynamo.config.patch("rewrite_assert_with_torch_assert", False)
def test_not_rewrite_assert(self):
def f(x):
Expand Down
12 changes: 9 additions & 3 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,22 +275,28 @@ def EQUALS_MATCH(self, guard: Guard):
self._produce_guard_code(guard, code)
return

# Add type check to prevent equality check between tensor and non-tensor.
code = list()

# If matching equality against list/tuple, we must also check that
# the internal types match. (TODO: what about nested lists?)
if istype(val, (list, tuple)):
# NB: LIST_LENGTH takes care of the outer __check_type_id test
self.LIST_LENGTH(guard)

for idx, elem in enumerate(val):
code.append(
f"___check_type_id({ref}[{idx}], {self.id_ref(type(elem))})"
)

elif not istype(val, torch.Size):
else:
# Add type check to prevent equality check between tensor and non-tensor.
code.append(f"___check_type_id({ref}, {self.id_ref(t)})")

if istype(val, torch.Size):
val = tuple(val)

# TODO: It feels like it would be better to just implement our own
# equality test in C that handles all of the necessary type checking
# and NaN tests
code.append(f"{ref} == {val!r}")
self._produce_guard_code(guard, code)

Expand Down

0 comments on commit 760852a

Please sign in to comment.