Skip to content

Commit

Permalink
Add named tuple's error message and workaround for RET failure
Browse files Browse the repository at this point in the history
Pull Request resolved: #46347

Added the named tuple's error messages & workarounds when it returns from a function of a class in Pytorch Mobile.

To identify the error cases (returning NamedTuple type), I used the following coditions:
1) ins.op == RET  (for returing)
2) type->kind() == TypeKind::TupleType  (for pruning non-tuple types)
3) type->cast<TupleType>().name()  (for pruning Tuple type)
  - I could use the type's str (str() or repr_str()) directly, but I used whether it has the "name" attribute. Please give the comment for this.


[Information of Tuple and NamedTuple types]
1. Tuple
type->str(): (int, int)
type->repr_str(): Tuple[int, int]
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type()->cast<NamedType>()>name(): False    # different with NamedTuple

2. NamedTuple
type->str():  __torch__.myNamedTuple
type->repr_str(): __torch__.myNamedTuple
type->kind():  TypeKind::TupleType         # different with other types
type()->cast<NamedType>(): True
type->cast<TupleType>().name() = True      # different with Tuple

(From the next diff, I will handle the other error cases: 1) returning List<module class>, Dict<module class> and 2) accessing Module class's member functions)
ghstack-source-id: 114361762

Differential Revision: [D24291962](https://our.internmc.facebook.com/intern/diff/D24291962/)
  • Loading branch information
jinwoop committed Oct 15, 2020
1 parent 635aebd commit e6bfb9e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
30 changes: 30 additions & 0 deletions test/mobile/test_lite_script_module.py
Expand Up @@ -3,6 +3,8 @@
import torch.utils.bundled_inputs

import io
from typing import NamedTuple
from collections import namedtuple

from torch.jit.mobile import _load_for_lite_interpreter

Expand Down Expand Up @@ -138,7 +140,35 @@ def forward(self, arg):
r"define a pytorch class \(class Foo\(torch\.nn\.Module\)\)\.$"):
script_module._save_to_buffer_for_lite_interpreter()

def test_unsupported_return_typing_namedtuple(self):
myNamedTuple = NamedTuple('myNamedTuple', [('a', torch.Tensor)])

class MyTestModule(torch.nn.Module):
def forward(self):
return myNamedTuple(torch.randn(1))

script_module = torch.jit.script(MyTestModule())
with self.assertRaisesRegex(RuntimeError,
r"A named tuple type is not supported in mobile module. "
r"Workaround: instead of using a named tuple type\'s fields, "
r"use a dictionary type\'s key-value pair itmes or "
r"a pytorch class \(class Foo\(torch\.nn\.Module\)\)\'s attributes."):
script_module._save_to_buffer_for_lite_interpreter()

def test_unsupported_return_collections_namedtuple(self):
myNamedTuple = namedtuple('myNamedTuple', [('a')])

class MyTestModule(torch.nn.Module):
def forward(self):
return myNamedTuple(torch.randn(1))

script_module = torch.jit.script(MyTestModule())
with self.assertRaisesRegex(RuntimeError,
r"A named tuple type is not supported in mobile module. "
r"Workaround: instead of using a named tuple type\'s fields, "
r"use a dictionary type\'s key-value pair itmes or "
r"a pytorch class \(class Foo\(torch\.nn\.Module\)\)\'s attributes."):
script_module._save_to_buffer_for_lite_interpreter()

if __name__ == '__main__':
unittest.main()
16 changes: 16 additions & 0 deletions torch/csrc/jit/serialization/export_module.cpp
Expand Up @@ -117,6 +117,22 @@ std::pair<IValue, c10::optional<IValue>> getFunctionTuple(
TORCH_INTERNAL_ASSERT(
false, "Unsupported node kind on CALL opcode for mobile");
}
} else if (ins.op == RET) {
auto node = code.instructions_source()[i];
for (const auto& input : node->inputs()) {
const auto& input_type = input->type();
if (input_type->kind() == TypeKind::TupleType) {
if (const auto& name_typed_input =
input_type->cast<at::NamedType>()) {
TORCH_CHECK(
!name_typed_input->name(),
"A named tuple type is not supported in mobile module. ",
"Workaround: instead of using a named tuple type's fields, ",
"use a dictionary type's key-value pair itmes or ",
"a pytorch class (class Foo(torch.nn.Module))'s attributes.'");
}
}
}
} else {
TORCH_CHECK(
ins.op != CREATE_OBJECT,
Expand Down

0 comments on commit e6bfb9e

Please sign in to comment.