-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SR] cover static runtime for ops that uses getitem instead of listunpack #124148
Conversation
This appears to be a diff that was exported from phabricator, but the PR author does not have sufficient permissions to run CI. @lakshmananrm1993, please do step 2 of internal wiki to get write access so you do not need to get CI approvals in the future. If you think this is a mistake, please contact the Pytorch Dev Infra team. |
|
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124148
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit e7bbf7d with merge base 402b289 (): FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D55924256 |
737e3a7
to
e195b9c
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
e195b9c
to
12bbced
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall concerned about it being easier to create cases it looks like this code didn't consider using __getitem__
Node* list_unpack_node = value_out->uses()[0].user; | ||
if (list_unpack_node->kind() != prim::ListUnpack) { | ||
Node* user_node = value_out->uses()[0].user; | ||
if (user_node->kind() != prim::ListUnpack && user_node->kind() != aten::__getitem__) { | ||
continue; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this worries me. 1) what if there is more than one __getitem__
? as written it looks like we will just leave the other ones around?
2) what if the number of __getitem__
calls in the input program doesn't match the size of the output list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i changed it to work only fb::equally_split. This has an argument num_splits and I am using that to do all the following checks
- same as number of uses of this module
- all the uses are get_item
- get_item uses is not duplicated
only after this I do the fuse
|
||
std::vector<std::tuple<Node*, ArrayRef<Value*>>> user_to_remove; | ||
|
||
for(auto index = 0; index < value_out->uses().size(); index++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use c10::irange
12bbced
to
5d382e8
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
5d382e8
to
b5de8c0
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
b5de8c0
to
46d83ce
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
46d83ce
to
a2632cc
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
I have changed it to cover the scenarios. I do the fuse only if the following validations succeed and I changed it to work only for equally_split same as number of uses of this module |
|
||
const auto num_splits_ivalue = toIValue(value_num_splits); | ||
|
||
// check if equally_split num of splits is 1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment repeats the code; delete
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
// Its only for fb::equally_split operator | ||
// the operators results are being accessed by get_item rather than ListUnpack | ||
void FuseEquallySplitGetItemUnpack(std::shared_ptr<torch::jit::Graph>& graph) { | ||
// replacement contains (old_node, new_node, list_unpack_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stale comment, delete
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
const Value* value_num_splits = node->inputs()[1]; | ||
|
||
// check if equally_split num of splits is equal to the number of output uses | ||
if(!toIValue(value_num_splits).has_value() || !toIValue(value_num_splits).value().isInt() || toIValue(value_num_splits).value().toInt() != value_out->uses().size()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't call toIValue repeatedly; we should not be requiring LTO for good performance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replaced it with constant_as()
|
||
const Value* value_num_splits = node->inputs()[1]; | ||
|
||
// check if equally_split num of splits is equal to the number of output uses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment repeats the code; delete
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
break; | ||
} | ||
|
||
auto getItem_indice_out = toIValue(user->inputs()[1]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be getItem_index_out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use constant_as<int>()
as recommended above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
bool should_fuse = true; | ||
|
||
std::vector<UserToRemove> user_to_remove_list; | ||
|
||
std::unordered_set<std::int64_t> getItem_indices; | ||
|
||
for(auto index = 0; index < value_out->uses().size(); index++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: please do not put so many blank lines; keep style consistent with the surrounding code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed the not needed blank lines
|
||
std::vector<UserToRemove> user_to_remove_list; | ||
|
||
std::unordered_set<std::int64_t> getItem_indices; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use c10::FastSet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
LOG(INFO) << "Found fb::equally_split node2: " << user->kind().toQualString(); | ||
|
||
getItem_indice_out.emplace(getItem_indice_out); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this does nothing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should have been getItem_Indices. Updated it now
break; | ||
} | ||
|
||
LOG(INFO) << "Found fb::equally_split node2: " << user->kind().toQualString(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why node2 and not just node?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was testing comment. removed it
for(const auto& replacement : replacement_list) { | ||
auto* new_node = graph->create(replacement.new_sym, 0); | ||
for (Value* in : replacement.node_to_be_fused->inputs()) { | ||
new_node->addInput(in); | ||
} | ||
for(const auto& user_to_remove_node : replacement.user_to_remove) { | ||
for (Value* out : user_to_remove_node.user_outputs) { | ||
Value* new_out = new_node->addOutput(); | ||
new_out->copyMetadata(out); | ||
out->replaceAllUsesWith(new_out); | ||
} | ||
user_to_remove_node.user->destroy(); | ||
} | ||
new_node->insertAfter(replacement.node_to_be_fused); | ||
replacement.node_to_be_fused->destroy(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be easier for me to review this for correctness if it ended up looking the same as the FuseListUnpack path. If you need to send a separate stacked PR to get rid of the std::tuple in FuseListUnpack, go for it
a2632cc
to
b186caa
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
0afc7ba
to
a98d4f2
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
…pack (pytorch#124148) Summary: cover static runtime for ops that uses getitem instead of listunpack. Reviewed By: snabelkabiya, swolchok Differential Revision: D55924256
a98d4f2
to
6939277
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
6939277
to
670ee82
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
670ee82
to
dd9ac1c
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
dd9ac1c
to
57694c3
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
…pack (pytorch#124148) Summary: cover static runtime for ops that uses getitem instead of listunpack. Reviewed By: snabelkabiya, swolchok Differential Revision: D55924256
57694c3
to
6a2d81c
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
6a2d81c
to
9be6709
Compare
…pack (pytorch#124148) Summary: cover static runtime for ops that uses getitem instead of listunpack. Reviewed By: snabelkabiya, swolchok Differential Revision: D55924256
This pull request was exported from Phabricator. Differential Revision: D55924256 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D55924256 |
9be6709
to
8721217
Compare
…pack (pytorch#124148) Summary: Pull Request resolved: pytorch#124148 cover static runtime for ops that uses getitem instead of listunpack. Reviewed By: snabelkabiya, swolchok Differential Revision: D55924256
8721217
to
e7bbf7d
Compare
This pull request was exported from Phabricator. Differential Revision: D55924256 |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Summary: cover static runtime for ops that uses getitem instead of listunpack
Differential Revision: D55924256