Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Adding a pass to replace interpolate function with aten::__interpolate #35744

Closed
wants to merge 53 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
f5ad704
Added inverse op export
neginraoof Mar 24, 2020
568cd75
Adding the pass to replace aten op
neginraoof Mar 26, 2020
df0ccd5
Fixed for script module
neginraoof Mar 30, 2020
d821e90
Refactoring names
neginraoof Mar 31, 2020
9a1d882
Default
neginraoof Mar 31, 2020
4727b25
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Mar 31, 2020
3bff668
Fixing nested sripts
neginraoof Mar 31, 2020
76ee514
clean up
neginraoof Mar 31, 2020
937f03b
Fix for build
neginraoof Mar 31, 2020
c1ef3ef
clang
neginraoof Mar 31, 2020
4d058ab
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Mar 31, 2020
8765be6
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Mar 31, 2020
32771e8
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Mar 31, 2020
d8932ec
clean up headers
neginraoof Mar 31, 2020
3f33375
Enabling tests
neginraoof Mar 31, 2020
fc8afec
Added comments
neginraoof Apr 2, 2020
6d52af8
clang format
neginraoof Apr 2, 2020
cb93982
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Apr 2, 2020
73f8d0d
Change to update optimized_graph after pre-inline pass
neginraoof Apr 2, 2020
2ee6526
Add pass to utils
neginraoof Apr 2, 2020
63e0d4d
Merge branch 'neraoof/interpolate' of github.com:neginraoof/pytorch i…
neginraoof Apr 2, 2020
9bf80c2
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Apr 2, 2020
1e36322
Add option to inliner to use_graph instead of optimized_graph
neginraoof Apr 2, 2020
031e2b2
Comments/description
neginraoof Apr 2, 2020
dc935e7
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Apr 2, 2020
bbdb725
clang
neginraoof Apr 2, 2020
0be5c7f
Fix for comments
neginraoof Apr 2, 2020
492a7ab
Fixes for build
neginraoof Apr 3, 2020
d0be8b6
comment
neginraoof Apr 3, 2020
92a6fcf
clang
neginraoof Apr 3, 2020
e79332e
qual name not used
neginraoof Apr 3, 2020
d2bf3a8
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Apr 3, 2020
bd03739
clang
neginraoof Apr 3, 2020
6222f41
Adding nms as well
neginraoof Apr 3, 2020
25ca0ac
Disable keypoint rcnn due to bug in size
neginraoof Apr 3, 2020
8c1886e
Update inliner.h
neginraoof Apr 3, 2020
5a2ad6c
Update utils.py
neginraoof Apr 3, 2020
ec1f405
clang
neginraoof Apr 3, 2020
bfd9e5a
Merge branch 'neraoof/interpolate' of github.com:neginraoof/pytorch i…
neginraoof Apr 3, 2020
c2ec3f2
Fix for feedback. Changed the pass name + removed nms
neginraoof Apr 3, 2020
ee99925
changed use_graph arg based on comments
neginraoof Apr 3, 2020
8aa3d2b
Fix for CallMethod
neginraoof Apr 8, 2020
bf6aadd
new line
neginraoof Apr 8, 2020
e4986a7
formatting
neginraoof Apr 8, 2020
d334b65
Refactor files
neginraoof Apr 8, 2020
7313674
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof Apr 8, 2020
4c9371b
Added tests
neginraoof Apr 8, 2020
ab38d36
Fixed logs
neginraoof Apr 9, 2020
104cab3
Fixed interpolate tests
neginraoof Apr 9, 2020
787190a
Add opset 9
neginraoof Apr 9, 2020
6303a92
Merge branch 'master' into neraoof/interpolate
neginraoof Apr 10, 2020
14328b8
Update test_pytorch_onnx_onnxruntime.py
neginraoof Apr 10, 2020
9df600a
Fixed missing TORCH_API
neginraoof Apr 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions caffe2/CMakeLists.txt
Expand Up @@ -405,6 +405,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/passes/inline_autodiff_subgraphs.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/insert_guards.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/inliner.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/preinline_onnx.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/lift_closures.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/inline_forked_closures.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/dead_code_elimination.cpp
Expand Down
1 change: 1 addition & 0 deletions torch/CMakeLists.txt
Expand Up @@ -71,6 +71,7 @@ set(TORCH_PYTHON_SRCS
${TORCH_SRC_DIR}/csrc/jit/python/init.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/fixup_onnx_conditionals.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/preinline_onnx.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/fixup_onnx_loop.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/helper.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/api/function_impl.cpp
@@ -1,5 +1,6 @@
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/passes/onnx/preinline_onnx.h>

#include <torch/csrc/jit/frontend/error_report.h>

Expand Down Expand Up @@ -64,6 +65,7 @@ const c10::FunctionSchema& GraphFunction::getSchema() const {
void preoptimizeGraph(std::shared_ptr<Graph>& graph) {
// TODO: Invoke cleanup passes before and after inlining to reduce amount of
// code we're copying.
PreInlineONNX(*graph);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be running every time we run inlining on a graph, just as an explicitly invoked pass by ONNX before ONNX runs inlining.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ONNX Calls the tracer which internally inlines functions. I'm trying to see where is the best place to put the pre-inline pass. Maybe somewhere in the tracer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we looked at the example together, seems like tracer doesn't eagerly inline so it should be possible outside of the tracer.

@torch.jit.script
def fn(x):
    return x + 2
@torch.jit.script
def fn2(x):
    return fn(x) + 3
def fn3(x):
    return x + fn2(x)
traced = torch.jit.trace(fn3, (torch.rand(3, 4),))
print(traced.graph)

graph(%x : Double(3, 4)):
  %1 : Function = prim::Constant[name="fn2"]()
  %2 : Tensor = prim::CallFunction(%1, %x)
  %3 : int = prim::Constant[value=1]() # test/test_jit.py:16104:0
  %4 : Double(3, 4) = aten::add(%x, %2, %3) # test/test_jit.py:16104:0
  return (%4)


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually putting this preinline pass before calling inline does not fix the graph.
So the problem is with function->optimized_graph
This optimized_graph is actually inlined within the tracer, and it does not get updated when the function->graph is updated.
I maybe able to add an API to update the optimize_graph.

Here it is:

std::shared_ptr<Graph> optimized_graph() const override {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@neginraoof neginraoof Apr 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this pass is modifying the function graph, not the function optimized_graph. And the latter is used by the downstream code. Let me know if it's easier to have a quick call about this.

Inline(*graph);
}

Expand Down
52 changes: 52 additions & 0 deletions torch/csrc/jit/passes/onnx/preinline_onnx.cpp
@@ -0,0 +1,52 @@
#include <torch/csrc/jit/passes/onnx/preinline_onnx.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/onnx/helper.h>

namespace torch {
namespace jit {

void replaceFunctions(Node* to_replace, Function* callee) {
if (callee->name() == "interpolate") {
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
to_replace->removeInput(0);
Node* interpolate_node = to_replace->owningGraph()->create(
Symbol::fromQualString("aten::__interpolate"),
{to_replace->inputs()},
to_replace->outputs().size());
interpolate_node->output()->copyMetadata(to_replace->output());
interpolate_node->insertAfter(to_replace);
to_replace->replaceAllUsesWith(interpolate_node);
to_replace->removeAllInputs();
to_replace->destroy();
return;
}
}

void PreInlineCalls(Block* block) {
for (auto it = block->nodes().begin(), end = block->nodes().end();
it != end;) {
Node* cur = *it++;
switch (cur->kind()) {
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
case prim::CallFunction: {
AT_ASSERT(cur->input(0)->node()->kind() == prim::Constant);
auto function_constant = cur->input(0)->node();
auto fun_type =
function_constant->output()->type()->expect<FunctionType>();
replaceFunctions(cur, fun_type->function());
} break;
default: {
for (auto b : cur->blocks()) {
PreInlineCalls(b);
}
} break;
}
}
}

void PreInlineONNX(Graph& graph) {
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
GRAPH_DUMP("Before Pre-inlining: ", &graph);
PreInlineCalls(graph.block());
GRAPH_DUMP("After Pre-inlining: ", &graph);
}

} // namespace jit
} // namespace torch
11 changes: 11 additions & 0 deletions torch/csrc/jit/passes/onnx/preinline_onnx.h
@@ -0,0 +1,11 @@
#pragma once

#include <torch/csrc/jit/ir/ir.h>

namespace torch {
namespace jit {

void PreInlineONNX(Graph& graph);
neginraoof marked this conversation as resolved.
Show resolved Hide resolved

}
} // namespace torch
2 changes: 2 additions & 0 deletions torch/csrc/jit/python/init.cpp
Expand Up @@ -31,6 +31,7 @@
#include <torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.h>
#include <torch/csrc/jit/passes/onnx/fixup_onnx_loop.h>
#include <torch/csrc/jit/passes/onnx/peephole.h>
#include <torch/csrc/jit/passes/onnx/preinline_onnx.h>
#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
#include <torch/csrc/jit/passes/onnx/prepare_inplace_ops_for_onnx.h>
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
Expand Down Expand Up @@ -122,6 +123,7 @@ void initJITBindings(PyObject* module) {
.def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops)
.def("_jit_pass_onnx", ToONNX)
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
.def("_jit_pass_onnx_preinline", PreInlineONNX)
.def(
"_jit_pass_onnx_peephole",
[](std::shared_ptr<Graph>& graph,
Expand Down
8 changes: 6 additions & 2 deletions torch/onnx/utils.py
Expand Up @@ -334,7 +334,9 @@ def _model_to_graph(model, args, verbose=False,
if isinstance(model, torch.jit.ScriptModule):
assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule"
try:
method_graph, params = torch._C._jit_pass_lower_graph(model.forward.graph, model._c)
graph = model.forward.graph
torch._C._jit_pass_onnx_preinline(graph)
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
method_graph, params = torch._C._jit_pass_lower_graph(graph, model._c)
in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
graph = _propagate_and_assign_input_shapes(
method_graph, tuple(in_vars), False, propagate)
Expand All @@ -345,8 +347,10 @@ def _model_to_graph(model, args, verbose=False,
method = model
params = ()
in_vars, in_desc = torch.jit._flatten(tuple(args))
graph = model.graph
torch._C._jit_pass_onnx_preinline(graph)
graph = _propagate_and_assign_input_shapes(
model.graph, tuple(in_vars), False, propagate)
graph, tuple(in_vars), False, propagate)
else:
graph, torch_out = _trace_and_get_graph_from_model(model, args)
state_dict = _unique_state_dict(model)
Expand Down