New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Adding a pass to replace interpolate function with aten::__interpolate #35744
Closed
Closed
Changes from 14 commits
Commits
Show all changes
53 commits
Select commit
Hold shift + click to select a range
f5ad704
Added inverse op export
neginraoof 568cd75
Adding the pass to replace aten op
neginraoof df0ccd5
Fixed for script module
neginraoof d821e90
Refactoring names
neginraoof 9a1d882
Default
neginraoof 4727b25
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof 3bff668
Fixing nested sripts
neginraoof 76ee514
clean up
neginraoof 937f03b
Fix for build
neginraoof c1ef3ef
clang
neginraoof 4d058ab
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof 8765be6
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof 32771e8
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof d8932ec
clean up headers
neginraoof 3f33375
Enabling tests
neginraoof fc8afec
Added comments
neginraoof 6d52af8
clang format
neginraoof cb93982
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof 73f8d0d
Change to update optimized_graph after pre-inline pass
neginraoof 2ee6526
Add pass to utils
neginraoof 63e0d4d
Merge branch 'neraoof/interpolate' of github.com:neginraoof/pytorch i…
neginraoof 9bf80c2
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof 1e36322
Add option to inliner to use_graph instead of optimized_graph
neginraoof 031e2b2
Comments/description
neginraoof dc935e7
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof bbdb725
clang
neginraoof 0be5c7f
Fix for comments
neginraoof 492a7ab
Fixes for build
neginraoof d0be8b6
comment
neginraoof 92a6fcf
clang
neginraoof e79332e
qual name not used
neginraoof d2bf3a8
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof bd03739
clang
neginraoof 6222f41
Adding nms as well
neginraoof 25ca0ac
Disable keypoint rcnn due to bug in size
neginraoof 8c1886e
Update inliner.h
neginraoof 5a2ad6c
Update utils.py
neginraoof ec1f405
clang
neginraoof bfd9e5a
Merge branch 'neraoof/interpolate' of github.com:neginraoof/pytorch i…
neginraoof c2ec3f2
Fix for feedback. Changed the pass name + removed nms
neginraoof ee99925
changed use_graph arg based on comments
neginraoof 8aa3d2b
Fix for CallMethod
neginraoof bf6aadd
new line
neginraoof e4986a7
formatting
neginraoof d334b65
Refactor files
neginraoof 7313674
Merge branch 'master' of https://github.com/pytorch/pytorch into nera…
neginraoof 4c9371b
Added tests
neginraoof ab38d36
Fixed logs
neginraoof 104cab3
Fixed interpolate tests
neginraoof 787190a
Add opset 9
neginraoof 6303a92
Merge branch 'master' into neraoof/interpolate
neginraoof 14328b8
Update test_pytorch_onnx_onnxruntime.py
neginraoof 9df600a
Fixed missing TORCH_API
neginraoof File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
#include <torch/csrc/jit/passes/onnx/preinline_onnx.h> | ||
#include <torch/csrc/jit/jit_log.h> | ||
#include <torch/csrc/jit/passes/onnx/helper.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
void replaceFunctions(Node* to_replace, Function* callee) { | ||
if (callee->name() == "interpolate") { | ||
neginraoof marked this conversation as resolved.
Show resolved
Hide resolved
|
||
to_replace->removeInput(0); | ||
Node* interpolate_node = to_replace->owningGraph()->create( | ||
Symbol::fromQualString("aten::__interpolate"), | ||
{to_replace->inputs()}, | ||
to_replace->outputs().size()); | ||
interpolate_node->output()->copyMetadata(to_replace->output()); | ||
interpolate_node->insertAfter(to_replace); | ||
to_replace->replaceAllUsesWith(interpolate_node); | ||
to_replace->removeAllInputs(); | ||
to_replace->destroy(); | ||
return; | ||
} | ||
} | ||
|
||
void PreInlineCalls(Block* block) { | ||
for (auto it = block->nodes().begin(), end = block->nodes().end(); | ||
it != end;) { | ||
Node* cur = *it++; | ||
switch (cur->kind()) { | ||
neginraoof marked this conversation as resolved.
Show resolved
Hide resolved
|
||
case prim::CallFunction: { | ||
AT_ASSERT(cur->input(0)->node()->kind() == prim::Constant); | ||
auto function_constant = cur->input(0)->node(); | ||
auto fun_type = | ||
function_constant->output()->type()->expect<FunctionType>(); | ||
replaceFunctions(cur, fun_type->function()); | ||
} break; | ||
default: { | ||
for (auto b : cur->blocks()) { | ||
PreInlineCalls(b); | ||
} | ||
} break; | ||
} | ||
} | ||
} | ||
|
||
void PreInlineONNX(Graph& graph) { | ||
neginraoof marked this conversation as resolved.
Show resolved
Hide resolved
|
||
GRAPH_DUMP("Before Pre-inlining: ", &graph); | ||
PreInlineCalls(graph.block()); | ||
GRAPH_DUMP("After Pre-inlining: ", &graph); | ||
} | ||
|
||
} // namespace jit | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#pragma once | ||
|
||
#include <torch/csrc/jit/ir/ir.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
void PreInlineONNX(Graph& graph); | ||
neginraoof marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
} | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be running every time we run inlining on a graph, just as an explicitly invoked pass by ONNX before ONNX runs inlining.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ONNX Calls the tracer which internally inlines functions. I'm trying to see where is the best place to put the pre-inline pass. Maybe somewhere in the tracer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we looked at the example together, seems like tracer doesn't eagerly inline so it should be possible outside of the tracer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually putting this preinline pass before calling inline does not fix the graph.
So the problem is with function->optimized_graph
This optimized_graph is actually inlined within the tracer, and it does not get updated when the function->graph is updated.
I maybe able to add an API to update the optimize_graph.
Here it is:
pytorch/torch/csrc/jit/api/function_impl.h
Line 37 in 86f3305
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What doesn't work if you put the pass here ?
https://github.com/pytorch/pytorch/blob/master/torch/onnx/utils.py#L115
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because this pass is modifying the function graph, not the function optimized_graph. And the latter is used by the downstream code. Let me know if it's easier to have a quick call about this.