Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Adding a pass to replace interpolate function with aten::__interpolate #35744

Closed
wants to merge 53 commits into from
Closed
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
f5ad704
Added inverse op export
neginraoof Mar 24, 2020
568cd75
Adding the pass to replace aten op
neginraoof Mar 26, 2020
df0ccd5
Fixed for script module
neginraoof Mar 30, 2020
d821e90
Refactoring names
neginraoof Mar 31, 2020
9a1d882
Default
neginraoof Mar 31, 2020
4727b25
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Mar 31, 2020
3bff668
Fixing nested sripts
neginraoof Mar 31, 2020
76ee514
clean up
neginraoof Mar 31, 2020
937f03b
Fix for build
neginraoof Mar 31, 2020
c1ef3ef
clang
neginraoof Mar 31, 2020
4d058ab
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Mar 31, 2020
8765be6
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Mar 31, 2020
32771e8
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Mar 31, 2020
d8932ec
clean up headers
neginraoof Mar 31, 2020
3f33375
Enabling tests
neginraoof Mar 31, 2020
fc8afec
Added comments
neginraoof Apr 2, 2020
6d52af8
clang format
neginraoof Apr 2, 2020
cb93982
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Apr 2, 2020
73f8d0d
Change to update optimized_graph after pre-inline pass
neginraoof Apr 2, 2020
2ee6526
Add pass to utils
neginraoof Apr 2, 2020
63e0d4d
Merge branch 'neraoof/interpolate' of github.com:neginraoof/pytorch i…
neginraoof Apr 2, 2020
9bf80c2
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Apr 2, 2020
1e36322
Add option to inliner to use_graph instead of optimized_graph
neginraoof Apr 2, 2020
031e2b2
Comments/description
neginraoof Apr 2, 2020
dc935e7
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Apr 2, 2020
bbdb725
clang
neginraoof Apr 2, 2020
0be5c7f
Fix for comments
neginraoof Apr 2, 2020
492a7ab
Fixes for build
neginraoof Apr 3, 2020
d0be8b6
comment
neginraoof Apr 3, 2020
92a6fcf
clang
neginraoof Apr 3, 2020
e79332e
qual name not used
neginraoof Apr 3, 2020
d2bf3a8
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Apr 3, 2020
bd03739
clang
neginraoof Apr 3, 2020
6222f41
Adding nms as well
neginraoof Apr 3, 2020
25ca0ac
Disable keypoint rcnn due to bug in size
neginraoof Apr 3, 2020
8c1886e
Update inliner.h
neginraoof Apr 3, 2020
5a2ad6c
Update utils.py
neginraoof Apr 3, 2020
ec1f405
clang
neginraoof Apr 3, 2020
bfd9e5a
Merge branch 'neraoof/interpolate' of github.com:neginraoof/pytorch i…
neginraoof Apr 3, 2020
c2ec3f2
Fix for feedback. Changed the pass name + removed nms
neginraoof Apr 3, 2020
ee99925
changed use_graph arg based on comments
neginraoof Apr 3, 2020
8aa3d2b
Fix for CallMethod
neginraoof Apr 8, 2020
bf6aadd
new line
neginraoof Apr 8, 2020
e4986a7
formatting
neginraoof Apr 8, 2020
d334b65
Refactor files
neginraoof Apr 8, 2020
7313674
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Apr 8, 2020
4c9371b
Added tests
neginraoof Apr 8, 2020
ab38d36
Fixed logs
neginraoof Apr 9, 2020
104cab3
Fixed interpolate tests
neginraoof Apr 9, 2020
787190a
Add opset 9
neginraoof Apr 9, 2020
6303a92
Merge branch 'master' into neraoof/interpolate
neginraoof Apr 10, 2020
14328b8
Update test_pytorch_onnx_onnxruntime.py
neginraoof Apr 10, 2020
9df600a
Fixed missing TORCH_API
neginraoof Apr 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions caffe2/CMakeLists.txt
Expand Up @@ -405,6 +405,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/passes/inline_autodiff_subgraphs.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/insert_guards.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/inliner.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/preclude_inlining.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/lift_closures.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/inline_forked_closures.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/dead_code_elimination.cpp
Expand Down
3 changes: 0 additions & 3 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -346,7 +346,6 @@ def get_test_images(self):
return images

