Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Track and list model params for scripting #47348

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f2b417b
Track and list model params for freezed module
neginraoof Nov 4, 2020
466a36c
Fix for training attr
neginraoof Nov 12, 2020
2aa4180
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Nov 16, 2020
5fa6f05
Update list_model_parameters.cpp
neginraoof Nov 16, 2020
5515ea3
Update list_model_parameters.cpp
neginraoof Nov 16, 2020
dfdf28c
Clang fix
neginraoof Nov 16, 2020
028c54e
Update list_model_parameters.cpp
neginraoof Nov 16, 2020
d7e1dee
Fix for feedback
neginraoof Nov 18, 2020
843bac0
Update torch/csrc/jit/passes/onnx/list_model_parameters.cpp
neginraoof Nov 18, 2020
144cafb
Enable new APIs by default
neginraoof Nov 18, 2020
dee7395
Adding batch_norm training test
neginraoof Nov 18, 2020
00e1cf0
Adding dropout training test
neginraoof Nov 18, 2020
559bdec
Update utils.py
neginraoof Nov 19, 2020
b570c36
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Nov 19, 2020
e6fb0c5
rebase
neginraoof Nov 19, 2020
67021af
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Nov 19, 2020
26b3df5
clang
neginraoof Nov 19, 2020
2c1581b
Fix for quantization
neginraoof Nov 19, 2020
eebf713
Merge branch 'neraoof/scripting_params' of github.com:neginraoof/pyto…
neginraoof Nov 19, 2020
f7d92e5
Fix for jit test, error when encounter PythonOp
neginraoof Nov 20, 2020
37f9051
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Nov 20, 2020
8747e6b
clang-tidy fix
neginraoof Nov 21, 2020
612c369
Fix for feedback.
neginraoof Nov 21, 2020
ba4c63c
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Nov 21, 2020
d7966fc
Fix for feedback
neginraoof Nov 21, 2020
65ecd22
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Nov 21, 2020
41a3843
Update list_model_parameters.cpp
neginraoof Dec 2, 2020
71b301a
Adding test for params
neginraoof Dec 2, 2020
79f4a4c
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Dec 3, 2020
90309a6
Merge branch 'neraoof/scripting_params' of github.com:neginraoof/pyto…
neginraoof Dec 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 47 additions & 25 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -472,6 +472,7 @@ def forward(self, x_in):
x = {"test_key_in": torch.randn(1, 2, 3)}
self.run_test(MyModel(), (x,))

@disableScriptTest()
def test_none_as_input(self):
class Model(torch.nn.Module):
def forward(self, x, y):
Expand All @@ -482,6 +483,7 @@ def forward(self, x, y):
x = torch.randn(2, 3)
self.run_test(Model(), (x, None))

@disableScriptTest()
def test_none_as_tuple_input(self):
class Model(torch.nn.Module):
def forward(self, x, y):
Expand All @@ -495,6 +497,7 @@ def forward(self, x, y):
y = torch.randn(2, 3)
self.run_test(Model(), (x, (None, y)))

@disableScriptTest()
def test_none_as_named_input(self):
class Model(torch.nn.Module):
def forward(self, x, y=None, z=None):
Expand Down Expand Up @@ -678,23 +681,11 @@ def __init__(self):
def forward(self, input1, input2, input3):
return self.conv1(input1), self.conv2(input2), self.conv3(input3)

class ScriptModel(torch.jit.ScriptModule):
def __init__(self):
super(ScriptModel, self).__init__()
self.conv1 = torch.nn.Conv1d(16, 33, 3, stride=2)
self.conv2 = torch.nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
self.conv3 = torch.nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))

@torch.jit.script_method
def forward(self, input1, input2, input3):
return self.conv1(input1), self.conv2(input2), self.conv3(input3)

x1 = torch.randn(20, 16, 50)
x2 = torch.randn(20, 16, 50, 100)
x3 = torch.randn(20, 16, 10, 50, 100)

self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)

def test_conv_shape_inference(self):
class Model(torch.nn.Module):
Expand All @@ -721,23 +712,11 @@ def __init__(self):
def forward(self, input1, input2, input3):
return self.conv1(input1), self.conv2(input2), self.conv3(input3)

