Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Track and list model params for scripting #47348

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f2b417b
Track and list model params for freezed module
neginraoof Nov 4, 2020
466a36c
Fix for training attr
neginraoof Nov 12, 2020
2aa4180
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Nov 16, 2020
5fa6f05
Update list_model_parameters.cpp
neginraoof Nov 16, 2020
5515ea3
Update list_model_parameters.cpp
neginraoof Nov 16, 2020
dfdf28c
Clang fix
neginraoof Nov 16, 2020
028c54e
Update list_model_parameters.cpp
neginraoof Nov 16, 2020
d7e1dee
Fix for feedback
neginraoof Nov 18, 2020
843bac0
Update torch/csrc/jit/passes/onnx/list_model_parameters.cpp
neginraoof Nov 18, 2020
144cafb
Enable new APIs by default
neginraoof Nov 18, 2020
dee7395
Adding batch_norm training test
neginraoof Nov 18, 2020
00e1cf0
Adding dropout training test
neginraoof Nov 18, 2020
559bdec
Update utils.py
neginraoof Nov 19, 2020
b570c36
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Nov 19, 2020
e6fb0c5
rebase
neginraoof Nov 19, 2020
67021af
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Nov 19, 2020
26b3df5
clang
neginraoof Nov 19, 2020
2c1581b
Fix for quantization
neginraoof Nov 19, 2020
eebf713
Merge branch 'neraoof/scripting_params' of github.com:neginraoof/pyto…
neginraoof Nov 19, 2020
f7d92e5
Fix for jit test, error when encounter PythonOp
neginraoof Nov 20, 2020
37f9051
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Nov 20, 2020
8747e6b
clang-tidy fix
neginraoof Nov 21, 2020
612c369
Fix for feedback.
neginraoof Nov 21, 2020
ba4c63c
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Nov 21, 2020
d7966fc
Fix for feedback
neginraoof Nov 21, 2020
65ecd22
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Nov 21, 2020
41a3843
Update list_model_parameters.cpp
neginraoof Dec 2, 2020
71b301a
Adding test for params
neginraoof Dec 2, 2020
79f4a4c
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Dec 3, 2020
90309a6
Merge branch 'neraoof/scripting_params' of github.com:neginraoof/pyto…
neginraoof Dec 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Expand Up @@ -503,6 +503,7 @@ libtorch_python_core_sources = [
"torch/csrc/jit/passes/onnx/constant_fold.cpp",
"torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp",
"torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp",
"torch/csrc/jit/passes/onnx/list_model_parameters.cpp",
"torch/csrc/jit/passes/onnx/function_substitution.cpp",
"torch/csrc/jit/passes/onnx/helper.cpp",
"torch/csrc/jit/passes/onnx/peephole.cpp",
Expand Down
164 changes: 164 additions & 0 deletions torch/csrc/jit/passes/onnx/list_model_parameters.cpp
@@ -0,0 +1,164 @@
#include <torch/csrc/jit/passes/onnx/list_model_parameters.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>

namespace torch {
namespace jit {

// findSubModuleAttr function chases getAttr chains to locate the submodules.
// For example:
// module M {
// attributes {
// A = <SubModule at ...>
// }
// ...
// %A = prim::GetAttr[name="A"](%self)
// ...
// %B = prim::GetAttr[name="B"](%A)
// ...
// %weight = prim::GetAttr[name="scale"](%B)
// ...
std::deque<std::string> findSubModuleAttr(
Value* input,
std::string& name,
Module& attrModule,
std::shared_ptr<Graph>& graph) {
Node* node = input->node();
std::deque<std::string> moduleNames;

while (!(node->outputs()[0]->type() == graph->inputs()[0]->type())) {
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
if (node->kind() == prim::GetAttr) {
moduleNames.push_front(node->s(attr::name));
node = node->inputs()[0]->node();
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
} else {
return moduleNames;
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
}
}

for (auto& moduleName : moduleNames) {
attrModule = attrModule.attr(moduleName).toModule();
}
return moduleNames;
}

Value* addParamAsArgument(Function* function, std::string& name, IValue& attr) {
auto schema = function->getSchema();
auto args = schema.arguments();
args.emplace_back(Argument(name, nullptr, c10::nullopt, attr));
auto new_schema = FunctionSchema(
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
schema.name(),
schema.overload_name(),
args,
schema.returns(),
schema.is_vararg(),
schema.is_varret());
function->setSchema(new_schema);
return function->graph()->addInput(name)->setType(attr.type());
}

std::vector<IValue> getParamAttributes(
std::shared_ptr<Graph>& graph,
const Module& module_,
Function* function_) {
std::vector<IValue> attrValues;
auto isEval = !module_.hasattr("training") || !module_.is_training();
auto block = graph->block();
std::vector<Block*> blocks({block});

Node* m = *block->nodes().begin();
WithInsertPoint guard(m);

while (!blocks.empty()) {
Block* block = blocks.back();
blocks.pop_back();
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
Node* n = *it;
it++; // node n can be destroyed
neginraoof marked this conversation as resolved.
Show resolved Hide resolved

for (Block* sub_block : n->blocks()) {
blocks.emplace_back(sub_block);
}
if (n->kind() == prim::SetAttr &&
n->s(attr::name) == "num_batches_tracked") {
n->destroy();
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
} else if (n->kind() == prim::GetAttr) {
auto name = n->s(attr::name);
auto attrModule = module_;
auto input = n->inputs()[0];

auto moduleNames = findSubModuleAttr(input, name, attrModule, graph);

if (!attrModule.hasattr(name)) {
continue;
}
Value* paramConst = nullptr;

auto attr = attrModule.attr(name);

std::string fullName("self_");
for (auto& name : moduleNames) {
fullName += name + '_';
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please ensure parameter names are the same as named_parameters(). Can we add tests in test_utility?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So named_parameters() and named_buffers() do not have the "self." prefix as part of the name.
So param name can be: %self.conv.weight for a named_parameter conv.weight.
Do you think to remove "self." from param names?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, as long as they match with named_parameters() and named_buffers().

}
fullName += name;

auto type = attrModule.type();
auto slot = *type->findAttributeSlot(name);

if (type->is_parameter(slot) || type->is_buffer(slot) ||
name == "training") {
if (type->is_parameter(slot) || type->is_buffer(slot) ||
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
name == "training") {
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
if (attr.isTensor()) {
TORCH_INTERNAL_ASSERT(attr.isTensor());
auto tensor_ = attr.toTensor();
if (isEval && tensor_.requires_grad()) {
tensor_ = tensor_.detach();
tensor_.set_requires_grad(false);
attr = IValue(tensor_);
}
attrValues.emplace_back(attr.toTensor());
paramConst = addParamAsArgument(function_, fullName, attr);
} else if (attr.isNone() || name == "training") {
auto attrVal = tryInsertConstant(*graph, attr);
paramConst = *attrVal;
}
n->output()->replaceAllUsesWith(paramConst);
n->removeAllInputs();

GRAPH_UPDATE(
"Folding GetAttr %",
n->outputs()[0]->debugName(),
" with ",
paramConst->debugName());
}
}
}
}
}
return attrValues;
}

std::pair<Module, std::vector<IValue>> list_module_parameters(
const Module& module) {
Module moduleClone = module.clone(true);
Method method = moduleClone.get_method("forward");
std::unordered_set<Function*> preservedMethods_;
preservedMethods_.insert(&method.function());

std::vector<IValue> modelParams;
for (auto function : preservedMethods_) {
GRAPH_DEBUG("List attributes for function: " + function->name());
auto graph = function->graph();
auto attributes = getParamAttributes(graph, moduleClone, function);
for (auto& attr_ : attributes) {
modelParams.push_back(attr_);
}
GRAPH_DEBUG("Cleaning up module");
EliminateDeadCode(graph->block());
}

return std::make_pair(moduleClone, modelParams);
}

} // namespace jit
} // namespace torch
13 changes: 13 additions & 0 deletions torch/csrc/jit/passes/onnx/list_model_parameters.h
@@ -0,0 +1,13 @@
#pragma once

