From fee585b5a3d7a61a5a62df81cae596b486f409f2 Mon Sep 17 00:00:00 2001 From: tmanlaibaatar Date: Thu, 29 Oct 2020 11:58:10 -0700 Subject: [PATCH] Correctly mark unannotated NamedTuple field to be inferred TensorType (#46969) Summary: If there is no annotation given, we want to show users that the type is inferred Pull Request resolved: https://github.com/pytorch/pytorch/pull/46969 Test Plan: Added a new test case that throws an error with the expected error message Fixes https://github.com/pytorch/pytorch/issues/46326 Reviewed By: ZolotukhinM Differential Revision: D24614450 Pulled By: gmagogsfm fbshipit-source-id: dec555a53bfaa9cdefd3b21b5142f5e522847504 --- test/test_jit.py | 19 ++++++++++++++++++- torch/_jit_internal.py | 2 +- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 9994754dcc42..f4c3a8d95b69 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -82,7 +82,7 @@ from itertools import product import itertools from textwrap import dedent -from typing import List, Dict, Optional, Tuple, Union +from typing import List, Dict, NamedTuple, Optional, Tuple, Union import inspect import math import functools @@ -13796,6 +13796,23 @@ def test_non_primitive_types(x): out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0))) self.assertEqual(out, torch.tensor(6.0)) + def test_namedtuple_type_inference(self): + _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)]) + _UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value']) + + def test_check_named_tuple_value(): + named_tuple = _AnnotatedNamedTuple(1) + return named_tuple.value + + self.checkScript(test_check_named_tuple_value, ()) + + def test_error(): + return _UnannotatedNamedTuple(1) + + with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' " + r"for argument \'value\' but instead found type \'int\'."): + torch.jit.script(test_error) + def test_isinstance_dynamic(self): @torch.jit.script def foo(a): diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 4396155c73ea..a1c800debb59 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -839,7 +839,7 @@ def _get_named_tuple_properties(obj): the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range()) annotations.append(the_type) else: - annotations.append(torch._C.TensorType.get()) + annotations.append(torch._C.TensorType.getInferred()) return type(obj).__name__, fields, annotations