From f2b417b2f36a9353949b007343dd8596f89fa0d9 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Wed, 4 Nov 2020 01:01:30 -0800 Subject: [PATCH 01/21] Track and list model params for freezed module --- tools/build_variables.bzl | 1 + .../jit/passes/onnx/list_model_parameters.cpp | 159 ++++++++++++++++++ .../jit/passes/onnx/list_model_parameters.h | 13 ++ torch/csrc/jit/python/init.cpp | 4 + torch/onnx/utils.py | 4 +- 5 files changed, 179 insertions(+), 2 deletions(-) create mode 100644 torch/csrc/jit/passes/onnx/list_model_parameters.cpp create mode 100644 torch/csrc/jit/passes/onnx/list_model_parameters.h diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index b717b58d7f81..7a026dd1d468 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -498,6 +498,7 @@ libtorch_python_core_sources = [ "torch/csrc/jit/passes/onnx/constant_fold.cpp", "torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp", "torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp", + "torch/csrc/jit/passes/onnx/list_model_parameters.cpp", "torch/csrc/jit/passes/onnx/function_substitution.cpp", "torch/csrc/jit/passes/onnx/helper.cpp", "torch/csrc/jit/passes/onnx/peephole.cpp", diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp new file mode 100644 index 000000000000..549918c278f6 --- /dev/null +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -0,0 +1,159 @@ +#include +#include +#include + +namespace torch { +namespace jit { + +std::deque names_; + +// findSubModuleAttr function chases getAttr chains to locate the submodules. +// For example: +// module M { +// attributes { +// A = +// } +// ... +// %A = prim::GetAttr[name="A"](%self) +// ... +// %B = prim::GetAttr[name="B"](%A) +// ... +// %weight = prim::GetAttr[name="scale"](%B) +// ... +void findSubModuleAttr( + Value* input, + std::string& name, + Module& attrModule, + std::shared_ptr& graph) { + Node* node = input->node(); + names_.clear(); + while (!(node->outputs()[0]->type() == graph->inputs()[0]->type())) { + if (node->kind() == prim::GetAttr) { + names_.push_front(node->s(attr::name)); + node = node->inputs()[0]->node(); + } + } + + for (auto& moduleName : names_) { + attrModule = attrModule.attr(moduleName).toModule(); + } +} + +Value* addParamAsArgument(Function* function, std::string& name, IValue& attr) { + auto schema = function->getSchema(); + auto args = schema.arguments(); + args.emplace_back(Argument(name, nullptr, c10::nullopt, attr)); + auto new_schema = FunctionSchema( + schema.name(), + schema.overload_name(), + args, + schema.returns(), + schema.is_vararg(), + schema.is_varret()); + function->setSchema(new_schema); + return function->graph()->addInput(name)->setType(attr.type()); +} + +std::vector getParamAttributes( + std::shared_ptr& graph, + Module module_, + Function* function_) { + std::vector attrValues; + auto isEval = !module_.hasattr("training") || !module_.is_training(); + auto block = graph->block(); + std::vector blocks({block}); + + Node* m = *block->nodes().begin(); + WithInsertPoint guard(m); + + while (!blocks.empty()) { + Block* block = blocks.back(); + blocks.pop_back(); + for (auto it = block->nodes().begin(); it != block->nodes().end();) { + Node* n = *it; + it++; // node n can be destroyed + + for (Block* sub_block : n->blocks()) { + blocks.emplace_back(sub_block); + } + if (n->kind() == prim::SetAttr) { + if (!(n->outputs().size())) { + n->destroy(); + } + } else if (n->kind() == prim::GetAttr) { + auto name = n->s(attr::name); + auto attrModule = module_; + auto input = n->inputs()[0]; + + findSubModuleAttr(input, name, attrModule, graph); + + TORCH_INTERNAL_ASSERT(attrModule.hasattr(name)); + Value* paramConst = nullptr; + + auto attr = attrModule.attr(name); + + std::string fullName("self_"); + for (auto& name : names_) { + fullName += name + '_'; + } + fullName += name; + + auto type = attrModule.type(); + auto slot = *type->findAttributeSlot(name); + + if (type->is_parameter(slot) || type->is_buffer(slot)) { + if (type->is_parameter(slot) || type->is_buffer(slot)) { + if (attr.isTensor()) { + TORCH_INTERNAL_ASSERT(attr.isTensor()); + auto tensor_ = attr.toTensor(); + if (isEval && tensor_.requires_grad()) { + tensor_ = tensor_.detach(); + tensor_.set_requires_grad(false); + attr = IValue(tensor_); + } + attrValues.push_back(attr.toTensor()); + paramConst = addParamAsArgument(function_, fullName, attr); + } else if (attr.isNone()) { + auto attrVal = tryInsertConstant(*graph, attr); + paramConst = *attrVal; + } + n->output()->replaceAllUsesWith(paramConst); + n->removeAllInputs(); + + GRAPH_UPDATE( + "Folding GetAttr %", + n->outputs()[0]->debugName(), + " with ", + paramConst->debugName()); + } + } + } + } + } + return attrValues; +} + +std::pair> list_module_parameters( + const Module& module) { + Module moduleClone = module.clone(true); + Method method = moduleClone.get_method("forward"); + std::unordered_set preservedMethods_; + preservedMethods_.insert(&method.function()); + + std::vector modelParams; + for (auto function : preservedMethods_) { + GRAPH_DEBUG("List attributes for function: " + function->name()); + auto graph = function->graph(); + auto attributes = getParamAttributes(graph, moduleClone, function); + for (auto attr_ : attributes) { + modelParams.push_back(attr_); + } + GRAPH_DEBUG("Cleaning up module"); + EliminateDeadCode(graph->block()); + } + + return std::make_pair(moduleClone, modelParams); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.h b/torch/csrc/jit/passes/onnx/list_model_parameters.h new file mode 100644 index 000000000000..50d1cea2b8fe --- /dev/null +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { + +TORCH_API std::pair> list_module_parameters( + const Module& module); + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 7e3ebd86e9e9..d2aa1dfa43f3 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -39,6 +39,7 @@ #include #include #include +#include #include #include #include @@ -279,6 +280,9 @@ void initJITBindings(PyObject* module) { "_jit_pass_quant_fusion", [](std::shared_ptr& g) { return QuantFusion(g); }) .def("_jit_pass_fold_convbn", &FoldConvBatchNorm) + .def( + "_jit_onnx_list_model_parameters", + [](Module& module) { return list_module_parameters(module); }) .def( "_freeze_module", [](Module& module, diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index a26a4cdf2332..2a2c2873312e 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -358,10 +358,10 @@ def _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes): if not use_new_jit_passes: method_graph, params = torch._C._jit_pass_lower_graph(graph, model._c) else: - freezed_m = torch._C._freeze_module(model._c) + freezed_m = torch._C._freeze_module(model._c, preserveParameters=True) + freezed_m, params = torch._C._jit_onnx_list_model_parameters(freezed_m) method_graph = freezed_m._get_method('forward').graph method_graph.eraseInput(0) # Remove 'self' from model inputs - params = [] in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params)) graph = _propagate_and_assign_input_shapes( From 466a36ce95d90e48f4e437c6336ec97953898ec4 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Thu, 12 Nov 2020 14:55:55 -0800 Subject: [PATCH 02/21] Fix for training attr --- .../csrc/jit/passes/onnx/list_model_parameters.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index 549918c278f6..1714f18ace4c 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -31,6 +31,8 @@ void findSubModuleAttr( if (node->kind() == prim::GetAttr) { names_.push_front(node->s(attr::name)); node = node->inputs()[0]->node(); + } else { + return; } } @@ -87,7 +89,9 @@ std::vector getParamAttributes( findSubModuleAttr(input, name, attrModule, graph); - TORCH_INTERNAL_ASSERT(attrModule.hasattr(name)); + if (!attrModule.hasattr(name)) { + continue; + } Value* paramConst = nullptr; auto attr = attrModule.attr(name); @@ -101,8 +105,10 @@ std::vector getParamAttributes( auto type = attrModule.type(); auto slot = *type->findAttributeSlot(name); - if (type->is_parameter(slot) || type->is_buffer(slot)) { - if (type->is_parameter(slot) || type->is_buffer(slot)) { + if (type->is_parameter(slot) || type->is_buffer(slot) || + name == "training") { + if (type->is_parameter(slot) || type->is_buffer(slot) || + name == "training") { if (attr.isTensor()) { TORCH_INTERNAL_ASSERT(attr.isTensor()); auto tensor_ = attr.toTensor(); @@ -113,7 +119,7 @@ std::vector getParamAttributes( } attrValues.push_back(attr.toTensor()); paramConst = addParamAsArgument(function_, fullName, attr); - } else if (attr.isNone()) { + } else if (attr.isNone() || name == "training") { auto attrVal = tryInsertConstant(*graph, attr); paramConst = *attrVal; } From 5fa6f05eb86b2e6091bae9b0b7280802a027e3b3 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Sun, 15 Nov 2020 20:13:52 -0800 Subject: [PATCH 03/21] Update list_model_parameters.cpp --- torch/csrc/jit/passes/onnx/list_model_parameters.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index 1714f18ace4c..0f698ed7e7a5 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -78,11 +78,7 @@ std::vector getParamAttributes( for (Block* sub_block : n->blocks()) { blocks.emplace_back(sub_block); } - if (n->kind() == prim::SetAttr) { - if (!(n->outputs().size())) { - n->destroy(); - } - } else if (n->kind() == prim::GetAttr) { + if (n->kind() == prim::GetAttr) { auto name = n->s(attr::name); auto attrModule = module_; auto input = n->inputs()[0]; From 5515ea33facacb000067e5c8034e94f0aba8ac41 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Sun, 15 Nov 2020 20:15:55 -0800 Subject: [PATCH 04/21] Update list_model_parameters.cpp From dfdf28c3e22687eb75de57ca8f4f6368780e2697 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Sun, 15 Nov 2020 22:52:26 -0800 Subject: [PATCH 05/21] Clang fix --- .../jit/passes/onnx/list_model_parameters.cpp | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index 0f698ed7e7a5..5e398fbbf0f5 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -5,8 +5,6 @@ namespace torch { namespace jit { -std::deque names_; - // findSubModuleAttr function chases getAttr chains to locate the submodules. // For example: // module M { @@ -20,25 +18,27 @@ std::deque names_; // ... // %weight = prim::GetAttr[name="scale"](%B) // ... -void findSubModuleAttr( +std::deque findSubModuleAttr( Value* input, std::string& name, Module& attrModule, std::shared_ptr& graph) { Node* node = input->node(); - names_.clear(); + std::deque moduleNames; + while (!(node->outputs()[0]->type() == graph->inputs()[0]->type())) { if (node->kind() == prim::GetAttr) { - names_.push_front(node->s(attr::name)); + moduleNames.push_front(node->s(attr::name)); node = node->inputs()[0]->node(); } else { - return; + return moduleNames; } } - for (auto& moduleName : names_) { + for (auto& moduleName : moduleNames) { attrModule = attrModule.attr(moduleName).toModule(); } + return moduleNames; } Value* addParamAsArgument(Function* function, std::string& name, IValue& attr) { @@ -78,12 +78,15 @@ std::vector getParamAttributes( for (Block* sub_block : n->blocks()) { blocks.emplace_back(sub_block); } - if (n->kind() == prim::GetAttr) { + if (n->kind() == prim::SetAttr && + n->s(attr::name) == "num_batches_tracked") { + n->destroy(); + } else if (n->kind() == prim::GetAttr) { auto name = n->s(attr::name); auto attrModule = module_; auto input = n->inputs()[0]; - findSubModuleAttr(input, name, attrModule, graph); + auto moduleNames = findSubModuleAttr(input, name, attrModule, graph); if (!attrModule.hasattr(name)) { continue; @@ -93,7 +96,7 @@ std::vector getParamAttributes( auto attr = attrModule.attr(name); std::string fullName("self_"); - for (auto& name : names_) { + for (auto& name : moduleNames) { fullName += name + '_'; } fullName += name; @@ -113,7 +116,7 @@ std::vector getParamAttributes( tensor_.set_requires_grad(false); attr = IValue(tensor_); } - attrValues.push_back(attr.toTensor()); + attrValues.emplace_back(attr.toTensor()); paramConst = addParamAsArgument(function_, fullName, attr); } else if (attr.isNone() || name == "training") { auto attrVal = tryInsertConstant(*graph, attr); @@ -147,7 +150,7 @@ std::pair> list_module_parameters( GRAPH_DEBUG("List attributes for function: " + function->name()); auto graph = function->graph(); auto attributes = getParamAttributes(graph, moduleClone, function); - for (auto attr_ : attributes) { + for (auto& attr_ : attributes) { modelParams.push_back(attr_); } GRAPH_DEBUG("Cleaning up module"); From 028c54e46c79154e690544cbb8bec1555a073e70 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Sun, 15 Nov 2020 23:08:53 -0800 Subject: [PATCH 06/21] Update list_model_parameters.cpp --- torch/csrc/jit/passes/onnx/list_model_parameters.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index 5e398fbbf0f5..d3201a655c14 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -58,7 +58,7 @@ Value* addParamAsArgument(Function* function, std::string& name, IValue& attr) { std::vector getParamAttributes( std::shared_ptr& graph, - Module module_, + const Module& module_, Function* function_) { std::vector attrValues; auto isEval = !module_.hasattr("training") || !module_.is_training(); From d7e1deee7aff3d7586ea08439ab78e720b5c2cab Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Wed, 18 Nov 2020 10:20:03 -0800 Subject: [PATCH 07/21] Fix for feedback --- .../jit/passes/onnx/list_model_parameters.cpp | 53 +++++++++---------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index d3201a655c14..495cddc8f26f 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -5,9 +5,8 @@ namespace torch { namespace jit { -// findSubModuleAttr function chases getAttr chains to locate the submodules. -// For example: -// module M { +// findSubModuleAttr function chases getAttr chains backwards to locate the +// submodules. For example: module M { // attributes { // A = // } @@ -18,6 +17,7 @@ namespace jit { // ... // %weight = prim::GetAttr[name="scale"](%B) // ... + std::deque findSubModuleAttr( Value* input, std::string& name, @@ -26,6 +26,8 @@ std::deque findSubModuleAttr( Node* node = input->node(); std::deque moduleNames; + // Loop starts from inner submodule and follows the chain until reaches the + // top module. while (!(node->outputs()[0]->type() == graph->inputs()[0]->type())) { if (node->kind() == prim::GetAttr) { moduleNames.push_front(node->s(attr::name)); @@ -35,6 +37,7 @@ std::deque findSubModuleAttr( } } + // Assign the inner module to attrModule. for (auto& moduleName : moduleNames) { attrModule = attrModule.attr(moduleName).toModule(); } @@ -87,7 +90,6 @@ std::vector getParamAttributes( auto input = n->inputs()[0]; auto moduleNames = findSubModuleAttr(input, name, attrModule, graph); - if (!attrModule.hasattr(name)) { continue; } @@ -106,31 +108,28 @@ std::vector getParamAttributes( if (type->is_parameter(slot) || type->is_buffer(slot) || name == "training") { - if (type->is_parameter(slot) || type->is_buffer(slot) || - name == "training") { - if (attr.isTensor()) { - TORCH_INTERNAL_ASSERT(attr.isTensor()); - auto tensor_ = attr.toTensor(); - if (isEval && tensor_.requires_grad()) { - tensor_ = tensor_.detach(); - tensor_.set_requires_grad(false); - attr = IValue(tensor_); - } - attrValues.emplace_back(attr.toTensor()); - paramConst = addParamAsArgument(function_, fullName, attr); - } else if (attr.isNone() || name == "training") { - auto attrVal = tryInsertConstant(*graph, attr); - paramConst = *attrVal; + if (attr.isTensor()) { + TORCH_INTERNAL_ASSERT(attr.isTensor()); + auto tensor_ = attr.toTensor(); + if (isEval && tensor_.requires_grad()) { + tensor_ = tensor_.detach(); + tensor_.set_requires_grad(false); + attr = IValue(tensor_); } - n->output()->replaceAllUsesWith(paramConst); - n->removeAllInputs(); - - GRAPH_UPDATE( - "Folding GetAttr %", - n->outputs()[0]->debugName(), - " with ", - paramConst->debugName()); + attrValues.emplace_back(attr.toTensor()); + paramConst = addParamAsArgument(function_, fullName, attr); + } else if (attr.isNone() || name == "training") { + auto attrVal = tryInsertConstant(*graph, attr); + paramConst = *attrVal; } + n->output()->replaceAllUsesWith(paramConst); + n->removeAllInputs(); + + GRAPH_UPDATE( + "Folding GetAttr %", + n->outputs()[0]->debugName(), + " with ", + paramConst->debugName()); } } } From 843bac0d6c3de50d30235f7a9b7c75c2e04135b3 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Wed, 18 Nov 2020 10:20:24 -0800 Subject: [PATCH 08/21] Update torch/csrc/jit/passes/onnx/list_model_parameters.cpp Co-authored-by: Bowen Bao --- torch/csrc/jit/passes/onnx/list_model_parameters.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index 495cddc8f26f..62325e7d0811 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -28,7 +28,7 @@ std::deque findSubModuleAttr( // Loop starts from inner submodule and follows the chain until reaches the // top module. - while (!(node->outputs()[0]->type() == graph->inputs()[0]->type())) { + while (node->outputs().at(0)->type() != graph->inputs().at(0)->type()) { if (node->kind() == prim::GetAttr) { moduleNames.push_front(node->s(attr::name)); node = node->inputs()[0]->node(); From 144cafb4e0bebef01619e6ad28b1655d09415407 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Wed, 18 Nov 2020 10:21:41 -0800 Subject: [PATCH 09/21] Enable new APIs by default --- test/onnx/test_pytorch_onnx_onnxruntime.py | 40 +++++++++------------- torch/onnx/utils.py | 6 ++-- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 70181b10b059..5cf47d3cd31f 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -472,6 +472,7 @@ def forward(self, x_in): x = {"test_key_in": torch.randn(1, 2, 3)} self.run_test(MyModel(), (x,)) + @disableScriptTest() def test_none_as_input(self): class Model(torch.nn.Module): def forward(self, x, y): @@ -482,6 +483,7 @@ def forward(self, x, y): x = torch.randn(2, 3) self.run_test(Model(), (x, None)) + @disableScriptTest() def test_none_as_tuple_input(self): class Model(torch.nn.Module): def forward(self, x, y): @@ -495,6 +497,7 @@ def forward(self, x, y): y = torch.randn(2, 3) self.run_test(Model(), (x, (None, y))) + @disableScriptTest() def test_none_as_named_input(self): class Model(torch.nn.Module): def forward(self, x, y=None, z=None): @@ -678,23 +681,11 @@ def __init__(self): def forward(self, input1, input2, input3): return self.conv1(input1), self.conv2(input2), self.conv3(input3) - class ScriptModel(torch.jit.ScriptModule): - def __init__(self): - super(ScriptModel, self).__init__() - self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2) - self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) - self.conv3 = torch.nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)) - - @torch.jit.script_method - def forward(self, input1, input2, input3): - return self.conv1(input1), self.conv2(input2), self.conv3(input3) - x1 = torch.randn(20, 16, 50) x2 = torch.randn(20, 16, 50, 100) x3 = torch.randn(20, 16, 10, 50, 100) self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5) - self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5) def test_conv_shape_inference(self): class Model(torch.nn.Module): @@ -721,23 +712,11 @@ def __init__(self): def forward(self, input1, input2, input3): return self.conv1(input1), self.conv2(input2), self.conv3(input3) - class ScriptModel(torch.jit.ScriptModule): - def __init__(self): - super(ScriptModel, self).__init__() - self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2) - self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) - self.conv3 = torch.nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0)) - - @torch.jit.script_method - def forward(self, input1, input2, input3): - return self.conv1(input1), self.conv2(input2), self.conv3(input3) - x1 = torch.randn(20, 16, 50) x2 = torch.randn(20, 16, 50, 100) x3 = torch.randn(20, 16, 10, 50, 100) self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5) - self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5) # Conversion of Transpose depends on input shape to be known. # The following test only works when onnx shape inference is enabled. @@ -5095,6 +5074,19 @@ def forward(self, x): [np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in zip(ort_outs1, ort_outs2)] + script_model = torch.jit.script(model) + outputs = model(x) + ort_sess1 = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version, + example_outputs=outputs, use_new_jit_passes=True, + training=torch.onnx.TrainingMode.TRAINING) + ort_outs1 = run_ort(ort_sess1, input=(x,)) + ort_sess2 = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version, + example_outputs=outputs, use_new_jit_passes=True, + training=torch.onnx.TrainingMode.EVAL) + ort_outs2 = run_ort(ort_sess2, input=(x,)) + [np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in + zip(ort_outs1, ort_outs2)] + def test_multiple_conv_bn(self): class MyModule(torch.nn.Module): def __init__(self): diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index a3bfec39c5d2..5c41306b9ee2 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -121,7 +121,7 @@ def _split_tensor_list_constants(g, block): def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=False, fixed_batch_size=False, - params_dict=None, use_new_jit_passes=False, dynamic_axes=None, input_names=None): + params_dict=None, use_new_jit_passes=True, dynamic_axes=None, input_names=None): # Inline everything torch._C._jit_pass_inline(graph) @@ -396,7 +396,7 @@ def _model_to_graph(model, args, verbose=False, example_outputs=None, _retain_param_name=False, do_constant_folding=True, _disable_torch_constant_prop=False, fixed_batch_size=False, - training=None, use_new_jit_passes=False, + training=None, use_new_jit_passes=True, dynamic_axes=None): from torch.onnx.symbolic_helper import _export_onnx_opset_version # Special case for common case of passing a single Tensor @@ -586,7 +586,7 @@ def _export(model, args, f, export_params=True, verbose=False, training=None, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None, fixed_batch_size=False, custom_opsets=None, add_node_names=True, enable_onnx_checker=True, use_external_data_format=False, - onnx_shape_inference=True, use_new_jit_passes=False): + onnx_shape_inference=True, use_new_jit_passes=True): if isinstance(model, torch.nn.DataParallel): raise ValueError('torch.nn.DataParallel is not supported by ONNX ' From dee7395931e3bd86053c797c5c311e3fb9642172 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Wed, 18 Nov 2020 10:31:56 -0800 Subject: [PATCH 10/21] Adding batch_norm training test --- test/onnx/test_pytorch_onnx_onnxruntime.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 5cf47d3cd31f..054a5e7e27fd 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -4990,8 +4990,16 @@ def forward(self, x): ort_sess = convert_to_onnx(model_export, input=(x,), opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING) ort_outs = run_ort(ort_sess, input=(x,)) - [np.testing.assert_allclose(p_out, ort_out, atol=10e-3, rtol=10e-3) for p_out, ort_out in zip(pytorch_out, ort_outs)] + + model_export = torch.jit.script(MyModule()) + ort_sess = convert_to_onnx(model_export, input=(x,), opset_version=self.opset_version, + example_outputs=out, + training=torch.onnx.TrainingMode.TRAINING, + use_new_jit_passes=True, onnx_shape_inference=True) + ort_outs = run_ort(ort_sess, input=(x,)) + [np.testing.assert_allclose(p_out, ort_out, atol=10e-3, rtol=10e-3) for p_out, ort_out in + zip(pytorch_out, ort_outs)] @skipIfUnsupportedMinOpsetVersion(12) def test_dropout_training(self): From 00e1cf0f96bbd5f82ea50d4dc442b8162f2286ce Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Wed, 18 Nov 2020 10:43:19 -0800 Subject: [PATCH 11/21] Adding dropout training test --- test/onnx/test_pytorch_onnx_onnxruntime.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 054a5e7e27fd..9a38d30d6eb6 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -5021,6 +5021,14 @@ def forward(self, x): training=torch.onnx.TrainingMode.TRAINING) ort_outs = run_ort(ort_sess, input=(x,)) assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0]))) + + script_model = torch.jit.script(model) + output = model(x) + ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version, + example_outputs=output, use_new_jit_passes=True, + training=torch.onnx.TrainingMode.TRAINING) + ort_outs = run_ort(ort_sess, input=(x,)) + assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0]))) @skipIfUnsupportedMinOpsetVersion(12) def test_dropout_training_zero(self): @@ -5050,7 +5058,21 @@ def forward(self, x): y = model(input) output = y.cpu().numpy() + ort_mask = np.where(ort_outs[0] != 0, 1, 0) + pyt_mask = np.where(output != 0, 1, 0) + ratio_pytorch = np.sum(pyt_mask) / nb_elements + ratio_ort = np.sum(ort_mask) / nb_elements + + np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01) + + script_model = torch.jit.script(model) + y = model(input) + output = y.cpu().numpy() + ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version, + example_outputs=y, use_new_jit_passes=True, + training=torch.onnx.TrainingMode.TRAINING) + ort_outs = run_ort(ort_sess, input=(x,)) ort_mask = np.where(ort_outs[0] != 0, 1, 0) pyt_mask = np.where(output != 0, 1, 0) From 559bdec146d8281f08aa78bfc76bc9925f6ce8fc Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Wed, 18 Nov 2020 17:10:36 -0800 Subject: [PATCH 12/21] Update utils.py --- torch/onnx/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 5c41306b9ee2..239ea3b81741 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -406,6 +406,11 @@ def _model_to_graph(model, args, verbose=False, if isinstance(example_outputs, torch.Tensor): example_outputs = [example_outputs] + # Temporarily disable the new APIs for quantized models until + # we fix feeze_module to handle the attributes + if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK: + use_new_jit_passes=False + graph, params, torch_out = _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes) From e6fb0c5ee17c159948a93bb6a0f269d6d5a1b622 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Wed, 18 Nov 2020 21:31:38 -0800 Subject: [PATCH 13/21] rebase --- torch/onnx/utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 239ea3b81741..5c41306b9ee2 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -406,11 +406,6 @@ def _model_to_graph(model, args, verbose=False, if isinstance(example_outputs, torch.Tensor): example_outputs = [example_outputs] - # Temporarily disable the new APIs for quantized models until - # we fix feeze_module to handle the attributes - if operator_export_type == OperatorExportTypes.ONNX_ATEN_FALLBACK: - use_new_jit_passes=False - graph, params, torch_out = _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes) From 26b3df5c2ac4d354f0b9c3a9a7f3a504e955dd59 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Thu, 19 Nov 2020 10:01:45 -0800 Subject: [PATCH 14/21] clang --- test/onnx/test_pytorch_onnx_onnxruntime.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 9a38d30d6eb6..39b7656433d5 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -4991,7 +4991,7 @@ def forward(self, x): training=torch.onnx.TrainingMode.TRAINING) ort_outs = run_ort(ort_sess, input=(x,)) [np.testing.assert_allclose(p_out, ort_out, atol=10e-3, rtol=10e-3) for p_out, ort_out in zip(pytorch_out, ort_outs)] - + model_export = torch.jit.script(MyModule()) ort_sess = convert_to_onnx(model_export, input=(x,), opset_version=self.opset_version, example_outputs=out, @@ -5021,7 +5021,7 @@ def forward(self, x): training=torch.onnx.TrainingMode.TRAINING) ort_outs = run_ort(ort_sess, input=(x,)) assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0]))) - + script_model = torch.jit.script(model) output = model(x) ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version, @@ -5065,7 +5065,7 @@ def forward(self, x): ratio_ort = np.sum(ort_mask) / nb_elements np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01) - + script_model = torch.jit.script(model) y = model(input) output = y.cpu().numpy() From 2c1581bebd58c47b95cb8af2a695656b86363270 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Thu, 19 Nov 2020 14:02:13 -0800 Subject: [PATCH 15/21] Fix for quantization --- torch/csrc/jit/passes/freeze_module.cpp | 4 +++- torch/csrc/jit/passes/onnx/list_model_parameters.cpp | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/passes/freeze_module.cpp b/torch/csrc/jit/passes/freeze_module.cpp index 35be415b1a1e..2778c7712f23 100644 --- a/torch/csrc/jit/passes/freeze_module.cpp +++ b/torch/csrc/jit/passes/freeze_module.cpp @@ -442,7 +442,9 @@ class AttributePropagator { if (!isEval || preserveParameters_) { auto type = attrModule.type(); auto slot = *type->findAttributeSlot(name); - if (type->is_parameter(slot) || type->is_buffer(slot)) { + if (type->is_parameter(slot) || type->is_buffer(slot) || + (attr.isObject() && + !attr.toObjectRef().type()->is_module())) { continue; } else { attr = overrideGradient(attr); diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index 62325e7d0811..95bf9d2fb36a 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -107,6 +107,7 @@ std::vector getParamAttributes( auto slot = *type->findAttributeSlot(name); if (type->is_parameter(slot) || type->is_buffer(slot) || + (attr.isObject() && !attr.toObjectRef().type()->is_module()) || name == "training") { if (attr.isTensor()) { TORCH_INTERNAL_ASSERT(attr.isTensor()); @@ -118,6 +119,12 @@ std::vector getParamAttributes( } attrValues.emplace_back(attr.toTensor()); paramConst = addParamAsArgument(function_, fullName, attr); + } else if ( + attr.isObject() && !attr.toObjectRef().type()->is_module()) { + attrValues.emplace_back( + script::Object(attr.toObject()).run_method("__getstate__")); + paramConst = addParamAsArgument(function_, fullName, attr); + } else if (attr.isNone() || name == "training") { auto attrVal = tryInsertConstant(*graph, attr); paramConst = *attrVal; From f7d92e50fd787856de42abc3e6c8e7a8c8a94c06 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Thu, 19 Nov 2020 23:43:41 -0800 Subject: [PATCH 16/21] Fix for jit test, error when encounter PythonOp --- torch/csrc/jit/passes/onnx/list_model_parameters.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index 95bf9d2fb36a..8573fb9d6618 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -85,6 +86,12 @@ std::vector getParamAttributes( n->s(attr::name) == "num_batches_tracked") { n->destroy(); } else if (n->kind() == prim::GetAttr) { + for (auto use : n->output()->uses()) { + if (use.user->kind() == prim::PythonOp) + throw ErrorReport(n->sourceRange()) + << "Couldn't export Python method."; + } + auto name = n->s(attr::name); auto attrModule = module_; auto input = n->inputs()[0]; From 8747e6b14d5779c9b587f4979de90251039348af Mon Sep 17 00:00:00 2001 From: neginraoof Date: Fri, 20 Nov 2020 19:13:39 -0800 Subject: [PATCH 17/21] clang-tidy fix --- torch/csrc/jit/passes/onnx/list_model_parameters.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index 9ba80306cbe4..cf9aa29b6ea5 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -139,11 +139,7 @@ std::vector getParamAttributes( n->output()->replaceAllUsesWith(paramConst); n->removeAllInputs(); - GRAPH_UPDATE( - "Folding GetAttr %", - n->outputs()[0]->debugName(), - " with ", - paramConst->debugName()); + GRAPH_UPDATE("Folding GetAttr %", n->outputs()[0]->debugName()); } } } @@ -161,6 +157,8 @@ std::pair> list_module_parameters( GRAPH_DEBUG("List attributes for function: " + function->name()); auto graph = function->graph(); auto attributes = getParamAttributes(graph, moduleClone, function); + + modelParams.reserve(attributes.size()); for (auto& attr_ : attributes) { modelParams.push_back(attr_); } From 612c3691877d515f2f0c9a9d8712803512f8cf76 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Fri, 20 Nov 2020 21:26:59 -0800 Subject: [PATCH 18/21] Fix for feedback. --- .../jit/passes/onnx/list_model_parameters.cpp | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index cf9aa29b6ea5..4ab1c1be07d3 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -106,7 +106,7 @@ std::vector getParamAttributes( std::string fullName("self_"); for (auto& name : moduleNames) { - fullName += name + '_'; + fullName += name + '.'; } fullName += name; @@ -128,10 +128,23 @@ std::vector getParamAttributes( paramConst = addParamAsArgument(function_, fullName, attr); } else if ( attr.isObject() && !attr.toObjectRef().type()->is_module()) { + // Only below registered torch classes are supported. + TORCH_CHECK( + (type == + getCustomClass( + "__torch__.torch.classes.quantized.Conv2dPackedParamsBase")) || + (type == + getCustomClass( + "__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) || + (type == + getCustomClass( + "__torch__.torch.classes.quantized.LinearPackedParamsBase")), + "Unknown type ", + type->repr_str(), + " encountered in handling model params. This type is not supported in ONNX export."); attrValues.emplace_back( script::Object(attr.toObject()).run_method("__getstate__")); paramConst = addParamAsArgument(function_, fullName, attr); - } else if (attr.isNone() || name == "training") { auto attrVal = tryInsertConstant(*graph, attr); paramConst = *attrVal; @@ -156,6 +169,8 @@ std::pair> list_module_parameters( GRAPH_DEBUG("List attributes for function: " + function->name()); auto graph = function->graph(); + // Add model_parameters and model_buffers as model inputs. Order is based on + // the appearance in the graph. auto attributes = getParamAttributes(graph, moduleClone, function); modelParams.reserve(attributes.size()); From d7966fc0ed326af240bae6e4b11681741303842e Mon Sep 17 00:00:00 2001 From: neginraoof Date: Fri, 20 Nov 2020 21:29:06 -0800 Subject: [PATCH 19/21] Fix for feedback --- torch/csrc/jit/passes/onnx/list_model_parameters.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index 4ab1c1be07d3..1b057c3d835a 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -33,8 +33,6 @@ std::deque findSubModuleAttr( if (node->kind() == prim::GetAttr) { moduleNames.push_front(node->s(attr::name)); node = node->inputs()[0]->node(); - } else { - return moduleNames; } } From 41a384334aceae346bbb3591973e39383f28ade5 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Tue, 1 Dec 2020 17:52:14 -0800 Subject: [PATCH 20/21] Update list_model_parameters.cpp --- torch/csrc/jit/passes/onnx/list_model_parameters.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp index dd63c159ad82..4f0b4e2b1437 100644 --- a/torch/csrc/jit/passes/onnx/list_model_parameters.cpp +++ b/torch/csrc/jit/passes/onnx/list_model_parameters.cpp @@ -102,7 +102,7 @@ std::vector getParamAttributes( auto attr = attrModule.attr(name); - std::string fullName("self_"); + std::string fullName(""); for (auto& name : moduleNames) { fullName += name + '.'; } From 71b301a273d21f15f71b4ff78154a3f5d6f0ff01 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Tue, 1 Dec 2020 17:54:18 -0800 Subject: [PATCH 21/21] Adding test for params --- test/onnx/test_utility_funs.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 88daef3d5fb0..5c1bfe8b5515 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -678,6 +678,31 @@ def forward(self, x): assert len(params_dict) == 2 + def test_scripting_param(self): + class MyModule(torch.nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True) + self.bn = torch.nn.BatchNorm2d(16, affine=True) + + def forward(self, x): + x = self.conv(x) + bn = self.bn(x) + return bn + + model = torch.jit.script(MyModule()) + x = torch.randn(10, 3, 128, 128) + example_outputs = model(x) + f = io.BytesIO() + _set_opset_version(self.opset_version) + _set_operator_export_type(OperatorExportTypes.ONNX) + graph, _, __ = utils._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs, + operator_export_type=OperatorExportTypes.ONNX) + + graph_input_params = [param.debugName() for param in graph.inputs()] + assert all(item in graph_input_params for item in dict(model.named_parameters())), \ + "Graph parameter names does not match model parameters." + def test_modifying_params(self): class MyModel(torch.nn.Module): def __init__(self):