-
Notifications
You must be signed in to change notification settings - Fork 684
support argmax/argmin without dim kwargs and fix adaptive_max_pool3d #14710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from executorch.exir.pass_base import ExportPass, PassResult | ||
from executorch.exir.passes import dead_code_elimination_pass | ||
|
||
|
||
class InsertReshapeForReduceOps(ExportPass): | ||
""" | ||
Rewrite `aten.argmax.default` with `dim=None` into | ||
a reshape-to-1D followed by argmax(dim=0). | ||
PyTorch semantics: | ||
torch.argmax(x, dim=None) -> flatten(x) then argmax along axis=0 | ||
QNN requires an explicit axis, so we insert the reshape. | ||
""" | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.op_map = {torch.ops.aten.argmax.default, torch.ops.aten.argmin.default} | ||
|
||
def call(self, graph_module: torch.fx.GraphModule): | ||
graph = graph_module.graph | ||
modified = False | ||
|
||
for n in graph.nodes: | ||
if n.target in self.op_map: | ||
dim_arg = None if len(n.args) == 1 else n.args[1] | ||
|
||
if dim_arg is None: | ||
inp = n.args[0] | ||
|
||
# Insert reshape before argmax | ||
with graph.inserting_before(n): | ||
reshape_node = graph.create_node( | ||
"call_function", | ||
torch.ops.aten.reshape.default, | ||
(inp, [-1]), | ||
{}, | ||
) | ||
reshape_node.meta = dict(inp.meta) | ||
if "val" in inp.meta: | ||
reshape_node.meta["val"] = inp.meta["val"].reshape(-1) | ||
|
||
# Rewrite argmax: take reshape_node as input, set dim=0 | ||
n.args = (reshape_node, 0, *n.args[2:]) | ||
|
||
modified = True | ||
|
||
if modified: | ||
graph_module.recompile() | ||
dead_code_elimination_pass(graph_module) | ||
|
||
return PassResult(graph_module, modified) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
to_be_implemented_operator = [ | ||
exir_ops.edge.aten._adaptive_avg_pool3d.default, | ||
exir_ops.edge.aten.adaptive_max_pool2d.default, | ||
exir_ops.edge.aten.adaptive_max_pool3d.default, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this pr just about argmax? i see maxpool added here. if it is about reduce ops then please do update the title There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah updated |
||
exir_ops.edge.aten.avg_pool3d.default, | ||
exir_ops.edge.aten.div.Tensor_mode, | ||
exir_ops.edge.aten.log10.default, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import unittest | ||
|
||
import torch | ||
from executorch.backends.qualcomm._passes import InsertReshapeForReduceOps | ||
|
||
|
||
class TestPasses(unittest.TestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for adding this file, it's helpful for us to test all the passes thoroughly. |
||
def test_insert_reshape_for_argmax(self): | ||
class ArgmaxModule(torch.nn.Module): | ||
def forward(self, x): | ||
return torch.argmax(x, dim=None) | ||
|
||
mod = ArgmaxModule() | ||
|
||
x = torch.tensor([[1.0, 5.0], [3.0, 2.0]]) | ||
ep = torch.export.export(mod, (x,)) | ||
# Run original module for reference | ||
ref = mod(x) | ||
|
||
reshape_nodes = [ | ||
n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default | ||
] | ||
argmax_nodes = [ | ||
n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default | ||
] | ||
self.assertTrue(len(reshape_nodes) == 0, "Reshape node not inserted") | ||
self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing") | ||
|
||
InsertReshapeForReduceOps()(ep.graph_module) | ||
|
||
out = ep.graph_module(x) | ||
|
||
# Check graph structure: argmax should take a reshape as input | ||
reshape_nodes = [ | ||
n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default | ||
] | ||
argmax_nodes = [ | ||
n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default | ||
] | ||
self.assertTrue(len(reshape_nodes) == 1, "Reshape node should be inserted") | ||
self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing") | ||
|
||
argmax_node = argmax_nodes[0] | ||
self.assertEqual(argmax_node.args[1], 0, "Argmax dim not set to 0") | ||
|
||
# Execute new graph and compare with reference | ||
out = ep.graph_module(x) | ||
self.assertTrue( | ||
torch.equal(*out, ref), f"Output mismatch: got {out}, expected {ref}" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Uh oh!
There was an error while loading. Please reload this page.