class ScriptModel(torch.jit.ScriptModule):
def __init__(self):
super(ScriptModel, self).__init__()
self.conv1 = torch.nn.ConvTranspose1d(16, 33, 3, stride=2)
self.conv2 = torch.nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
self.conv3 = torch.nn.ConvTranspose3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0))

@torch.jit.script_method
def forward(self, input1, input2, input3):
return self.conv1(input1), self.conv2(input2), self.conv3(input3)

x1 = torch.randn(20, 16, 50)
x2 = torch.randn(20, 16, 50, 100)
x3 = torch.randn(20, 16, 10, 50, 100)

self.run_test(TraceModel(), (x1, x2, x3), atol=10e-5)
self.run_test(ScriptModel(), (x1, x2, x3), atol=10e-5)

# Conversion of Transpose depends on input shape to be known.
# The following test only works when onnx shape inference is enabled.
Expand Down Expand Up @@ -5021,9 +5000,17 @@ def forward(self, x):
ort_sess = convert_to_onnx(model_export, input=(x,), opset_version=self.opset_version,
training=torch.onnx.TrainingMode.TRAINING)
ort_outs = run_ort(ort_sess, input=(x,))

[np.testing.assert_allclose(p_out, ort_out, atol=10e-3, rtol=10e-3) for p_out, ort_out in zip(pytorch_out, ort_outs)]

model_export = torch.jit.script(MyModule())
ort_sess = convert_to_onnx(model_export, input=(x,), opset_version=self.opset_version,
example_outputs=out,
training=torch.onnx.TrainingMode.TRAINING,
use_new_jit_passes=True, onnx_shape_inference=True)
ort_outs = run_ort(ort_sess, input=(x,))
[np.testing.assert_allclose(p_out, ort_out, atol=10e-3, rtol=10e-3) for p_out, ort_out in
zip(pytorch_out, ort_outs)]

@skipIfUnsupportedMinOpsetVersion(12)
def test_dropout_training(self):
class MyModule(torch.nn.Module):
Expand All @@ -5045,6 +5032,14 @@ def forward(self, x):
ort_outs = run_ort(ort_sess, input=(x,))
assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))

script_model = torch.jit.script(model)
output = model(x)
ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version,
example_outputs=output, use_new_jit_passes=True,
training=torch.onnx.TrainingMode.TRAINING)
ort_outs = run_ort(ort_sess, input=(x,))
assert not torch.all(torch.eq(x, torch.from_numpy(ort_outs[0])))

@skipIfUnsupportedMinOpsetVersion(12)
def test_dropout_training_zero(self):
class MyModule(torch.nn.Module):
Expand Down Expand Up @@ -5073,7 +5068,21 @@ def forward(self, x):

y = model(input)
output = y.cpu().numpy()
ort_mask = np.where(ort_outs[0] != 0, 1, 0)
pyt_mask = np.where(output != 0, 1, 0)

ratio_pytorch = np.sum(pyt_mask) / nb_elements
ratio_ort = np.sum(ort_mask) / nb_elements

np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)

script_model = torch.jit.script(model)
y = model(input)
output = y.cpu().numpy()
ort_sess = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version,
example_outputs=y, use_new_jit_passes=True,
training=torch.onnx.TrainingMode.TRAINING)
ort_outs = run_ort(ort_sess, input=(x,))
ort_mask = np.where(ort_outs[0] != 0, 1, 0)
pyt_mask = np.where(output != 0, 1, 0)

Expand Down Expand Up @@ -5105,6 +5114,19 @@ def forward(self, x):
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in
zip(ort_outs1, ort_outs2)]

script_model = torch.jit.script(model)
outputs = model(x)
ort_sess1 = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version,
example_outputs=outputs, use_new_jit_passes=True,
training=torch.onnx.TrainingMode.TRAINING)
ort_outs1 = run_ort(ort_sess1, input=(x,))
ort_sess2 = convert_to_onnx(script_model, input=(x,), opset_version=self.opset_version,
example_outputs=outputs, use_new_jit_passes=True,
training=torch.onnx.TrainingMode.EVAL)
ort_outs2 = run_ort(ort_sess2, input=(x,))
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in
zip(ort_outs1, ort_outs2)]

