From 0d087de3229bcda2479fc431efe551447d9584aa Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 21 May 2019 16:12:54 -0700 Subject: [PATCH 1/6] Support tuples in ScriptModule inputs/outputs --- torch/csrc/jit/script/init.cpp | 87 ++++++++++++++++++++++++---------- torch/onnx/utils.py | 20 +++++--- 2 files changed, 74 insertions(+), 33 deletions(-) diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index b77ef8b7646281c..07ae5ebe8e0d11e 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -178,11 +178,49 @@ static Self moduleSelf( }; } -static void setInputTensorTypes(Graph& g, const Stack& stack) { - AT_ASSERT(stack.size() == g.inputs().size()); - for (size_t i = 0; i < stack.size(); ++i) { - g.inputs().at(i)->setType( - DimensionedTensorType::create(stack.at(i).toTensor())); +static TypePtr getTensorType( + const at::Tensor& t, + bool isDimensionedTensor) { + if (isDimensionedTensor) { + return DimensionedTensorType::create(t); + } + auto scalar_type = t.scalar_type(); + auto sizes = t.sizes(); + return CompleteTensorType::create(scalar_type, at::kCPU, sizes); +} + +static TupleTypePtr getTupleTensorType( + const Stack::const_iterator& s_iter, + const Stack::const_iterator& s_iter_end, + const TypePtr& tupleType, + bool isDimensionedTensor) { + AT_ASSERT(tupleType->kind() == TupleType::Kind); + AT_ASSERT(s_iter != s_iter_end); + + std::vector types; + for (auto subType : tupleType->containedTypes()) { + if (subType->kind() == TupleType::Kind) { + types.push_back(getTupleTensorType(s_iter+1, s_iter_end, subType, isDimensionedTensor)); + } else { + types.push_back(getTensorType(s_iter->toTensor(), isDimensionedTensor)); + } + } + return TupleType::create(types); +} + +static void setInputTensorTypes(Graph& g, const Stack& stack, bool isDimensionedTensor=true) { + at::ArrayRef input_values = g.inputs(); + auto s_iter = stack.begin(); + for (auto v : input_values) { + AT_ASSERT(s_iter != stack.end()); + if (v->type()->kind() == TupleType::Kind) { + AT_ASSERT(v->node()->kind() == prim::Param); + v->setType( + getTupleTensorType(s_iter, stack.end(), v->type(), isDimensionedTensor)); + } else { + v->setType(getTensorType(s_iter->toTensor(), isDimensionedTensor)); + s_iter++; + } } } @@ -197,38 +235,32 @@ static std::shared_ptr _propagate_shapes( return retval; } -static std::shared_ptr _propagate_and_assign_input_and_output_shapes( +static std::shared_ptr _propagate_and_assign_input_shapes( Graph& graph, std::vector inputs, - std::vector outputs, bool with_grad = false, bool propagate = true) { auto retval = graph.copy(); if (propagate) { - setInputTensorTypes(*retval, fmap(inputs)); + setInputTensorTypes(*retval, fmap(inputs), propagate); PropagateInputShapes(retval); } - AT_ASSERT(retval->inputs().size() == inputs.size()); - for (size_t i = 0; i < retval->inputs().size(); ++i) { - auto scalar_type = inputs[i].scalar_type(); - auto sizes = inputs[i].sizes(); - auto type = - torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); - retval->inputs()[i]->setType(type); - } - at::ArrayRef output_values = retval->outputs(); - // patch this to still work if we are returning a tuple of multiple values - if (output_values.at(0)->type()->kind() == TupleType::Kind) { - AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct); - output_values = output_values.at(0)->node()->inputs(); - } - AT_ASSERT(output_values.size() == outputs.size()); + setInputTensorTypes(*retval, fmap(inputs), false); + + return retval; +} + +static std::shared_ptr _assign_output_shapes( + Graph& graph, + std::vector outputs) { + auto retval = graph.copy(); + AT_ASSERT(retval->outputs().size() == outputs.size()); for (size_t i = 0; i < outputs.size(); ++i) { auto scalar_type = outputs[i].scalar_type(); auto sizes = outputs[i].sizes(); auto type = torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes); - output_values[i]->setType(type); + retval->outputs()[i]->setType(type); } return retval; } @@ -679,8 +711,11 @@ void initJitScriptBindings(PyObject* module) { debugSetAutodiffSubgraphInlining); m.def("_propagate_shapes", _propagate_shapes); m.def( - "_propagate_and_assign_input_and_output_shapes", - _propagate_and_assign_input_and_output_shapes); + "_propagate_and_assign_input_shapes", + _propagate_and_assign_input_shapes); + m.def( + "_assign_output_shapes", + _assign_output_shapes); m.def("_jit_python_print", [](py::object obj) { std::ostringstream ss; std::vector constants; diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index eb2a972bce1ef50..e9403f23d6ab440 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -16,7 +16,7 @@ from torch._six import string_classes from torch.jit import _unique_state_dict from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes -from torch._C import ListType, _propagate_and_assign_input_and_output_shapes +from torch._C import ListType, _propagate_and_assign_input_shapes, _assign_output_shapes # the flag to tell the user whether it's in the middle of ONNX export or not @@ -214,10 +214,10 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0 torch._C._jit_pass_prepare_division_for_onnx(graph) - # onnx only supports tensors, so we turn all out number types into tensors - torch._C._jit_pass_erase_number_types(graph) # onnx does not support tuples, so try to remove them torch._C._jit_pass_lower_all_tuples(graph) + # onnx only supports tensors, so we turn all out number types into tensors + torch._C._jit_pass_erase_number_types(graph) torch._C._jit_pass_peephole(graph, True) torch._C._jit_pass_lint(graph) @@ -288,16 +288,18 @@ def _model_to_graph(model, args, verbose=False, training=False, assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule" try: method_graph, params = model.forward._lowered_graph() - graph = _propagate_and_assign_input_and_output_shapes( - method_graph, tuple(args) + tuple(params), example_outputs, False, propagate) + in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params)) + graph = _propagate_and_assign_input_shapes( + method_graph, tuple(in_vars), False, propagate) except AttributeError: raise RuntimeError('\'forward\' method must be a script method') elif isinstance(model, torch.jit.Function): assert example_outputs is not None, "example_outputs must be provided when exporting a TorchScript Function" method = model params = () - graph = _propagate_and_assign_input_and_output_shapes( - model.graph, tuple(args), example_outputs, False, propagate) + in_vars, in_desc = torch.jit._flatten(tuple(args)) + graph = _propagate_and_assign_input_shapes( + model.graph, tuple(in_vars), False, propagate) else: graph, torch_out = _trace_and_get_graph_from_model(model, args, training) state_dict = _unique_state_dict(model) @@ -313,6 +315,10 @@ def _model_to_graph(model, args, verbose=False, training=False, graph = _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=_disable_torch_constant_prop) + if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.Function): + out_vars, _ = torch.jit._flatten(tuple(example_outputs)) + graph = _assign_output_shapes(graph, out_vars) + # NB: ONNX requires complete information about output types, which might be # erased by some optimizations, so we need to set it explicitly again. if torch_out is not None: From 75ba0f3624ecba481af3c0f467403bf80cd4a782 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Wed, 22 May 2019 18:26:24 -0700 Subject: [PATCH 2/6] add tests --- test/onnx/test_pytorch_onnx_caffe2.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index a31a1b7cc78188c..43f8a05792cbe33 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -1651,6 +1651,29 @@ def forward(self, lstm_in): self.run_model_test(MyModel(), train=False, input=lstm_in, batch_size=3, use_gpu=False) + def test_tuple_input_output(self): + class TupleModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, a): + # type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]) + return a + + x = (torch.randn(3, 4), torch.randn(4, 3)) + self.run_model_test(TupleModel(), train=False, input=(x,), batch_size=BATCH_SIZE, + example_outputs=(x,)) + + def test_nested_tuple_input_output(self): + class NestedTupleModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, a, b): + # type: (Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor + return a + b[0] + b[1][0] + b[1][1] + + x = torch.randn(4, 5) + y = (torch.randn(4, 5), (torch.randn(4, 5), torch.randn(4, 5))) + self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE, + example_outputs=x + y[0] + y[1][0] + y[1][1]) + def test_topk(self): class TopKModel(torch.nn.Module): def forward(self, input): From fab7b72dc34cdbd688ed5bc1f224dbb9e22620de Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 30 May 2019 10:08:08 -0700 Subject: [PATCH 3/6] fix lint error --- test/onnx/test_pytorch_onnx_caffe2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 43f8a05792cbe33..9b139326ed44ec1 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -1655,7 +1655,7 @@ def test_tuple_input_output(self): class TupleModel(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, a): - # type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]) + # type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor] return a x = (torch.randn(3, 4), torch.randn(4, 3)) From f8df73a0eba5f81ae0cdc744b5d4425d90530ee4 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Thu, 6 Jun 2019 10:13:00 -0700 Subject: [PATCH 4/6] try to fix lint again --- torch/csrc/jit/script/init.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 07ae5ebe8e0d11e..8eb4697da1030ef 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -198,7 +198,7 @@ static TupleTypePtr getTupleTensorType( AT_ASSERT(s_iter != s_iter_end); std::vector types; - for (auto subType : tupleType->containedTypes()) { + for (const auto& subType : tupleType->containedTypes()) { if (subType->kind() == TupleType::Kind) { types.push_back(getTupleTensorType(s_iter+1, s_iter_end, subType, isDimensionedTensor)); } else { @@ -237,7 +237,7 @@ static std::shared_ptr _propagate_shapes( static std::shared_ptr _propagate_and_assign_input_shapes( Graph& graph, - std::vector inputs, + const std::vector& inputs, bool with_grad = false, bool propagate = true) { auto retval = graph.copy(); From de0b66a1717c33398bcc84bfb1c635b09d4c1048 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Wed, 12 Jun 2019 13:21:44 -0700 Subject: [PATCH 5/6] improve parameter naming --- torch/csrc/jit/script/init.cpp | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 8eb4697da1030ef..1ba7ca3d24d55ca 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -180,35 +180,41 @@ static Self moduleSelf( static TypePtr getTensorType( const at::Tensor& t, - bool isDimensionedTensor) { - if (isDimensionedTensor) { - return DimensionedTensorType::create(t); + const TypeKind type_kind) { + AT_ASSERT(type_kind == TypeKind::DimensionedTensorType || type_kind == TypeKind::CompleteTensorType); + switch (type_kind) { + case TypeKind::DimensionedTensorType: + return DimensionedTensorType::create(t); + case TypeKind::CompleteTensorType: + auto scalar_type = t.scalar_type(); + auto sizes = t.sizes(); + return CompleteTensorType::create(scalar_type, at::kCPU, sizes); } - auto scalar_type = t.scalar_type(); - auto sizes = t.sizes(); - return CompleteTensorType::create(scalar_type, at::kCPU, sizes); } static TupleTypePtr getTupleTensorType( const Stack::const_iterator& s_iter, const Stack::const_iterator& s_iter_end, const TypePtr& tupleType, - bool isDimensionedTensor) { + const TypeKind type_kind) { AT_ASSERT(tupleType->kind() == TupleType::Kind); AT_ASSERT(s_iter != s_iter_end); std::vector types; for (const auto& subType : tupleType->containedTypes()) { if (subType->kind() == TupleType::Kind) { - types.push_back(getTupleTensorType(s_iter+1, s_iter_end, subType, isDimensionedTensor)); + types.push_back(getTupleTensorType(s_iter+1, s_iter_end, subType, type_kind)); } else { - types.push_back(getTensorType(s_iter->toTensor(), isDimensionedTensor)); + types.push_back(getTensorType(s_iter->toTensor(), type_kind)); } } return TupleType::create(types); } -static void setInputTensorTypes(Graph& g, const Stack& stack, bool isDimensionedTensor=true) { +static void setInputTensorTypes( + Graph& g, + const Stack& stack, + const TypeKind type_kind = TypeKind::DimensionedTensorType) { at::ArrayRef input_values = g.inputs(); auto s_iter = stack.begin(); for (auto v : input_values) { @@ -216,9 +222,9 @@ static void setInputTensorTypes(Graph& g, const Stack& stack, bool isDimensioned if (v->type()->kind() == TupleType::Kind) { AT_ASSERT(v->node()->kind() == prim::Param); v->setType( - getTupleTensorType(s_iter, stack.end(), v->type(), isDimensionedTensor)); + getTupleTensorType(s_iter, stack.end(), v->type(), type_kind)); } else { - v->setType(getTensorType(s_iter->toTensor(), isDimensionedTensor)); + v->setType(getTensorType(s_iter->toTensor(), type_kind)); s_iter++; } } @@ -242,10 +248,10 @@ static std::shared_ptr _propagate_and_assign_input_shapes( bool propagate = true) { auto retval = graph.copy(); if (propagate) { - setInputTensorTypes(*retval, fmap(inputs), propagate); + setInputTensorTypes(*retval, fmap(inputs), TypeKind::DimensionedTensorType); PropagateInputShapes(retval); } - setInputTensorTypes(*retval, fmap(inputs), false); + setInputTensorTypes(*retval, fmap(inputs), TypeKind::CompleteTensorType); return retval; } From 9f53df34bd42b4afd0c90ad6319156d6259ba811 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Wed, 12 Jun 2019 13:44:41 -0700 Subject: [PATCH 6/6] fix switch warning --- torch/csrc/jit/script/init.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 1ba7ca3d24d55ca..201945a2c32e400 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -181,14 +181,17 @@ static Self moduleSelf( static TypePtr getTensorType( const at::Tensor& t, const TypeKind type_kind) { - AT_ASSERT(type_kind == TypeKind::DimensionedTensorType || type_kind == TypeKind::CompleteTensorType); switch (type_kind) { case TypeKind::DimensionedTensorType: return DimensionedTensorType::create(t); - case TypeKind::CompleteTensorType: + case TypeKind::CompleteTensorType: { auto scalar_type = t.scalar_type(); auto sizes = t.sizes(); return CompleteTensorType::create(scalar_type, at::kCPU, sizes); + } + default: + throw std::runtime_error( + "Attempted to call getTensorType for type kind other than DimensionedTensorType or CompleteTensorType."); } }