Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e42a1af
Preserve types during empty container assignment
May 25, 2021
f40d6d9
Update on "Preserve types during empty container assignment"
May 27, 2021
97fed79
Update on "Preserve types during empty container assignment"
Jun 7, 2021
75f4fb5
Update on "Preserve types during empty container assignment"
Jun 26, 2021
1fe5881
Update on "Preserve types during empty container assignment"
Jul 28, 2021
f3bb684
Update on "Preserve types during empty container assignment"
Aug 11, 2021
c88c7c2
Update on "Preserve types during empty container assignment"
Aug 17, 2021
e877050
Update on "Preserve types during empty container assignment"
Aug 19, 2021
53a823f
Update on "Preserve types during empty container assignment"
Aug 27, 2021
dc0b7f5
Update on "Preserve types during empty container assignment"
Aug 27, 2021
cc51257
Update on "Preserve types during empty container assignment"
Aug 30, 2021
cdabd90
Update on "Preserve types during empty container assignment"
Aug 30, 2021
c6dfc33
Update on "Preserve types during empty container assignment"
Sep 7, 2021
c30031b
Update on "Preserve types during empty container assignment"
Sep 7, 2021
0ec92f9
Update on "Preserve types during empty container assignment"
Sep 7, 2021
72068c5
Update on "Preserve types during empty container assignment"
Sep 7, 2021
55233de
Update on "Preserve types during empty container assignment"
Sep 7, 2021
b4c0723
Update on "Preserve types during empty container assignment"
Sep 7, 2021
aa3aee9
Update on "Preserve types during empty container assignment"
Sep 8, 2021
ab20f43
Update on "Preserve types during empty container assignment"
Sep 8, 2021
abd9fc1
Update on "Preserve types during empty container assignment"
Sep 8, 2021
7a9b4f4
Update on "Preserve types during empty container assignment"
Sep 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions test/jit/test_list_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,9 @@ def fn():
self.checkScript(fn, ())

def test_dict_keyword_with_mismatched_annotations(self):
err_msg = r"Dict type annotation `Dict\[int, str\]` did not "\
"match the types of the actual dict items"
err_msg = r"Dict type annotation `Dict\[int, str\]` did not "\
"match the type of an actual key type `str`"
highlight_msg = "dict([(\"foo\", 1), (\"bar\", 2), (\"baz\", 3"
with self.assertRaisesRegexWithHighlight(RuntimeError, err_msg, highlight_msg):
err_msg = r"is annotated with type Dict\[int, str\] but is " \
r"being assigned to a value of type Dict\[str, int\]"
with self.assertRaisesRegex(RuntimeError, err_msg):
@torch.jit.script
def fn():
x: Dict[int, str] = dict([("foo", 1), ("bar", 2), ("baz", 3)]) # noqa: C406
Expand Down Expand Up @@ -1328,7 +1325,7 @@ def test_list_none(self):
x = torch._C.ListType(None)

def test_list_unification_hint(self):
with self.assertRaisesRegex(RuntimeError, "Expected a List type hint"):
with self.assertRaisesRegex(RuntimeError, "Expected an annotation of type List"):
@torch.jit.script
def x():
b : int = [2, 3]
Expand Down
19 changes: 10 additions & 9 deletions test/jit/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,24 +168,25 @@ def fn():
l1 = [1, 2, "foo", 3]
l2 = ["foo", "bar", "baz", "qux"]
d: Dict[int, str] = {k : v for k, v in zip(l1, l2)}
return l
return d

with self.assertRaisesRegex(RuntimeError, "Dict type annotation"
r" `Dict\[int, str\]` did not match"
" the type of an actual key type"):
with self.assertRaisesRegex(RuntimeError, "Dicts may only "
"contain homogeneous keys, but the "
"type of the first generated key "
r"was Union\[int, str\]"):
torch.jit.script(fn)

def test_dict_type_refinement_annotation_value_mismatch(self):
def fn():
l1 = ["foo", "bar", "baz", "qux"]
l2 = [1, 2, "foo", 3]
d: Dict[str, int] = {k : v for k, v in zip(l1, l2)}
return l
return d

with self.assertRaisesRegex(RuntimeError, "Dict type annotation"
r" `Dict\[str, int\]` did not match"
" the type of an actual value "
"type"):
with self.assertRaisesRegex(RuntimeError, "annotated with type "
r"Dict\[str, int\] but is being "
"assigned to a value of type "
r"Dict\[str, Union\[int, str\]\]"):
torch.jit.script(fn)

