-
Notifications
You must be signed in to change notification settings - Fork 683
support argmax/argmin without dim kwargs and fix adaptive_max_pool3d #14710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/14710
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 23 PendingAs of commit 3a5bff8 with merge base 0e74a17 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case Differential Revision: D83606497
0918c6a
to
74d7e98
Compare
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case Differential Revision: D83606497
74d7e98
to
c6642fa
Compare
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case Differential Revision: D83606497
c6642fa
to
e1de9d0
Compare
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case Differential Revision: D83606497
e1de9d0
to
f2f84e4
Compare
from executorch.exir.pass_base import ExportPass, PassResult | ||
from executorch.exir.passes import dead_code_elimination_pass | ||
|
||
class InsertReshapeForArgmax(ExportPass): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might have other ops like argmin
which can leverage this pass as well. Do you mind rewording the class name into more generic one? Something like InsertReshapeForReduceOp
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the new commit wasn't pushed successfully. Could you try again? Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh sorry, just push
graph = graph_module.graph | ||
modified = False | ||
|
||
for n in list(graph.nodes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can use the iterator directly: for n in graph.nodes:
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
import torch | ||
from executorch.backends.qualcomm._passes import InsertReshapeForArgmax | ||
|
||
class TestPasses(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding this file, it's helpful for us to test all the passes thoroughly.
f2f84e4
to
ea4bebf
Compare
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case Differential Revision: D83606497
3f825cf
to
af7ef91
Compare
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case Differential Revision: D83606497
214554c
to
5077d00
Compare
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case Differential Revision: D83606497
5077d00
to
a982ee6
Compare
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case Differential Revision: D83606497
Can I get another round of review on this? I fix a few more failing tests including |
a982ee6
to
553be81
Compare
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case Differential Revision: D83606497
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case Differential Revision: D83606497
553be81
to
4087c68
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
to_be_implemented_operator = [ | ||
exir_ops.edge.aten._adaptive_avg_pool3d.default, | ||
exir_ops.edge.aten.adaptive_max_pool2d.default, | ||
exir_ops.edge.aten.adaptive_max_pool3d.default, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this pr just about argmax? i see maxpool added here. if it is about reduce ops then please do update the title
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah updated
…ytorch#14710) Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case Differential Revision: D83606497
4087c68
to
3a5bff8
Compare
@pytorchbot cherry-pick --onto release/1.0 -c regression |
…14710) Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case edit: 1. Apply to argmin too 2. Add `exir_ops.edge.aten.adaptive_max_pool3d.default` to the to be implemented op list to pass the error Differential Revision: D83606497 (cherry picked from commit e09abea)
Cherry picking #14710The cherry pick PR is at #14868 and it is recommended to link a regression cherry pick PR with an issue. The following tracker issues are updated: Details for Dev Infra teamRaised by workflow job |
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case
edit:
exir_ops.edge.aten.adaptive_max_pool3d.default
to the to be implemented op list to pass the errorDifferential Revision: D83606497