diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index a286bf8b1ae..26b2bdc96c9 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -33,6 +33,7 @@ from .i64_to_i32 import I64toI32 from .insert_io_qdq import InsertIOQDQ from .insert_requantize import InsertRequantize +from .insert_reshape_for_reduce_ops import InsertReshapeForReduceOps from .layout_transform import LayoutTransform from .lift_constant_scalar_operands import LiftConstantScalarOperands from .recompose_pixel_unshuffle import RecomposePixelUnshuffle @@ -45,7 +46,6 @@ from .seq_mse import SeqMSE from .tag_quant_io import TagQuantIO - __all__ = [ AnnotateAdaptiveAvgPool1D, AnnotateQuantAttrs, @@ -75,6 +75,7 @@ FuseConsecutiveTranspose, I64toI32, InsertIOQDQ, + InsertReshapeForReduceOps, InsertRequantize, LayoutTransform, LiftConstantScalarOperands, diff --git a/backends/qualcomm/_passes/insert_reshape_for_reduce_ops.py b/backends/qualcomm/_passes/insert_reshape_for_reduce_ops.py new file mode 100644 index 00000000000..52f9546c28e --- /dev/null +++ b/backends/qualcomm/_passes/insert_reshape_for_reduce_ops.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + + +class InsertReshapeForReduceOps(ExportPass): + """ + Rewrite `aten.argmax.default` with `dim=None` into + a reshape-to-1D followed by argmax(dim=0). + + PyTorch semantics: + torch.argmax(x, dim=None) -> flatten(x) then argmax along axis=0 + + QNN requires an explicit axis, so we insert the reshape. + """ + + def __init__(self): + super().__init__() + self.op_map = {torch.ops.aten.argmax.default, torch.ops.aten.argmin.default} + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + modified = False + + for n in graph.nodes: + if n.target in self.op_map: + dim_arg = None if len(n.args) == 1 else n.args[1] + + if dim_arg is None: + inp = n.args[0] + + # Insert reshape before argmax + with graph.inserting_before(n): + reshape_node = graph.create_node( + "call_function", + torch.ops.aten.reshape.default, + (inp, [-1]), + {}, + ) + reshape_node.meta = dict(inp.meta) + if "val" in inp.meta: + reshape_node.meta["val"] = inp.meta["val"].reshape(-1) + + # Rewrite argmax: take reshape_node as input, set dim=0 + n.args = (reshape_node, 0, *n.args[2:]) + + modified = True + + if modified: + graph_module.recompile() + dead_code_elimination_pass(graph_module) + + return PassResult(graph_module, modified) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index a377f0f4eb4..796662ca6b3 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -38,6 +38,7 @@ I64toI32, InsertIOQDQ, InsertRequantize, + InsertReshapeForReduceOps, LayoutTransform, LiftConstantScalarOperands, RecomposePixelUnshuffle, @@ -209,6 +210,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(ReplaceInfValues()) self.add_pass(LiftConstantScalarOperands()) + self.add_pass(InsertReshapeForReduceOps()) return self._transform(graph_module) def transform_for_export_pipeline( @@ -229,6 +231,7 @@ def transform_for_export_pipeline( self.add_pass(ConvertLinearToConv2d(exported_program)) self.add_pass(ConvertSquareToPow()) self.add_pass(LiftConstantScalarOperands()) + self.add_pass(InsertReshapeForReduceOps()) self._transform(exported_program.graph_module) ep = lift_constant_tensor_pass(exported_program) return ep diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 7a2924fe756..0a947759538 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -17,6 +17,7 @@ to_be_implemented_operator = [ exir_ops.edge.aten._adaptive_avg_pool3d.default, exir_ops.edge.aten.adaptive_max_pool2d.default, + exir_ops.edge.aten.adaptive_max_pool3d.default, exir_ops.edge.aten.avg_pool3d.default, exir_ops.edge.aten.div.Tensor_mode, exir_ops.edge.aten.log10.default, diff --git a/backends/qualcomm/tests/TARGETS b/backends/qualcomm/tests/TARGETS index 639303c7eb8..d968f954485 100644 --- a/backends/qualcomm/tests/TARGETS +++ b/backends/qualcomm/tests/TARGETS @@ -47,3 +47,17 @@ runtime.python_library( ":test_qnn_delegate" ] ) + +runtime.python_test( + name = "test_passes", + srcs = [ + "test_passes.py", + ], + deps = [ + "fbsource//third-party/pypi/expecttest:expecttest", # @manual + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/backends/qualcomm/_passes:passes", + "//executorch/backends/qualcomm/builders:builders", + ], +) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 7b1663d09f6..3240ad7a018 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -171,21 +171,23 @@ def forward(self, y): class Argmax(torch.nn.Module): - def __init__(self): + def __init__(self, dim: Optional[int] = None, keepdim: bool = False): super().__init__() + self.dim = dim + self.keepdim = keepdim def forward(self, x): - x = torch.argmax(x, dim=0, keepdim=True) - return x + return torch.argmax(x, dim=self.dim, keepdim=self.keepdim) class Argmin(torch.nn.Module): - def __init__(self): + def __init__(self, dim: Optional[int] = None, keepdim: bool = False): super().__init__() + self.dim = dim + self.keepdim = keepdim def forward(self, x): - x = torch.argmin(x, dim=0, keepdim=True) - return x + return torch.argmin(x, dim=self.dim, keepdim=self.keepdim) class ArgminViewSqueezeConv2D(torch.nn.Module): diff --git a/backends/qualcomm/tests/test_passes.py b/backends/qualcomm/tests/test_passes.py new file mode 100644 index 00000000000..94a5d08acc1 --- /dev/null +++ b/backends/qualcomm/tests/test_passes.py @@ -0,0 +1,54 @@ +import unittest + +import torch +from executorch.backends.qualcomm._passes import InsertReshapeForReduceOps + + +class TestPasses(unittest.TestCase): + def test_insert_reshape_for_argmax(self): + class ArgmaxModule(torch.nn.Module): + def forward(self, x): + return torch.argmax(x, dim=None) + + mod = ArgmaxModule() + + x = torch.tensor([[1.0, 5.0], [3.0, 2.0]]) + ep = torch.export.export(mod, (x,)) + # Run original module for reference + ref = mod(x) + + reshape_nodes = [ + n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default + ] + argmax_nodes = [ + n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default + ] + self.assertTrue(len(reshape_nodes) == 0, "Reshape node not inserted") + self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing") + + InsertReshapeForReduceOps()(ep.graph_module) + + out = ep.graph_module(x) + + # Check graph structure: argmax should take a reshape as input + reshape_nodes = [ + n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default + ] + argmax_nodes = [ + n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default + ] + self.assertTrue(len(reshape_nodes) == 1, "Reshape node should be inserted") + self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing") + + argmax_node = argmax_nodes[0] + self.assertEqual(argmax_node.args[1], 0, "Argmax dim not set to 0") + + # Execute new graph and compare with reference + out = ep.graph_module(x) + self.assertTrue( + torch.equal(*out, ref), f"Output mismatch: got {out}, expected {ref}" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index e18c5b05a97..fd0454e3250 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -173,14 +173,64 @@ def test_qnn_backend_arange(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_argmax(self): - module = Argmax() # noqa: F405 - sample_input = (torch.randn(16, 3, 4, 4),) - self.lower_module_and_test_output(module, sample_input) + test_cases = [ + { + QCOM_MODULE: Argmax(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmax(dim=0, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmax(dim=1, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),), + }, + { + QCOM_MODULE: Argmax(dim=None, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),), + }, + { + QCOM_MODULE: Argmax(dim=2, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),), + }, + ] + + for i, case in enumerate(test_cases): + with self.subTest(i=i): + self.lower_module_and_test_output( + case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS] + ) def test_qnn_backend_argmin(self): - module = Argmin() # noqa: F405 - sample_input = (torch.rand(3, 4),) - self.lower_module_and_test_output(module, sample_input) + test_cases = [ + { + QCOM_MODULE: Argmin(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmin(dim=0, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmin(dim=1, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),), + }, + { + QCOM_MODULE: Argmin(dim=None, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),), + }, + { + QCOM_MODULE: Argmin(dim=2, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),), + }, + ] + + for i, case in enumerate(test_cases): + with self.subTest(i=i): + self.lower_module_and_test_output( + case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS] + ) @unittest.expectedFailure def test_qnn_backend_asin(self): @@ -1797,16 +1847,66 @@ def test_qnn_backend_arange(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_argmax(self): - module = Argmax() # noqa: F405 - sample_input = (torch.randn(16, 3, 4, 4),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + test_cases = [ + { + QCOM_MODULE: Argmax(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmax(dim=0, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmax(dim=1, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),), + }, + { + QCOM_MODULE: Argmax(dim=None, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),), + }, + { + QCOM_MODULE: Argmax(dim=2, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),), + }, + ] + + for i, case in enumerate(test_cases): + with self.subTest(i=i): + module = self.get_qdq_module( + case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output(module, case[QCOM_SAMPLE_INPUTS]) def test_qnn_backend_argmin(self): - module = Argmin() # noqa: F405 - sample_input = (torch.randn(16, 3, 4, 4),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + test_cases = [ + { + QCOM_MODULE: Argmin(), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmin(dim=0, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),), + }, + { + QCOM_MODULE: Argmin(dim=1, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),), + }, + { + QCOM_MODULE: Argmin(dim=None, keepdim=False), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),), + }, + { + QCOM_MODULE: Argmin(dim=2, keepdim=True), # noqa: F405 + QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),), + }, + ] + + for i, case in enumerate(test_cases): + with self.subTest(i=i): + module = self.get_qdq_module( + case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS] + ) + self.lower_module_and_test_output(module, case[QCOM_SAMPLE_INPUTS]) def test_qnn_backend_asin(self): module = Asin() # noqa: F405