Skip to content

Commit

Permalink
More descriptive error message for inferred types #46326
Browse files Browse the repository at this point in the history
Summary: If there is no annotation given, we want to show users that the type is inferred

Test Plan: Added a new test case that throws an error with the expected error message

Reviewers: Yanan Cao

Subscribers:

Tasks: #46326

Tags:
  • Loading branch information
tmanlaibaatar authored and tugsbayasgalan committed Oct 29, 2020
1 parent c2a3951 commit b059224
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
19 changes: 18 additions & 1 deletion test/test_jit.py
Expand Up @@ -82,7 +82,7 @@
from itertools import product
import itertools
from textwrap import dedent
from typing import List, Dict, Optional, Tuple, Union
from typing import List, Dict, NamedTuple, Optional, Tuple, Union
import inspect
import math
import functools
Expand Down Expand Up @@ -13796,6 +13796,23 @@ def test_non_primitive_types(x):
out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0)))
self.assertEqual(out, torch.tensor(6.0))

def test_namedtuple_type_inference(self):
_AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)])
_UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value'])

def test_check_named_tuple_value():
named_tuple = _AnnotatedNamedTuple(1)
return named_tuple.value

self.checkScript(test_check_named_tuple_value, ())

def test_error():
return _UnannotatedNamedTuple(1)

with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' "
r"for argument \'value\' but instead found type \'int\'."):
torch.jit.script(test_error)

def test_isinstance_dynamic(self):
@torch.jit.script
def foo(a):
Expand Down
2 changes: 1 addition & 1 deletion torch/_jit_internal.py
Expand Up @@ -839,7 +839,7 @@ def _get_named_tuple_properties(obj):
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range())
annotations.append(the_type)
else:
annotations.append(torch._C.TensorType.get())
annotations.append(torch._C.TensorType.getInferred())
return type(obj).__name__, fields, annotations


Expand Down

0 comments on commit b059224

Please sign in to comment.