@skipIfUnsupportedMinOpsetVersion(11)
@unittest.skip("disabled due to removal of aten::__interpolate")
def test_mask_rcnn(self):
model = torchvision.models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200,
max_size=300)
Expand Down Expand Up @@ -1198,8 +1197,6 @@ def forward(self, x):
self.run_test(MyModel(), x)

def _interpolate_script(self, x, mode, use_size, is_upsample, align_corners=False):
# test disabled
return

class MyModel(torch.jit.ScriptModule):
__constants__ = ['mode', 'use_size', 'is_upsample', 'size', 'scale', 'size_array', 'scale_array', 'align_corners']
Expand Down
1 change: 1 addition & 0 deletions torch/CMakeLists.txt
Expand Up @@ -71,6 +71,7 @@ set(TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/jit/python/init.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/fixup_onnx_conditionals.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/preclude_inlining.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/fixup_onnx_loop.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/helper.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp
Expand Down
28 changes: 21 additions & 7 deletions torch/csrc/jit/ir/ir.cpp
Expand Up @@ -1820,16 +1820,30 @@ at::ArrayRef<Value*> createTupleUnpack(Value* v) {
return g.insertNode(g.createTupleUnpack(v))->outputs();
}

std::vector<Value*> inlineCallTo(Node* to_replace, Function* callee) {
// inline_optimized_graph argument is used to preclude inlining functions for
// ONNX conversion
std::vector<Value*> inlineCallTo(
Node* to_replace,
Function* callee,
bool inline_optimized_graph /*=true*/) {
WithInsertPoint guard(to_replace);
TORCH_INTERNAL_ASSERT(callee->isGraphFunction());
std::unordered_map<Value*, Value*> value_map;
auto new_outputs = insertGraph(
*to_replace->owningGraph(),
*(callee->optimized_graph()),
to_replace->inputs(),
value_map);

std::vector<torch::jit::Value*> new_outputs;

if (inline_optimized_graph) {
new_outputs = insertGraph(
*to_replace->owningGraph(),
*(callee->optimized_graph()),
to_replace->inputs(),
value_map);
} else {
new_outputs = insertGraph(
*to_replace->owningGraph(),
*(callee->graph()),
to_replace->inputs(),
value_map);
}
std::unordered_map<InlinedCallStack*, InlinedCallStackPtr>
new_callstack_entries;

Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/jit/ir/ir.h
Expand Up @@ -1367,7 +1367,10 @@ TORCH_API std::vector<Value*> insertGraph(
* This asserts that the number of outputs of the original node and the
* graph are the same.
*/
TORCH_API std::vector<Value*> inlineCallTo(Node* to_replace, Function* callee);
TORCH_API std::vector<Value*> inlineCallTo(
Node* to_replace,
Function* callee,
bool use_graph = true);

/** If there is only one value in \p OUTPUTS and its kind is Tuple, insert a
* tuple unpack node and return the resulting values.
Expand Down
87 changes: 87 additions & 0 deletions torch/csrc/jit/passes/onnx/preclude_inlining.cpp
@@ -0,0 +1,87 @@
#include <torch/csrc/jit/passes/onnx/preclude_inlining.h>
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/onnx/helper.h>

namespace torch {
namespace jit {

void functionCallSubstitution(Block* block) {
for (auto it = block->nodes().begin(), end = block->nodes().end();
it != end;) {
Node* cur = *it++;
switch (cur->kind()) {
case prim::CallFunction: {
AT_ASSERT(cur->input(0)->node()->kind() == prim::Constant);
auto function_constant = cur->input(0)->node();
auto fun_type =
function_constant->output()->type()->expect<FunctionType>();

if ((fun_type->function()->qualname().qualifiedName().find(
"torch.nn.functional") != std::string::npos) &&
(fun_type->function()->qualname().qualifiedName().find(
"interpolate") != std::string::npos)) {
cur->removeInput(0);
Node* interpolate_node = block->owningGraph()->create(
Symbol::fromQualString("aten::__interpolate"),
{cur->inputs()},
cur->outputs().size());
interpolate_node->output()->copyMetadata(cur->output());
interpolate_node->insertAfter(cur);
cur->replaceAllUsesWith(interpolate_node);
cur->removeAllInputs();
cur->destroy();
} else {
cur->removeInput(0);
functionCallSubstitution(fun_type->function()->graph()->block());
GRAPH_UPDATE(
"Inlining in ONNX preclude inlining function '",
fun_type->function()->name(),
"' to ",
*cur);
GRAPH_UPDATE(
"Function in ONNX preclude inlining body: ",
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
*fun_type->function()->optimized_graph());
inlineCallTo(cur, fun_type->function(), false);
}
} break;
case prim::CallMethod: {
const std::string& name = cur->s(attr::name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for the new code path ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm adding a test in onnxruntime backend tests to cover this case (interpolate call within submodule). Let me know if you think of a more generalized way of testing this.

if (auto class_type = cur->input(0)->type()->cast<ClassType>()) {
auto function = class_type->getMethod(name);
if (!function->isGraphFunction()) {
continue;
}
functionCallSubstitution(function->graph()->block());
GRAPH_UPDATE(
"Inlining in ONNX preclude inlining function '",
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
function->name(),
"' to ",
*cur);
GRAPH_UPDATE(
"Function in ONNX preclude inlining body: ",
function->optimized_graph());
inlineCallTo(cur, function, false);
}
} break;
default: {
for (auto b : cur->blocks()) {
functionCallSubstitution(b);
}
} break;
}
}
}

// This pass is to be used for ONNX conversion only. The ONNX converter depends
// on a number of deprecated aten operators. These operators are removed from IR
// and replaced by the compiled python function code. However, in-order to
// maintain the behavior for ONNX conversion, we replace these function calls
// with the aten symbolic which can still be used by the ONNX converter.
void ONNXFunctionCallSubstitution(Graph& graph) {
GRAPH_DUMP("Before stop-inlining calls: ", &graph);
functionCallSubstitution(graph.block());
GRAPH_DUMP("After stop-inlining calls: ", &graph);
}

} // namespace jit
} // namespace torch
11 changes: 11 additions & 0 deletions torch/csrc/jit/passes/onnx/preclude_inlining.h
@@ -0,0 +1,11 @@
#pragma once

#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

void ONNXFunctionCallSubstitution(Graph& graph);

}
} // namespace torch
2 changes: 2 additions & 0 deletions torch/csrc/jit/python/init.cpp
Expand Up @@ -31,6 +31,7 @@
#include <torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.h>
#include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
#include <torch/csrc/jit/passes/onnx/peephole.h>
#include <torch/csrc/jit/passes/onnx/preclude_inlining.h>
#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
#include <torch/csrc/jit/passes/onnx/prepare_inplace_ops_for_onnx.h>
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
Expand Down Expand Up @@ -122,6 +123,7 @@ void initJITBindings(PyObject* module) {
.def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops)
.def("_jit_pass_onnx", ToONNX)
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
.def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution)
.def(
"_jit_pass_onnx_peephole",
[](std::shared_ptr<Graph>& graph,
Expand Down
9 changes: 7 additions & 2 deletions torch/onnx/utils.py
Expand Up @@ -333,7 +333,9 @@ def _model_to_graph(model, args, verbose=False,
if isinstance(model, torch.jit.ScriptModule):
assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule"
try:
method_graph, params = torch._C._jit_pass_lower_graph(model.forward.graph, model._c)
graph = model.forward.graph
torch._C._jit_pass_onnx_function_substitution(graph)
method_graph, params = torch._C._jit_pass_lower_graph(graph, model._c)
in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
graph = _propagate_and_assign_input_shapes(
method_graph, tuple(in_vars), False, propagate)
Expand All @@ -344,8 +346,10 @@ def _model_to_graph(model, args, verbose=False,
method = model
params = ()
in_vars, in_desc = torch.jit._flatten(tuple(args))
graph = model.graph
torch._C._jit_pass_onnx_function_substitution(graph)
graph = _propagate_and_assign_input_shapes(
model.graph, tuple(in_vars), False, propagate)
graph, tuple(in_vars), False, propagate)
else:
graph, torch_out = _trace_and_get_graph_from_model(model, args)
state_dict = _unique_state_dict(model)
Expand All @@ -357,6 +361,7 @@ def _model_to_graph(model, args, verbose=False,
for i, inp in enumerate(graph_inputs):
if i >= user_input_num:
inp.setDebugName(param_names[i - user_input_num])
torch._C._jit_pass_onnx_function_substitution(graph)

input_and_param_names = [val.debugName() for val in graph.inputs()]
param_names = input_and_param_names[len(input_and_param_names) - len(params):]
Expand Down