Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correctly mark unannotated NamedTuple field to be inferred TensorType #46969

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 18 additions & 1 deletion test/test_jit.py
Expand Up @@ -82,7 +82,7 @@
from itertools import product
import itertools
from textwrap import dedent
from typing import List, Dict, Optional, Tuple, Union
from typing import List, Dict, NamedTuple, Optional, Tuple, Union
import inspect
import math
import functools
Expand Down Expand Up @@ -13796,6 +13796,23 @@ def test_non_primitive_types(x):
out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0)))
self.assertEqual(out, torch.tensor(6.0))

def test_namedtuple_type_inference(self):
_AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)])
_UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value'])

def test_check_named_tuple_value():
named_tuple = _AnnotatedNamedTuple(1)
return named_tuple.value

self.checkScript(test_check_named_tuple_value, ())

def test_error():
tugsbayasgalan marked this conversation as resolved.
Show resolved Hide resolved
return _UnannotatedNamedTuple(1)

with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' "
r"for argument \'value\' but instead found type \'int\'."):
torch.jit.script(test_error)

def test_isinstance_dynamic(self):
@torch.jit.script
def foo(a):
Expand Down
2 changes: 1 addition & 1 deletion torch/_jit_internal.py
Expand Up @@ -839,7 +839,7 @@ def _get_named_tuple_properties(obj):
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range())
annotations.append(the_type)
else:
annotations.append(torch._C.TensorType.get())
annotations.append(torch._C.TensorType.getInferred())
return type(obj).__name__, fields, annotations


Expand Down