From 60492e46ebb23f995204aa267489a9a9ec4f0087 Mon Sep 17 00:00:00 2001 From: Arik Horodniceanu Date: Mon, 6 Apr 2026 16:42:08 -0700 Subject: [PATCH] Qualcomm AI Engine Direct - Adding QNN backend support for remainder core ATen ops --- backends/qualcomm/_passes/__init__.py | 2 + .../qualcomm/_passes/decompose_remainder.py | 103 ++++++++ backends/qualcomm/_passes/qnn_pass_manager.py | 3 + backends/qualcomm/_passes/utils.py | 2 + backends/qualcomm/tests/models.py | 24 ++ backends/qualcomm/tests/test_qnn_delegate.py | 249 +++++++++++++----- 6 files changed, 315 insertions(+), 68 deletions(-) create mode 100644 backends/qualcomm/_passes/decompose_remainder.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index c0da8da56e6..72df4240eb8 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -27,6 +27,7 @@ from .decompose_maxpool3d import DecomposeMaxPool3d from .decompose_minmaxdim import DecomposeMinMaxDim from .decompose_reciprocal import DecomposeReciprocal +from .decompose_remainder import DecomposeRemainder from .decompose_roll import DecomposeRoll from .decompose_silu import DecomposeSilu from .decompose_threshold import DecomposeThreshold @@ -80,6 +81,7 @@ DecomposeMaxPool3d, DecomposeMinMaxDim, DecomposeReciprocal, + DecomposeRemainder, DecomposeRoll, DecomposeSilu, DecomposeThreshold, diff --git a/backends/qualcomm/_passes/decompose_remainder.py b/backends/qualcomm/_passes/decompose_remainder.py new file mode 100644 index 00000000000..70b32efa8d1 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_remainder.py @@ -0,0 +1,103 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix + +from .utils import copy_meta, get_const_node + + +class DecomposeRemainder(ExportPass): + """ + Decompose remainder.Scalar and remainder.Tensor using the identity: + remainder(x, y) = x - floor(x / y) * y + """ + + def __init__(self): + super(DecomposeRemainder, self).__init__() + self.remainder_targets = { + torch.ops.aten.remainder.Scalar, + torch.ops.aten.remainder.Tensor, + exir_ops.edge.aten.remainder.Scalar, + exir_ops.edge.aten.remainder.Tensor, + } + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + # Cache scalar:node mappings to avoid duplicate buffer registrations if the same scalar divisor appears in multiple remainder ops + const_cache = {} + + for node in list(graph.nodes): + if node.op == "call_function" and node.target in self.remainder_targets: + x_node = node.args[0] + y_arg = node.args[1] + is_edge = isinstance(node.target, EdgeOpOverload) + meta = node.meta + + div_op = ( + exir_ops.edge.aten.div.Tensor + if is_edge + else torch.ops.aten.div.Tensor + ) + floor_op = ( + exir_ops.edge.aten.floor.default + if is_edge + else torch.ops.aten.floor.default + ) + mul_op = ( + exir_ops.edge.aten.mul.Tensor + if is_edge + else torch.ops.aten.mul.Tensor + ) + sub_op = ( + exir_ops.edge.aten.sub.Tensor + if is_edge + else torch.ops.aten.sub.Tensor + ) + + is_scalar = not isinstance(y_arg, torch.fx.Node) + if is_scalar and is_edge: + if y_arg not in const_cache: + attr_name = get_new_attr_name_with_prefix("_remainder_const_")( + graph_module + ) + const_cache[y_arg] = get_const_node( + graph, graph_module, attr_name, y_arg, node + ) + y_node = const_cache[y_arg] + else: + y_node = y_arg + + with graph.inserting_before(node): + div_node = graph.create_node( + "call_function", div_op, (x_node, y_node) + ) + div_node.meta = copy_meta(meta) + + floor_node = graph.create_node( + "call_function", floor_op, (div_node,) + ) + floor_node.meta = copy_meta(meta) + + mul_node = graph.create_node( + "call_function", mul_op, (floor_node, y_node) + ) + mul_node.meta = copy_meta(meta) + + sub_node = graph.create_node( + "call_function", sub_op, (x_node, mul_node) + ) + sub_node.meta = copy_meta(meta) + + for user in node.users.copy(): + user.replace_input_with(node, sub_node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index b3b59b9134c..ca7faa244cf 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -32,6 +32,7 @@ DecomposeMaxPool3d, DecomposeMinMaxDim, DecomposeReciprocal, + DecomposeRemainder, DecomposeRoll, DecomposeSilu, DecomposeThreshold, @@ -106,6 +107,7 @@ def get_capture_program_passes(): (DecomposeLogVariants, True), (DecomposeMaxPool3d, True), (DecomposeMinMaxDim, True), + (DecomposeRemainder, True), (DecomposeTrunc, True), (ExpandBroadcastTensorShape, True), (FixedLinearKeepDim, True), @@ -239,6 +241,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): # Decompose Reciprocal into Div for these 2 backend # TODO: Skip this pass for CPU backend (Dependency: Backend-aware passes manager) self.add_pass(DecomposeReciprocal()) + self.add_pass(DecomposeRemainder()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(DecomposeLogVariants()) self.add_pass(ReplaceInfValues()) diff --git a/backends/qualcomm/_passes/utils.py b/backends/qualcomm/_passes/utils.py index ed4b02ae1b0..24835702ce4 100755 --- a/backends/qualcomm/_passes/utils.py +++ b/backends/qualcomm/_passes/utils.py @@ -71,6 +71,7 @@ def get_passes_dependency_for_capture_program(): DecomposeLinalgVectorNorm, DecomposeLogVariants, DecomposeMaxPool3d, + DecomposeRemainder, DecomposeTrunc, ExpandBroadcastTensorShape, FixedLinearKeepDim, @@ -101,6 +102,7 @@ def get_passes_dependency_for_capture_program(): DecomposeLinalgVectorNorm: [RemoveRedundancy], DecomposeLogVariants: [RemoveRedundancy], DecomposeMaxPool3d: [RemoveRedundancy], + DecomposeRemainder: [RemoveRedundancy], DecomposeTrunc: [RemoveRedundancy], ExpandBroadcastTensorShape: [FoldQDQ], FixedLinearKeepDim: [FoldQDQ], diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 3df9c437447..053e5d26455 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1889,6 +1889,30 @@ def forward(self, x): return x.repeat(1, 2, 3, 4) +class RemainderScalar(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.remainder(x, 3.0) + + +class RemainderTensor(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.remainder(x, y) + + +class RemainderMultiNode(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.remainder(x, 3.0), torch.remainder(x, y) + + class ReWriteObs(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 223fa7d9386..48f07da06e9 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -128,17 +128,29 @@ def test_qnn_backend_abs(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_acos(self): - module = Acos() # noqa: F405 - sample_input = (torch.rand(3, 4) * 2 - 1,) - self.lower_module_and_test_output(module, sample_input) + test_comb = [ + { + QCOM_MODULE: [Acos()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.rand(3, 4) * 2 - 1,)], + }, + { + QCOM_MODULE: [AcosMultiNode()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([0.0, 0.5, -0.5, 1.0, -1.0]), + torch.tensor([0.1, -0.1, 0.9, -0.9, 0.0]), + ) + ], + }, + ] - def test_qnn_backend_acos_multi_node(self): - module = AcosMultiNode() # noqa: F405 - sample_input = ( - torch.tensor([0.0, 0.5, -0.5, 1.0, -1.0]), - torch.tensor([0.1, -0.1, 0.9, -0.9, 0.0]), - ) - self.lower_module_and_test_output(module, sample_input) + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_adaptive_avg_pool1d(self): module = AdaptiveAvgPool1D() # noqa: F405 @@ -1480,28 +1492,38 @@ def test_qnn_backend_log_softmax(self): sample_input = (torch.randn([1, 4, 8, 8]),) self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_log_variants_multi_node(self): - module = LogVariantsMultiNode() # noqa: F405 - sample_input = ( - torch.abs(torch.rand(2, 3, 4)) + 0.1, - torch.abs(torch.rand(2, 3, 4)) + 0.1, - ) - self.lower_module_and_test_output(module, sample_input) - - def test_qnn_backend_log10(self): - module = Log10() # noqa: F405 - sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),) - self.lower_module_and_test_output(module, sample_input) - - def test_qnn_backend_log1p(self): - module = Log1p() # noqa: F405 - sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),) - self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_log_variants(self): + test_comb = [ + { + QCOM_MODULE: [Log10()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)], + }, + { + QCOM_MODULE: [Log1p()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)], + }, + { + QCOM_MODULE: [Log2()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)], + }, + { + QCOM_MODULE: [LogVariantsMultiNode()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + ( + torch.abs(torch.rand(2, 3, 4)) + 0.1, + torch.abs(torch.rand(2, 3, 4)) + 0.1, + ) + ], + }, + ] - def test_qnn_backend_log2(self): - module = Log2() # noqa: F405 - sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),) - self.lower_module_and_test_output(module, sample_input) + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_maximum(self): module = Maximum() # noqa: F405 @@ -1749,6 +1771,42 @@ def test_qnn_backend_repeat(self): sample_input = (torch.randn([2, 2, 2, 2]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_remainder(self): + test_comb = [ + { + QCOM_MODULE: [RemainderScalar()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + (torch.tensor([1.0, 2.5, 4.0, -1.0, -2.5, 7.5]).reshape(2, 3),) + ], + }, + { + QCOM_MODULE: [RemainderTensor()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([7.0, 5.0, 4.0, -3.0, 8.0, 1.0]).reshape(2, 3), + torch.tensor([2.0, 3.0, 1.5, 2.0, 3.0, 4.0]).reshape(2, 3), + ) + ], + }, + { + QCOM_MODULE: [RemainderMultiNode()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([1.0, 2.0, 4.0, 5.0, 7.0, 8.0]).reshape(1, 2, 3), + torch.tensor([2.0, 3.0, 4.0, 2.0, 3.0, 4.0]).reshape(1, 2, 3), + ) + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_reshape(self): module = Reshape() # noqa: F405 sample_input = (torch.randn([3, 4]),) @@ -2407,19 +2465,30 @@ def test_qnn_backend_abs(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_acos(self): - module = Acos() # noqa: F405 - sample_input = (torch.rand(3, 4) * 2 - 1,) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + test_comb = [ + { + QCOM_MODULE: [Acos()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.rand(3, 4) * 2 - 1,)], + }, + { + QCOM_MODULE: [AcosMultiNode()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([0.0, 0.5, -0.5, 1.0, -1.0]), + torch.tensor([0.1, -0.1, 0.9, -0.9, 0.0]), + ) + ], + }, + ] - def test_qnn_backend_acos_multi_node(self): - module = AcosMultiNode() # noqa: F405 - sample_input = ( - torch.tensor([0.0, 0.5, -0.5, 1.0, -1.0]), - torch.tensor([0.1, -0.1, 0.9, -0.9, 0.0]), - ) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_adaptive_avg_pool1d(self): module = AdaptiveAvgPool1D() # noqa: F405 @@ -3884,32 +3953,39 @@ def test_qnn_backend_log_softmax(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_log_variants_multi_node(self): - module = LogVariantsMultiNode() # noqa: F405 - sample_input = ( - torch.abs(torch.rand(2, 3, 4)) + 0.1, - torch.abs(torch.rand(2, 3, 4)) + 0.1, - ) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) - - def test_qnn_backend_log10(self): - module = Log10() # noqa: F405 - sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) - - def test_qnn_backend_log1p(self): - module = Log1p() # noqa: F405 - sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_log_variants(self): + test_comb = [ + { + QCOM_MODULE: [Log10()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)], + }, + { + QCOM_MODULE: [Log1p()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)], + }, + { + QCOM_MODULE: [Log2()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [(torch.abs(torch.rand(2, 5, 1, 3) + 0.1),)], + }, + { + QCOM_MODULE: [LogVariantsMultiNode()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + ( + torch.abs(torch.rand(2, 3, 4)) + 0.1, + torch.abs(torch.rand(2, 3, 4)) + 0.1, + ) + ], + }, + ] - def test_qnn_backend_log2(self): - module = Log2() # noqa: F405 - sample_input = (torch.abs(torch.rand(2, 5, 1, 3) + 0.1),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_maximum(self): module = Maximum() # noqa: F405 @@ -4185,6 +4261,43 @@ def test_qnn_backend_repeat(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_remainder(self): + test_comb = [ + { + QCOM_MODULE: [RemainderScalar()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + (torch.tensor([1.0, 2.5, 4.0, -1.0, -2.5, 7.5]).reshape(2, 3),) + ], + }, + { + QCOM_MODULE: [RemainderTensor()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([7.0, 5.0, 4.0, -3.0, 8.0, 1.0]).reshape(2, 3), + torch.tensor([2.0, 3.0, 1.5, 2.0, 3.0, 4.0]).reshape(2, 3), + ) + ], + }, + { + QCOM_MODULE: [RemainderMultiNode()], # noqa: F405 + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([1.0, 2.0, 4.0, 5.0, 7.0, 8.0]).reshape(1, 2, 3), + torch.tensor([2.0, 3.0, 4.0, 2.0, 3.0, 4.0]).reshape(1, 2, 3), + ) + ], + }, + ] + + index = 0 + for comb in test_comb: + for module in comb[QCOM_MODULE]: + for sample_input in comb[QCOM_SAMPLE_INPUTS]: + with self.subTest(i=index): + index += 1 + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_reshape(self): module = Reshape() # noqa: F405 sample_input = (torch.randn([3, 4]),)