[fusion] Migrate away from CustomFuseGraph #72
Conversation
[fusion] Migrate away from CustomFuseGraph gh-metadata: pytorch tvm 72 gh/bwasti/46/head
this depends on pytorch/pytorch#23210 |
[fusion] Migrate away from CustomFuseGraph gh-metadata: pytorch tvm 72 gh/bwasti/46/head
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to use this to compare the tvm fusion group we generated before and after, and looks like we are pulling in redundant constant nodes again and again, which I think we shouldn't. Taking a example:
with tvm::CompilationGroup_1 = graph(%0 : Tensor,
%1 : Float(*, *, *),
%2 : Float(*)):
%11 : int[] = prim::Constant[value=[0]]()
%3 : int[] = prim::Constant[value=[1]]()
%4 : int[] = prim::Constant[value=[0]]()
%5 : int[] = prim::Constant[value=[1]]()
%6 : bool = prim::Constant[value=0]()
%7 : int[] = prim::Constant[value=[0]]()
%8 : int = prim::Constant[value=1]()
%9 : bool = prim::Constant[value=1]()
%x1.1 : Tensor = aten::_convolution(%0, %1, %2, %3, %4, %5, %6, %7, %8, %6, %6, %9) # code/my_noqqq.py:377:8
return (%x1.1)
with tvm::CompilationGroup_2 = graph(%0 : Tensor,
%1 : Float(*, *, *),
%2 : Float(*)):
%11 : int[] = prim::Constant[value=[0]]()
%3 : int[] = prim::Constant[value=[1]]()
%4 : int[] = prim::Constant[value=[0]]()
%5 : int[] = prim::Constant[value=[1]]()
%6 : bool = prim::Constant[value=0]()
%7 : int[] = prim::Constant[value=[0]]()
%8 : int = prim::Constant[value=1]()
%9 : bool = prim::Constant[value=1]()
%x2.1 : Tensor = aten::_convolution(%0, %1, %2, %3, %4, %5, %6, %7, %8, %6, %6, %9) # code/my_noqqq.py:380:8
return (%x2.1)
I thought that the graph executor run constant pooling somewhere but it turned out not running it properly.. this will probably need to do more passes after we generate the fusion group like the graphFuser did:
// After FuseGraph some common subexpressions may come back
EliminateCommonSubexpression(graph);
// We might have emitted a fair amount of useless shape propagating code, so
// remove it
EliminateDeadCode(graph);
// Improve the quality of shape propagation code that was left
PeepholeOptimizeShapeExpressions(graph->block());
Also, we might think of just adding constants as arguments rather than pulling it into the fusion group, this way it's easier to run constant pooling stuff I think.
So I'm ignorant of the underlying goal of not using the graph fuser, but the constants seem to be an indication that it might be nice to try to use as much of PyTorch's capabilities as possible. In contrast to the current master torch_tvm (see #55) and apparently also this PR, the PyTorch fuser treats constants more delicately: It will not fuse them on their own but instead copies constant inputs of operation nodes to be fused into the fusion group and reconnects the fused operation to these copies. With CustomFuseGraph, just removing the fusion of |
@wanchaol @t-vi good points on being careful with @t-vi The issue is that GraphFuser doesn't (and probably shouldn't) handle the more complex cases of aliasing and control flow. For the sake of Relay lowering, we will want to have loops and conditionals as well as tensor views well supported to generate the best code. I put up this diff as the first step in many to get that all working 😃 |
Good point about the control flow which would not be needed for basic graph fusion. I'm sure that people will come up with other uses for control-flow capable fusers if they get them... :) |
[fusion] Migrate away from CustomFuseGraph gh-metadata: pytorch tvm 72 gh/bwasti/46/head
if (auto group = tryMerge(consumer, input->node(), aliasDb)) { | ||
// we successfully merged, so the new group's `inputs` may have | ||
// changed. So rescan the new group for more merging opportunities. | ||
return group.value()->reverseIterator(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has_value check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it has not, the if will evaluate to false.
[fusion] Migrate away from CustomFuseGraph gh-metadata: pytorch tvm 72 gh/bwasti/46/head
@bwasti Could you add some reasoning on this PR? Like why do we need to migrate away from custom graph fuse. |
[fusion] Migrate away from CustomFuseGraph gh-metadata: pytorch tvm 72 gh/bwasti/46/head
[fusion] Migrate away from CustomFuseGraph gh-metadata: pytorch tvm 72 gh/bwasti/46/head
Is this pr ready to be merged now? |
@yinghai I added some notes above in response to @t-vi who had a similar question. Largely it comes down to proper handling of alias information and enabling control flow in the future @zrphercule I believe so, yes. I will add a few more tests on coverage |
#include <torch/csrc/jit/passes/graph_fuser.h> | ||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h> | ||
|
||
void FuseSupportedOps(std::shared_ptr<torch::jit::Graph> graph); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, is this function TVM specific? If not, can we move it to pytorch repro?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a good catch. For the sake of efficiency we can "stage" a lot of simple functionality like this here and upstream it to PyTorch at some point.
I'd like to prioritize flushing out functionality before trying to standardize it into an API. That being said, I've tried to make it very copy-and-paste-able 😛
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very copy-and-paste-able
Lol sounds useful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jokes aside, it's great to have code to follow.
There are some things like the control flow support that this will have that Glow won't be able to support at least for a while most likely so sharing exact code would require some abstraction of those things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That being said, I've tried to make it very copy-and-paste-able
That's something that can be landed directly into Pytorch in the spirit of hackability of PyTorch. Lol
bool canHandle(Block* block, AliasDb& aliasDb) { | ||
for (Node* node : block->nodes()) { | ||
if (!canHandle(node, aliasDb)) { | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes blocks "all or nothing" right? Will this take away some of the ease of handing operators off between pytorch and tvm for things tvm doesn't support? Or maybe if this is happening in a loop it's not really desirable anyways?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The traditional fuser works on blocks, traversing them in the fusion pass. This here seems more about whether we can fuse everything in the block in order to fuse the entire block-using control flow statement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand what you mean, that's what I'm saying. It looks the the traditional fuser recurses on sub-blocks fusing what it can within them whereas this will only fuse a sub-block if all nodes within it are fusable (recursively). So I think the traditional fuser could for example fuse most of a loop body while leaving the any unsupported nodes in the loop body unfused wheras this I think will try to fuse the entire loop or nothing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, this means any single fusion attempt will be all or nothing. We can add attempts to recursively fuse on blocks that weren't fused on a previous attempt (in a later diff maybe?) to recover the CustomGraphFuse behavior
#include <torch/csrc/jit/passes/graph_fuser.h> | ||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h> | ||
|
||
void FuseSupportedOps(std::shared_ptr<torch::jit::Graph> graph); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jokes aside, it's great to have code to follow.
There are some things like the control flow support that this will have that Glow won't be able to support at least for a while most likely so sharing exact code would require some abstraction of those things.
gh-metadata: pytorch tvm 72 gh/bwasti/46/head Pull Request resolved: #72
Summary: This is basiclly the glow version of pytorch/tvm#72 Will not use PyTorch's customFuseNode anymore. Will add comment indicate the copied code and fix the lint once finished. Please dont give detailed review until WIP is removed, but feel free to leave any big-scope opinion. Pull Request resolved: pytorch#3403 Differential Revision: D16775646 fbshipit-source-id: 90873346feff60876602473b303a7883a1370b26
Summary: This is basiclly the glow version of pytorch/tvm#72 Will not use PyTorch's customFuseNode anymore. Will add comment indicate the copied code and fix the lint once finished. Please dont give detailed review until WIP is removed, but feel free to leave any big-scope opinion. Pull Request resolved: #3403 Reviewed By: jackm321 Differential Revision: D16775646 Pulled By: zrphercule fbshipit-source-id: a6d4dd757bf0db2ec0f4092330962b7e7fdf241d
Stack from ghstack: