Skip to content

Commit

Permalink
Properly serialize types that only appear at function input
Browse files Browse the repository at this point in the history
When serializing graphs, we check every node for named types referenced,
so that we can register them as dependencies. We were skipping this
check for the graph inputs themselves. Since types used at input are
almost always used somewhere in the graph, we never noticed this gap
until a user reported an issue with NamedTuples.

ghstack-source-id: 50b22f2c7b1f1a39c374ab125e3981a06bf9b2f7
Pull Request resolved: #47775
  • Loading branch information
suo committed Nov 11, 2020
1 parent 4cb73f5 commit a69016b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
38 changes: 37 additions & 1 deletion test/jit/test_save_load.py
Expand Up @@ -7,7 +7,7 @@
from itertools import product as product
from torch import Tensor
from torch.testing._internal.common_utils import TemporaryFileName
from typing import NamedTuple
from typing import NamedTuple, Optional

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
Expand Down Expand Up @@ -947,3 +947,39 @@ def forward(self, x):
script_module = torch.jit.script(Foo())
with self.assertRaises(RuntimeError):
script_module.save("NonExist/path/test.pt")

def test_save_namedtuple_input_only(self):
"""
Even if a NamedTuple is only used as an input argument, saving and
loading should work correctly.
"""
global FooTuple # see [local resolution in python]

class FooTuple(NamedTuple):
a: int

class MyModule(torch.nn.Module):
def forward(self, x: FooTuple) -> torch.Tensor:
return torch.tensor(3)

m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
output = m_loaded(FooTuple(a=5))
self.assertEqual(output, torch.tensor(3))

def test_save_namedtuple_output_only(self):
"""
Even if a NamedTuple is only used as an output argument, saving and
loading should work correctly.
"""
global FooTuple # see [local resolution in python]

class FooTuple(NamedTuple):
a: int

class MyModule(torch.nn.Module):
def forward(self) -> Optional[FooTuple]:
return None

m_loaded = self.getExportImportCopy(torch.jit.script(MyModule()))
output = m_loaded()
self.assertEqual(output, None)
8 changes: 5 additions & 3 deletions torch/csrc/jit/serialization/python_print.cpp
Expand Up @@ -1255,6 +1255,7 @@ struct PythonPrintImpl {
body_ << "def " << func.name() << "(";
auto param_it = graph.inputs().begin();
for (const Argument& arg : schema.arguments()) {
registerClassDependencies(arg.type());
std::string arg_name = genName(arg.name());
if (param_it == graph.inputs().begin()) {
// the first argument may omit its type when it is implied by context
Expand All @@ -1273,9 +1274,10 @@ struct PythonPrintImpl {
assignValue(*param_it++, arg_name);
}

body_ << ") -> "
<< schema.returns().at(0).type()->annotation_str(type_printer_)
<< ":\n";
const auto& returnType = schema.returns().at(0).type();
body_ << ") -> " << returnType->annotation_str(type_printer_) << ":\n";
registerClassDependencies(returnType);

printBody(graph.block());
}

Expand Down

0 comments on commit a69016b

Please sign in to comment.