diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 26b2bdc96c9..154a360689e 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -18,6 +18,7 @@ from .decompose_col_im import DecomposeColIm from .decompose_einsum import DecomposeEinsum from .decompose_expm1 import DecomposeExpM1 +from .decompose_floor_divide import DecomposeFloorDivide from .decompose_glu import DecomposeGlu from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm from .decompose_minmaxdim import DecomposeMinMaxDim @@ -61,6 +62,7 @@ DecomposeColIm, DecomposeEinsum, DecomposeExpM1, + DecomposeFloorDivide, DecomposeGlu, DecomposeLinalgVectorNorm, DecomposeMinMaxDim, diff --git a/backends/qualcomm/_passes/decompose_floor_divide.py b/backends/qualcomm/_passes/decompose_floor_divide.py new file mode 100644 index 00000000000..f7de074259e --- /dev/null +++ b/backends/qualcomm/_passes/decompose_floor_divide.py @@ -0,0 +1,62 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import merge_decomposed_graph + + +class FloorDivide(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + dtype = x.dtype + result = torch.div(x, y) + result = torch.floor(result) + return result.to(dtype) + + +class DecomposeFloorDivide(ExportPass): + """ + Decompose for math equivalent op. + Since QNN does not support floor_divide operations for int32 or int64 inputs, + it is necessary to decompose the operation into a division using floating-point precision, + followed by applying the floor function. + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + model = FloorDivide() + if ( + torch.ops.aten.floor_divide.default == node.target + and not torch.is_floating_point(node.meta["val"]) + ): + decomposed_module = torch.export.export( + model, + (node.args[0].meta["val"], node.args[1].meta["val"]), + strict=True, + ).module() + with graph.inserting_before(node): + # remap is used to map original node values to new node values, + # which ensures that reference to nodes are correctly updated in the new graph + remap = {"x": node.args[0], "y": node.args[1]} + merge_decomposed_graph( + remap=remap, + target_node=node, + target_graph=graph, + decomposed_graph_module=decomposed_module, + ) + graph.erase_node(node) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 796662ca6b3..360581a2929 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -23,6 +23,7 @@ DecomposeColIm, DecomposeEinsum, DecomposeExpM1, + DecomposeFloorDivide, DecomposeGlu, DecomposeLinalgVectorNorm, DecomposeMinMaxDim, @@ -223,6 +224,11 @@ def transform_for_export_pipeline( self.add_pass(DecomposeThreshold()) self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True)) self.add_pass(DecomposeExpM1()) + # DecomposeFloorDivide does not apply to the annotation pipeline, + # since the CPU QDQ model would reduce accuracy. + # We keep div and floor operations in floating-point to maintain precision. + # This pass is needed before to_edge pipeline to avoid mixed type for div operator with RemoveMixedTypeOperators pass. + self.add_pass(DecomposeFloorDivide()) self.add_pass(DecomposeWrapWithAutocast()) # this pass will rewrite state_dict, it needs to be accomplished before # to_edge_transform_and_lower diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 8b2e5cc84d8..9c62f3af6a1 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -398,8 +398,8 @@ def test_qnn_backend_cumsum(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_einsum_outer_product(self): module = EinsumOuterProduct() # noqa: F405 @@ -467,8 +467,8 @@ def test_qnn_backend_element_wise_add(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_element_wise_and(self): module = And(torch.tensor(1.7), torch.tensor(0.2)) # noqa: F405 @@ -506,8 +506,8 @@ def test_qnn_backend_element_wise_div(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_element_wise_mul(self): test_comb = [ @@ -533,8 +533,8 @@ def test_qnn_backend_element_wise_mul(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_element_wise_or(self): test_comb = [ @@ -608,8 +608,8 @@ def test_qnn_backend_element_wise_sub(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) @unittest.expectedFailure def test_qnn_backend_elu(self): @@ -651,10 +651,10 @@ def test_qnn_backend_expand(self): for module in modules: for sample_input in sample_inputs: with self.subTest(i=index): + index += 1 self.lower_module_and_test_output( module, sample_input, passes_job=passes_job ) - index += 1 def test_qnn_backend_expm1(self): sample_input = (torch.randn(3, 4, 5),) @@ -677,6 +677,21 @@ def test_qnn_backend_floor_divide(self): { QCOM_MODULE: [FloorDiv()], # noqa: F405 QCOM_SAMPLE_INPUTS: [ + (torch.randint(-100, 100, (10, 10)), torch.full((10, 10), 3)), + ( + torch.randint(-100, 100, (10, 10)).float(), + torch.full((10, 10), 2.5), + ), + (torch.randint(-1000, 1000, (10, 10)), torch.full((10, 10), 100)), + (torch.tensor([10]), torch.arange(1, 5)), # Failed + (torch.arange(-10, 10), torch.tensor([2])), + (torch.randint(-100, 100, (20,)), torch.full((20,), 2)), + (torch.randint(-100, 100, (5, 10)), torch.full((5, 10), 2)), + (torch.randint(-100, 100, (3, 4, 5)), torch.full((3, 4, 5), 2)), + ( + torch.randint(-100, 100, (2, 3, 4, 5)), + torch.full((2, 3, 4, 5), 2), + ), (torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)), (torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])), ], @@ -692,8 +707,8 @@ def test_qnn_backend_floor_divide(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_fold(self): sample_input = (torch.randn(3, 512, 256),) @@ -1136,8 +1151,8 @@ def test_qnn_backend_leaky_relu(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_less_equal(self): test_comb = [ @@ -1392,8 +1407,8 @@ def test_qnn_backend_prelu(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_relu(self): module = Relu() # noqa: F405 @@ -1520,8 +1535,8 @@ def test_qnn_backend_slice_scatter(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - self.lower_module_and_test_output(module, sample_input) index += 1 + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_stack(self): module = Stack() # noqa: F405 @@ -2332,9 +2347,9 @@ def test_qnn_backend_element_wise_add(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): + index += 1 gm = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(gm, sample_input) - index += 1 def test_qnn_backend_element_wise_and(self): module = And(torch.tensor(1.7), torch.tensor(0.2)) # noqa: F405 @@ -2373,9 +2388,9 @@ def test_qnn_backend_element_wise_div(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): + index += 1 gm = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(gm, sample_input) - index += 1 def test_qnn_backend_element_wise_mul(self): test_comb = [ @@ -2401,9 +2416,9 @@ def test_qnn_backend_element_wise_mul(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): + index += 1 gm = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(gm, sample_input) - index += 1 def test_qnn_backend_element_wise_or(self): test_comb = [ @@ -2479,9 +2494,9 @@ def test_qnn_backend_element_wise_sub(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): + index += 1 gm = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(gm, sample_input) - index += 1 def test_qnn_backend_elu(self): module = Elu() # noqa: F405 @@ -2530,11 +2545,11 @@ def test_qnn_backend_expand(self): for module in modules: for sample_input in sample_inputs: with self.subTest(i=index): + index += 1 module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output( module, sample_input, passes_job=passes_job ) - index += 1 def test_qnn_backend_expm1(self): sample_input = (torch.randn(3, 4, 5),) @@ -2560,6 +2575,21 @@ def test_qnn_backend_floor_divide(self): { QCOM_MODULE: [FloorDiv()], # noqa: F405 QCOM_SAMPLE_INPUTS: [ + (torch.randint(-100, 100, (10, 10)), torch.full((10, 10), 3)), + ( + torch.randint(-100, 100, (10, 10)).float(), + torch.full((10, 10), 2.5), + ), + (torch.randint(-1000, 1000, (10, 10)), torch.full((10, 10), 100)), + (torch.tensor([10]), torch.arange(1, 5)), + (torch.arange(-10, 10), torch.tensor([2])), + (torch.randint(-100, 100, (20,)), torch.full((20,), 2)), + (torch.randint(-100, 100, (5, 10)), torch.full((5, 10), 2)), + (torch.randint(-100, 100, (3, 4, 5)), torch.full((3, 4, 5), 2)), + ( + torch.randint(-100, 100, (2, 3, 4, 5)), + torch.full((2, 3, 4, 5), 2), + ), (torch.randn(2, 5, 1, 3), eps + torch.randn(2, 5, 1, 3)), (torch.randn([2, 5, 1, 3]), eps + torch.randn([4, 1])), ], @@ -2575,9 +2605,12 @@ def test_qnn_backend_floor_divide(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - gm = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(gm, sample_input) index += 1 + # Support int input cases with bypass_check=True + gm = self.get_qdq_module( + module, sample_input, bypass_check=True + ) + self.lower_module_and_test_output(gm, sample_input) def test_qnn_backend_fold(self): sample_input = (torch.randn(3, 512, 256),) @@ -3048,9 +3081,9 @@ def test_qnn_backend_leaky_relu(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): + index += 1 module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) - index += 1 def test_qnn_backend_less_equal(self): test_comb = [ @@ -3352,9 +3385,9 @@ def test_qnn_backend_prelu(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): - qdq_module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(qdq_module, sample_input) index += 1 + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_relu(self): module = Relu() # noqa: F405 @@ -3504,9 +3537,9 @@ def test_qnn_backend_slice_scatter(self): for module in comb[QCOM_MODULE]: for sample_input in comb[QCOM_SAMPLE_INPUTS]: with self.subTest(i=index): + index += 1 module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) - index += 1 def test_qnn_backend_softmax(self): modules = [Softmax(dim=1), Softmax(dim=-1)] # noqa: F405