From 6f2f1a88b40834d6be46ae7ead167c55c1f68002 Mon Sep 17 00:00:00 2001 From: jinwoop Date: Wed, 14 Oct 2020 16:45:52 -0700 Subject: [PATCH] Add named tuple's error message and workaround for RET failure Pull Request resolved: https://github.com/pytorch/pytorch/pull/46347 Added the named tuple's error messages & workarounds when it returns from a function of a class in Pytorch Mobile. To identify the error cases (returning NamedTuple type), I used the following coditions: 1) ins.op == RET (for returing) 2) type->kind() == TypeKind::TupleType (for pruning non-tuple types) 3) type->cast().name() (for pruning Tuple type) - I could use the type's str (str() or repr_str()) directly, but I used whether it has the "name" attribute. Please give the comment for this. [Information of Tuple and NamedTuple types] 1. Tuple type->str(): (int, int) type->repr_str(): Tuple[int, int] type->kind(): TypeKind::TupleType # different with other types type()->cast(): True type()->cast()>name(): False # different with NamedTuple 2. NamedTuple type->str(): __torch__.myNamedTuple type->repr_str(): __torch__.myNamedTuple type->kind(): TypeKind::TupleType # different with other types type()->cast(): True type->cast().name() = True # different with Tuple (From the next diff, I will handle the other error cases: 1) returning List, Dict and 2) accessing Module class's member functions) ghstack-source-id: 114339694 Differential Revision: [D24291962](https://our.internmc.facebook.com/intern/diff/D24291962/) --- test/mobile/test_lite_script_module.py | 30 +++++++++++++++++++ .../csrc/jit/serialization/export_module.cpp | 16 ++++++++++ 2 files changed, 46 insertions(+) diff --git a/test/mobile/test_lite_script_module.py b/test/mobile/test_lite_script_module.py index 253b45be2217..16558953611b 100644 --- a/test/mobile/test_lite_script_module.py +++ b/test/mobile/test_lite_script_module.py @@ -3,6 +3,8 @@ import torch.utils.bundled_inputs import io +from typing import NamedTuple +from collections import namedtuple from torch.jit.mobile import _load_for_lite_interpreter @@ -138,7 +140,35 @@ def forward(self, arg): r"define a pytorch class \(class Foo\(torch\.nn\.Module\)\)\.$"): script_module._save_to_buffer_for_lite_interpreter() + def test_unsupported_return_typing_namedtuple(self): + myNamedTuple = NamedTuple('myNamedTuple', [('a', torch.Tensor)]) + class MyTestModule(torch.nn.Module): + def forward(self): + return myNamedTuple(torch.randn(1)) + + script_module = torch.jit.script(MyTestModule()) + with self.assertRaisesRegex(RuntimeError, + r"A named tuple type is not supported in mobile module. " + r"Workaround: instead of using a named tuple type\'s fields, " + r"use a dictionary type\'s key-value pair itmes or " + r"a pytorch class \(class Foo\(torch\.nn\.Module\)\)\'s attributes."): + script_module._save_to_buffer_for_lite_interpreter() + + def test_unsupported_return_collections_namedtuple(self): + myNamedTuple = namedtuple('myNamedTuple', [('a')]) + + class MyTestModule(torch.nn.Module): + def forward(self): + return myNamedTuple(torch.randn(1)) + + script_module = torch.jit.script(MyTestModule()) + with self.assertRaisesRegex(RuntimeError, + r"A named tuple type is not supported in mobile module. " + r"Workaround: instead of using a named tuple type\'s fields, " + r"use a dictionary type\'s key-value pair itmes or " + r"a pytorch class \(class Foo\(torch\.nn\.Module\)\)\'s attributes."): + script_module._save_to_buffer_for_lite_interpreter() if __name__ == '__main__': unittest.main() diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index 7c42bf052583..548c816bac6a 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -117,6 +117,22 @@ std::pair> getFunctionTuple( TORCH_INTERNAL_ASSERT( false, "Unsupported node kind on CALL opcode for mobile"); } + } else if (ins.op == RET) { + auto node = code.instructions_source()[i]; + for (const auto& input : node->inputs()) { + const auto& input_type = input->type(); + if (input_type->kind() == TypeKind::TupleType) { + if (const auto& name_typed_input = + input_type->cast()) { + TORCH_CHECK( + !name_typed_input->name(), + "A named tuple type is not supported in mobile module. ", + "Workaround: instead of using a named tuple type's fields, ", + "use a dictionary type's key-value pair itmes or ", + "a pytorch class (class Foo(torch.nn.Module))'s attributes.'"); + } + } + } } else { TORCH_CHECK( ins.op != CREATE_OBJECT,