From b059224e33249729d1e1e42c05af25f8c7052259 Mon Sep 17 00:00:00 2001 From: tmanlaibaatar Date: Tue, 27 Oct 2020 20:04:19 -0700 Subject: [PATCH] More descriptive error message for inferred types #46326 Summary: If there is no annotation given, we want to show users that the type is inferred Test Plan: Added a new test case that throws an error with the expected error message Reviewers: Yanan Cao Subscribers: Tasks: #46326 Tags: --- test/test_jit.py | 19 ++++++++++++++++++- torch/_jit_internal.py | 2 +- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 9994754dcc42..f4c3a8d95b69 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -82,7 +82,7 @@ from itertools import product import itertools from textwrap import dedent -from typing import List, Dict, Optional, Tuple, Union +from typing import List, Dict, NamedTuple, Optional, Tuple, Union import inspect import math import functools @@ -13796,6 +13796,23 @@ def test_non_primitive_types(x): out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0))) self.assertEqual(out, torch.tensor(6.0)) + def test_namedtuple_type_inference(self): + _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)]) + _UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value']) + + def test_check_named_tuple_value(): + named_tuple = _AnnotatedNamedTuple(1) + return named_tuple.value + + self.checkScript(test_check_named_tuple_value, ()) + + def test_error(): + return _UnannotatedNamedTuple(1) + + with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' " + r"for argument \'value\' but instead found type \'int\'."): + torch.jit.script(test_error) + def test_isinstance_dynamic(self): @torch.jit.script def foo(a): diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 4396155c73ea..a1c800debb59 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -839,7 +839,7 @@ def _get_named_tuple_properties(obj): the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range()) annotations.append(the_type) else: - annotations.append(torch._C.TensorType.get()) + annotations.append(torch._C.TensorType.getInferred()) return type(obj).__name__, fields, annotations