Skip to content

Commit

Permalink
Correctly mark unannotated NamedTuple field to be inferred TensorType (
Browse files Browse the repository at this point in the history
…#46969)

Summary:
If there is no annotation given, we want to show users that the type is inferred

Pull Request resolved: #46969

Test Plan:
Added a new test case that throws an error with the expected error message

Fixes #46326

Reviewed By: ZolotukhinM

Differential Revision: D24614450

Pulled By: gmagogsfm

fbshipit-source-id: dec555a53bfaa9cdefd3b21b5142f5e522847504
  • Loading branch information
tmanlaibaatar authored and facebook-github-bot committed Oct 29, 2020
1 parent 1e275bc commit fee585b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
19 changes: 18 additions & 1 deletion test/test_jit.py
Expand Up @@ -82,7 +82,7 @@
from itertools import product
import itertools
from textwrap import dedent
from typing import List, Dict, Optional, Tuple, Union
from typing import List, Dict, NamedTuple, Optional, Tuple, Union
import inspect
import math
import functools
Expand Down Expand Up @@ -13796,6 +13796,23 @@ def test_non_primitive_types(x):
out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0)))
self.assertEqual(out, torch.tensor(6.0))

def test_namedtuple_type_inference(self):
_AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)])
_UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value'])

def test_check_named_tuple_value():
named_tuple = _AnnotatedNamedTuple(1)
return named_tuple.value

self.checkScript(test_check_named_tuple_value, ())

def test_error():
return _UnannotatedNamedTuple(1)

with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' "
r"for argument \'value\' but instead found type \'int\'."):
torch.jit.script(test_error)

def test_isinstance_dynamic(self):
@torch.jit.script
def foo(a):
Expand Down
2 changes: 1 addition & 1 deletion torch/_jit_internal.py
Expand Up @@ -839,7 +839,7 @@ def _get_named_tuple_properties(obj):
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range())
annotations.append(the_type)
else:
annotations.append(torch._C.TensorType.get())
annotations.append(torch._C.TensorType.getInferred())
return type(obj).__name__, fields, annotations


Expand Down

0 comments on commit fee585b

Please sign in to comment.