Skip to content

Commit

Permalink
[pytorch] recast infer_type error and amend with name and item that f…
Browse files Browse the repository at this point in the history
…ailed inferring

Summary:
When type inference fails when JITing torchscript module, the error message does not give any implication where the error fails. For example:  "Cannot create dict for key type 'int?', only int, float, complex, Tensor and string keys are supported".

This adds the variable name and item to the error message.

Reviewed By: ajaech

Differential Revision: D26327483

fbshipit-source-id: d8c85e7550258d7c56530f5826ff9683ca8b2b94
  • Loading branch information
Aapo Kyrola authored and facebook-github-bot committed Feb 10, 2021
1 parent 12d85b5 commit e964d77
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions torch/jit/_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,20 @@ def infer_type(name, item):
# isinstance on typing things doesn't seem to work: isinstance(list, Callable)
# is also true!
inferred = False
if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]:
ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range())
attr_type = torch._C.InferredType(ann_to_type)
elif isinstance(item, torch.jit.Attribute):
ann_to_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range())
attr_type = torch._C.InferredType(ann_to_type)
else:
attr_type = torch._C._jit_try_infer_type(item)
inferred = True
try:
if name in class_annotations and class_annotations[name] != torch.nn.Module.__annotations__["forward"]:
ann_to_type = torch.jit.annotations.ann_to_type(class_annotations[name], _jit_internal.fake_range())
attr_type = torch._C.InferredType(ann_to_type)
elif isinstance(item, torch.jit.Attribute):
ann_to_type = torch.jit.annotations.ann_to_type(item.type, _jit_internal.fake_range())
attr_type = torch._C.InferredType(ann_to_type)
else:
attr_type = torch._C._jit_try_infer_type(item)
inferred = True
except RuntimeError as re:
raise RuntimeError(
"Error inferring type for {name}: {item}: {re}".format(name=name, item=item, re=re)
)

return attr_type, inferred

Expand Down

0 comments on commit e964d77

Please sign in to comment.