#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

TORCH_API std::pair<Module, std::vector<IValue>> list_module_parameters(
const Module& module);

} // namespace jit
} // namespace torch
4 changes: 4 additions & 0 deletions torch/csrc/jit/python/init.cpp
Expand Up @@ -39,6 +39,7 @@
#include <torch/csrc/jit/passes/onnx/eval_peephole.h>
#include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
#include <torch/csrc/jit/passes/onnx/function_substitution.h>
#include <torch/csrc/jit/passes/onnx/list_model_parameters.h>
#include <torch/csrc/jit/passes/onnx/peephole.h>
#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
#include <torch/csrc/jit/passes/onnx/preprocess_for_onnx.h>
Expand Down Expand Up @@ -279,6 +280,9 @@ void initJITBindings(PyObject* module) {
"_jit_pass_quant_fusion",
[](std::shared_ptr<Graph>& g) { return QuantFusion(g); })
.def("_jit_pass_fold_convbn", &FoldConvBatchNorm)
.def(
"_jit_onnx_list_model_parameters",
[](Module& module) { return list_module_parameters(module); })
.def(
"_freeze_module",
[](Module& module,
Expand Down
4 changes: 2 additions & 2 deletions torch/onnx/utils.py
Expand Up @@ -358,10 +358,10 @@ def _create_jit_graph(model, args, _retain_param_name, use_new_jit_passes):
if not use_new_jit_passes:
method_graph, params = torch._C._jit_pass_lower_graph(graph, model._c)
else:
freezed_m = torch._C._freeze_module(model._c)
freezed_m = torch._C._freeze_module(model._c, preserveParameters=True)
freezed_m, params = torch._C._jit_onnx_list_model_parameters(freezed_m)
method_graph = freezed_m._get_method('forward').graph
method_graph.eraseInput(0) # Remove 'self' from model inputs
params = []

in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
graph = _propagate_and_assign_input_shapes(
Expand Down