def test_multiple_conv_bn(self):
class MyModule(torch.nn.Module):
def __init__(self):
Expand Down
25 changes: 25 additions & 0 deletions test/onnx/test_utility_funs.py
Expand Up @@ -678,6 +678,31 @@ def forward(self, x):

assert len(params_dict) == 2

def test_scripting_param(self):
class MyModule(torch.nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True)
self.bn = torch.nn.BatchNorm2d(16, affine=True)

def forward(self, x):
x = self.conv(x)
bn = self.bn(x)
return bn

model = torch.jit.script(MyModule())
x = torch.randn(10, 3, 128, 128)
example_outputs = model(x)
f = io.BytesIO()
_set_opset_version(self.opset_version)
_set_operator_export_type(OperatorExportTypes.ONNX)
graph, _, __ = utils._model_to_graph(model, (x,), do_constant_folding=True, example_outputs=example_outputs,
operator_export_type=OperatorExportTypes.ONNX)

graph_input_params = [param.debugName() for param in graph.inputs()]
assert all(item in graph_input_params for item in dict(model.named_parameters())), \
"Graph parameter names does not match model parameters."

def test_modifying_params(self):
class MyModel(torch.nn.Module):
def __init__(self):
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.bzl
Expand Up @@ -520,6 +520,7 @@ libtorch_python_core_sources = [
"torch/csrc/jit/passes/onnx/constant_fold.cpp",
"torch/csrc/jit/passes/onnx/eliminate_unused_items.cpp",
"torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp",
"torch/csrc/jit/passes/onnx/list_model_parameters.cpp",
"torch/csrc/jit/passes/onnx/function_substitution.cpp",
"torch/csrc/jit/passes/onnx/helper.cpp",
"torch/csrc/jit/passes/onnx/peephole.cpp",
Expand Down
4 changes: 3 additions & 1 deletion torch/csrc/jit/passes/freeze_module.cpp
Expand Up @@ -442,7 +442,9 @@ class AttributePropagator {
if (!isEval || preserveParameters_) {
auto type = attrModule.type();
auto slot = *type->findAttributeSlot(name);
if (type->is_parameter(slot) || type->is_buffer(slot)) {
if (type->is_parameter(slot) || type->is_buffer(slot) ||
(attr.isObject() &&
!attr.toObjectRef().type()->is_module())) {
continue;
} else {
attr = overrideGradient(attr);
Expand Down
186 changes: 186 additions & 0 deletions torch/csrc/jit/passes/onnx/list_model_parameters.cpp
@@ -0,0 +1,186 @@
#include <torch/csrc/jit/passes/onnx/list_model_parameters.h>
#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>

namespace torch {
namespace jit {

// findSubModuleAttr function chases getAttr chains backwards to locate the
// submodules. For example: module M {
// attributes {
// A = <SubModule at ...>
// }
// ...
// %A = prim::GetAttr[name="A"](%self)
// ...
// %B = prim::GetAttr[name="B"](%A)
// ...
// %weight = prim::GetAttr[name="scale"](%B)
// ...

std::deque<std::string> findSubModuleAttr(
Value* input,
std::string& name,
Module& attrModule,
std::shared_ptr<Graph>& graph) {
Node* node = input->node();
std::deque<std::string> moduleNames;

// Loop starts from inner submodule and follows the chain until reaches the
// top module.
while (node->outputs().at(0)->type() != graph->inputs().at(0)->type()) {
if (node->kind() == prim::GetAttr) {
moduleNames.push_front(node->s(attr::name));
node = node->inputs()[0]->node();
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
}
}

// Assign the inner module to attrModule.
for (auto& moduleName : moduleNames) {
attrModule = attrModule.attr(moduleName).toModule();
}
return moduleNames;
}

Value* addParamAsArgument(Function* function, std::string& name, IValue& attr) {
auto schema = function->getSchema();
auto args = schema.arguments();
args.emplace_back(Argument(name, nullptr, c10::nullopt, attr));
auto new_schema = FunctionSchema(
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
schema.name(),
schema.overload_name(),
args,
schema.returns(),
schema.is_vararg(),
schema.is_varret());
function->setSchema(new_schema);
return function->graph()->addInput(name)->setType(attr.type());
}

std::vector<IValue> getParamAttributes(
std::shared_ptr<Graph>& graph,
const Module& module_,
Function* function_) {
std::vector<IValue> attrValues;
auto isEval = !module_.hasattr("training") || !module_.is_training();
auto block = graph->block();
std::vector<Block*> blocks({block});

Node* m = *block->nodes().begin();
WithInsertPoint guard(m);

while (!blocks.empty()) {
Block* block = blocks.back();
blocks.pop_back();
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
Node* n = *it;
it++; // node n can be destroyed
neginraoof marked this conversation as resolved.
Show resolved Hide resolved

for (Block* sub_block : n->blocks()) {
blocks.emplace_back(sub_block);
}
if (n->kind() == prim::SetAttr &&
n->s(attr::name) == "num_batches_tracked") {
n->destroy();
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
} else if (n->kind() == prim::GetAttr) {
for (auto use : n->output()->uses()) {
if (use.user->kind() == prim::PythonOp)
throw ErrorReport(n->sourceRange())
<< "Couldn't export Python method.";
}

auto name = n->s(attr::name);
auto attrModule = module_;
auto input = n->inputs()[0];

auto moduleNames = findSubModuleAttr(input, name, attrModule, graph);
if (!attrModule.hasattr(name)) {
continue;
}
Value* paramConst = nullptr;

auto attr = attrModule.attr(name);

std::string fullName("");
for (auto& name : moduleNames) {
fullName += name + '.';
}
fullName += name;

auto type = attrModule.type();
auto slot = *type->findAttributeSlot(name);

if (type->is_parameter(slot) || type->is_buffer(slot) ||
(attr.isObject() && !attr.toObjectRef().type()->is_module()) ||
name == "training") {
if (attr.isTensor()) {
TORCH_INTERNAL_ASSERT(attr.isTensor());
auto tensor_ = attr.toTensor();
if (isEval && tensor_.requires_grad()) {
tensor_ = tensor_.detach();
tensor_.set_requires_grad(false);
attr = IValue(tensor_);
}
attrValues.emplace_back(attr.toTensor());
paramConst = addParamAsArgument(function_, fullName, attr);
} else if (
attr.isObject() && !attr.toObjectRef().type()->is_module()) {
BowenBao marked this conversation as resolved.
Show resolved Hide resolved
// Only below registered torch classes are supported.
auto type = attr.type();
TORCH_CHECK(
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase")) ||
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) ||
(type ==
getCustomClass(
"__torch__.torch.classes.quantized.LinearPackedParamsBase")),
"Unknown type ",
type->repr_str(),
" encountered in handling model params. This type is not supported in ONNX export.");
attrValues.emplace_back(
script::Object(attr.toObject()).run_method("__getstate__"));
paramConst = addParamAsArgument(function_, fullName, attr);
} else if (attr.isNone() || name == "training") {
auto attrVal = tryInsertConstant(*graph, attr);
paramConst = *attrVal;
}
n->output()->replaceAllUsesWith(paramConst);
n->removeAllInputs();

GRAPH_UPDATE("Folding GetAttr %", n->outputs()[0]->debugName());
}
}
}
}
return attrValues;
}

std::pair<Module, std::vector<IValue>> list_module_parameters(
const Module& module) {
Module moduleClone = module.clone(true);
Method method = moduleClone.get_method("forward");
auto function = &method.function();
std::vector<IValue> modelParams;

GRAPH_DEBUG("List attributes for function: " + function->name());
auto graph = function->graph();
// Add model_parameters and model_buffers as model inputs. Order is based on
// the appearance in the graph.
auto attributes = getParamAttributes(graph, moduleClone, function);
BowenBao marked this conversation as resolved.
Show resolved Hide resolved

modelParams.reserve(attributes.size());
for (auto& attr_ : attributes) {
modelParams.push_back(attr_);
}
GRAPH_DEBUG("Cleaning up module");
EliminateDeadCode(graph->block());

return std::make_pair(moduleClone, modelParams);
}

} // namespace jit
} // namespace torch
13 changes: 13 additions & 0 deletions torch/csrc/jit/passes/onnx/list_model_parameters.h
@@ -0,0 +1,13 @@
#pragma once

#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

TORCH_API std::pair<Module, std::vector<IValue>> list_module_parameters(
const Module& module);

} // namespace jit
} // namespace torch