diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 23751c4fd92b..31b0124ae802 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -7,7 +7,7 @@ from itertools import product as product from torch import Tensor from torch.testing._internal.common_utils import TemporaryFileName -from typing import NamedTuple +from typing import NamedTuple, Optional # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -947,3 +947,39 @@ def forward(self, x): script_module = torch.jit.script(Foo()) with self.assertRaises(RuntimeError): script_module.save("NonExist/path/test.pt") + + def test_save_namedtuple_input_only(self): + """ + Even if a NamedTuple is only used as an input argument, saving and + loading should work correctly. + """ + global FooTuple # see [local resolution in python] + + class FooTuple(NamedTuple): + a: int + + class MyModule(torch.nn.Module): + def forward(self, x: FooTuple) -> torch.Tensor: + return torch.tensor(3) + + m_loaded = self.getExportImportCopy(torch.jit.script(MyModule())) + output = m_loaded(FooTuple(a=5)) + self.assertEqual(output, torch.tensor(3)) + + def test_save_namedtuple_output_only(self): + """ + Even if a NamedTuple is only used as an output argument, saving and + loading should work correctly. + """ + global FooTuple # see [local resolution in python] + + class FooTuple(NamedTuple): + a: int + + class MyModule(torch.nn.Module): + def forward(self) -> Optional[FooTuple]: + return None + + m_loaded = self.getExportImportCopy(torch.jit.script(MyModule())) + output = m_loaded() + self.assertEqual(output, None) diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index 9803829eb683..b9b1d60640c2 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -1255,6 +1255,7 @@ struct PythonPrintImpl { body_ << "def " << func.name() << "("; auto param_it = graph.inputs().begin(); for (const Argument& arg : schema.arguments()) { + registerClassDependencies(arg.type()); std::string arg_name = genName(arg.name()); if (param_it == graph.inputs().begin()) { // the first argument may omit its type when it is implied by context @@ -1273,9 +1274,10 @@ struct PythonPrintImpl { assignValue(*param_it++, arg_name); } - body_ << ") -> " - << schema.returns().at(0).type()->annotation_str(type_printer_) - << ":\n"; + const auto& returnType = schema.returns().at(0).type(); + body_ << ") -> " << returnType->annotation_str(type_printer_) << ":\n"; + registerClassDependencies(returnType); + printBody(graph.block()); }