def test_dict_invalid_annotations(self):
Expand Down
270 changes: 270 additions & 0 deletions test/jit/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torch.testing import FileCheck
from enum import Enum
from textwrap import dedent
from typing import Dict, List, Optional, Tuple, Union

# Make the helper files in test/ importable
Expand Down Expand Up @@ -655,3 +656,272 @@ def fn(x: int) -> str:

self.checkScript(fn, (1,))
self.checkScript(fn, (8,))

def _assert_passes(self, template: str, ann: str, lhs: str):
code = template.format(ann=ann, lhs=lhs)
self.checkScript(code, (), name="fn")

def _assert_raises(self, template: str, ann: str, lhs: str, msg: str):
code = template.format(ann=ann, lhs=lhs)
with self.assertRaisesRegex(RuntimeError, msg):
cu = torch.jit.CompilationUnit(code, _frames_up=1)
string_frontend = getattr(cu, "fn") # noqa: B009

def test_union_with_list_assignment(self):
template = dedent('''
def fn():
x: {ann} = {lhs}
if torch.jit.isinstance(x, List[torch.Tensor]):
x.append(torch.tensor(3))
return x
''')

lhs = {"list_literal_empty" : "[]",

"list_literal_of_tensor" : "[torch.arange(3), torch.arange(5)]",

"list_literal_of_str" : "[\"foo\", \"bar\", \"baz\"]",

"list_literal_of_mixed" : "[torch.arange(5), 1]",

"list_comprehension_of_tensor" :
"[torch.add(x, 1) for x in [torch.arange(3), torch.arange(5)]]",

"list_comprehension_of_str" :
"[x + \"!\" for x in [\"foo\", \"bar\", \"baz\"]]",

"list_comprehension_of_mixed" :
"[torch.add(1, x) for x in [torch.arange(5), 1]]"}

"""
Union[List[str], List[torch.Tensor]]
"""
self._assert_raises(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_empty"],
"there are multiple possible List type "
"candidates in the Union annotation")

self._assert_passes(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_of_tensor"])

self._assert_passes(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_of_str"])

self._assert_raises(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_literal_of_mixed"],
"none of those list types can hold the "
"types of the given list elements")

self._assert_passes(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_tensor"])

self._assert_passes(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_str"])

# TODO: Support mixed list comprehensions
self._assert_raises(template,
"Union[List[str], List[torch.Tensor]]",
lhs["list_comprehension_of_mixed"],
"Arguments for call are not valid")

"""
Union[int, torch.Tensor]
"""
self._assert_raises(template,
"Union[int, torch.Tensor]",
lhs["list_literal_empty"],
"Expected an Union type annotation with an "
"inner List type")

self._assert_raises(template, "Union[int, torch.Tensor]",
lhs["list_literal_of_tensor"],
"Expected an Union type annotation with an "
"inner List type")

self._assert_raises(template, "Union[int, torch.Tensor]",
lhs["list_comprehension_of_tensor"],
"Expected an Union type annotation with an "
"inner List type")

"""
Union[List[torch.Tensor], int]
"""
self._assert_passes(template,
"Union[List[torch.Tensor], int]",
lhs["list_literal_empty"])

self._assert_passes(template,
"Union[List[torch.Tensor], int]",
lhs["list_literal_of_tensor"])

self._assert_raises(template, "Union[List[torch.Tensor], int]",
lhs["list_literal_of_str"],
r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements")

self._assert_raises(template, "Union[List[torch.Tensor], int]",
lhs["list_literal_of_mixed"],
r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements")

self._assert_passes(template,
"Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_tensor"])

self._assert_raises(template,
"Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_str"],
r"List type annotation `List\[Tensor\]` did "
"not match the types of the given list "
"elements")

# TODO: Support mixed list comprehensions
self._assert_raises(template,
"Union[List[torch.Tensor], int]",
lhs["list_comprehension_of_mixed"],
"Arguments for call are not valid")

def test_union_with_dict_assignment(self):
template = dedent('''
def fn():
x: {ann} = {lhs}
if torch.jit.isinstance(x, Dict[str, torch.Tensor]):
x["foo"] = torch.tensor(3)
return x
''')

lhs = {"dict_literal_empty" : "{}",

"dict_literal_of_str_tensor" :
"{\"foo\" : torch.arange(3), \"bar\" : torch.arange(5)}",

"dict_literal_of_str_int" :
"{\"foo\" : 1, \"bar\" : 2}",

"dict_literal_of_mixed" :
"{\"foo\" : torch.arange(3), \"bar\" : 2}",

"dict_comprehension_of_str_tensor" :
"{x : torch.add(y, 1) for x, y in \
zip([\"foo\", \"bar\"], [torch.arange(3), torch.arange(5)])}",

"dict_comprehension_of_str_int" :
"{x : torch.add(y, 1) for x, y in \
zip([\"foo\", \"bar\"], [1, 2]}",

"dict_comprehension_of_mixed" :
"{x : torch.add(y, 1) for x, y in \
zip([\"foo\", \"bar\"], [torch.arange(3), 2])}",

"dict_keyword" :
"dict(foo=torch.arange(3), baz=torch.arange(5))"}

"""
Union[Dict[str, torch.Tensor], Dict[str, int]]
"""
self._assert_raises(template,
"Union[List[str], List[torch.Tensor]]",
lhs["dict_literal_empty"],
"Expected an Union type annotation with an "
"inner Dict type")

self._assert_passes(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_str_tensor"])

self._assert_passes(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_str_int"])

self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_literal_of_mixed"],
"none of those types can hold the types "
"of the given dict elements")

# TODO: String frontend does not support tuple unpacking
# https://github.com/pytorch/pytorch/issues/64096
# self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
# lhs["dict_comprehension_of_str_tensor"])

# self._assert_passes(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
# lhs["dict_comprehension_of_str_int"])

# self._assert_raises(template, "Union[Dict[str, torch.Tensor], Dict[str, int]]",
# lhs["dict_comprehension_of_mixed"],
# "foobar")

self._assert_passes(template,
"Union[Dict[str, torch.Tensor], Dict[str, int]]",
lhs["dict_keyword"])

"""
Union[int, torch.Tensor]
"""
self._assert_raises(template,
"Union[int, torch.Tensor]",
lhs["dict_literal_empty"],
"Expected an Union type annotation with "
"an inner Dict type")

self._assert_raises(template,
"Union[int, torch.Tensor]",
lhs["dict_literal_of_str_tensor"],
"Expected an Union type annotation with "
"an inner Dict type")

# See above--string frontend does not support tuple unpacking
# self._assert_raises(template, "Union[int, torch.Tensor]",
# lhs["dict_comprehension_of_tensor"],
# "foobar")

"""
Union[Dict[str, torch.Tensor], int]
"""
self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_empty"])

self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_str_tensor"])

self._assert_raises(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_str_int"],
r"Type hint for dict was Dict\[str, Tensor\]"
", but the value at index 0 has type int, "
"which is not a valid subtype of Tensor")

self._assert_raises(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_literal_of_mixed"],
r"Type hint for dict was Dict\[str, Tensor\]"
", but the value at index 1 has type int, "
"which is not a valid subtype of Tensor")

# See above--string frontend does not support tuple unpacking
# self._assert_passes(template,
# "Union[Dict[str, torch.Tensor], int]",
# lhs["dict_comprehension_of_str_tensor"])

# self._assert_raises(template,
# "Union[Dict[str, torch.Tensor], int]",
# lhs["dict_comprehension_of_str_int"],
# "foobar")

# self._assert_raises(template,
# "Union[Dict[str, torch.Tensor], int]",
# lhs["dict_comprehension_of_mixed"],
# "foobar")

self._assert_passes(template,
"Union[Dict[str, torch.Tensor], int]",
lhs["dict_keyword"])
3 changes: 2 additions & 1 deletion test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11526,7 +11526,8 @@ def bad_type_annotation():
out = torch.jit.annotate(int, [x for x in [1, 2, 3]]) # noqa: C416
return out

with self.assertRaisesRegex(Exception, "Expected list type annotation"):
with self.assertRaisesRegex(Exception, "Expected an annotation"
" of type List"):
torch.jit.script(bad_type_annotation)

def test_list_comprehension_variable_write(self):
Expand Down
2 changes: 0 additions & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import Sequence
from functools import partial, wraps
import unittest
import warnings

import torch
Expand Down Expand Up @@ -685,7 +684,6 @@ class TestJit(JitCommonTestCase):
# and runtimes (eager, traced, scripted).
# TODO WARNING: inplace x {traced, scripted} not currently tested
@_variant_ops(op_db)
@unittest.skipIf(True, "Temporarily skipping while landing Union PR stack")
def test_variant_consistency_jit(self, device, dtype, op):
_requires_grad = op.supports_autograd and (dtype.is_floating_point or
op.supports_complex_autograd(torch.device(device).type))
Expand Down
Loading