diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index 0846d973722..3feb0a0e051 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -6,9 +6,12 @@ import itertools +from typing import List + import torch from executorch.backends.arm._passes.arm_pass_utils import create_node -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op + +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, QuantArgs from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule @@ -24,6 +27,22 @@ class AnnotateDecomposedMatmulPass(ExportPass): matmul-op (can be mm or bmm). """ + def _match_partition_to_node( + self, node: torch.fx.Node, partitioned_inputs: List[torch.fx.Node] + ) -> torch.fx.Node: + """ + The partition.input_nodes order is not guaranteed. Compare these + with the matmul node inputs coming in and return the nodes + in the correct order. + """ + if not node or node in partitioned_inputs or node.op == "placeholder": + return node + else: + return self._match_partition_to_node( + node.all_input_nodes[0], partitioned_inputs + ) + raise RuntimeError(f"Cannot find an input node which matches, {node}.") + def call(self, graph_module: GraphModule) -> PassResult: matmul_partitions = get_source_partitions( graph_module.graph, @@ -45,28 +64,36 @@ def call(self, graph_module: GraphModule) -> PassResult: matmul_node = [ node for node in partition.nodes if node.target in matmul_targets ][0] + if quantized_input: matmul_args = matmul_node.all_input_nodes - for i in range(len(matmul_args)): - input_node = partition.input_nodes[i] - matmul_input_node = matmul_args[i] + for node in matmul_args: + input_node = self._match_partition_to_node( + node, partition.input_nodes + ) + # Remove partition input dq-node input_node.replace_all_uses_with(input_node.all_input_nodes[0]) graph_module.graph.erase_node(input_node) - input_node_qargs = input_node.args[1:] + input_node_qargs = QuantArgs.from_operator( + input_node.target, input_node.args + ) + with graph_module.graph.inserting_before(matmul_node): # Create new dq-node before matmul dq_node = create_node( graph=graph_module.graph, op_target=dq_op, ) - dq_node.args = (matmul_input_node, *input_node_qargs) - matmul_node.replace_input_with(matmul_input_node, dq_node) + dq_node.args = (node, *input_node_qargs) + matmul_node.replace_input_with(node, dq_node) partition_output = list(partition.output_nodes[0].users)[0] quantized_output = partition_output.target == q_op if quantized_output: - output_node_qargs = partition_output.args[1:] + output_node_qargs = QuantArgs.from_operator( + partition_output.target, partition_output.args + ) with graph_module.graph.inserting_after(matmul_node): # Create q-node after matmul q_node = create_node( diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index ed1a551ef33..bd6e1ef6897 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -134,19 +134,16 @@ def test_matmul_tosa_MI(self, test_data_generator: Callable[[], Tuple]): self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data) @parameterized.expand(MatMul.test_data_generators) - @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) def test_matmul_tosa_BI(self, test_data_generator: Callable[[], Tuple]): test_data = test_data_generator() self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data) @parameterized.expand(BMM.test_data_generators) - @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) def test_bmm_tosa_BI(self, test_data_generator: Callable[[], Tuple]): test_data = test_data_generator() self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data) @parameterized.expand(BMMSingleInput.test_data_generators) - @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) def test_bmm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]): test_data = test_data_generator() self._test_bmm_tosa_BI_pipeline(self.BMMSingleInput(), test_data) @@ -162,7 +159,6 @@ def test_bmm_u55_BI_xfails(self, test_data_generator: Callable[[], Tuple]): @parameterized.expand(BMM.test_data_generators) @pytest.mark.corstone_fvp - @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) def test_bmm_u85_BI(self, test_data_generator: Callable[[], Tuple]): test_data = test_data_generator() self._test_bmm_ethosu_BI_pipeline( @@ -183,7 +179,6 @@ def test_bmm_single_input_u55_BI_xfails( @parameterized.expand(BMMSingleInput.test_data_generators) @pytest.mark.corstone_fvp - @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) def test_bmm_single_input_u85_BI(self, test_data_generator: Callable[[], Tuple]): test_data = test_data_generator() self._test_bmm_ethosu_BI_pipeline( diff --git a/backends/arm/test/ops/test_mm.py b/backends/arm/test/ops/test_mm.py index ba5b0eb1b86..d9b58da9046 100644 --- a/backends/arm/test/ops/test_mm.py +++ b/backends/arm/test/ops/test_mm.py @@ -126,7 +126,6 @@ def test_mm_tosa_BI(self, test_data_generator: Callable[[], Tuple]): self._test_mm_tosa_BI_pipeline(self.MM(), test_data) @parameterized.expand(MMSingleInput.test_data_generators) - @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) def test_mm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]): test_data = test_data_generator() self._test_mm_tosa_BI_pipeline(self.MMSingleInput(), test_data) @@ -150,7 +149,6 @@ def test_mm_single_input_u55_BI(self, test_data_generator: Callable[[], Tuple]): ) @parameterized.expand(MM.test_data_generators) - @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) def test_mm_u85_BI(self, test_data_generator: Callable[[], Tuple]): test_data = test_data_generator() self._test_mm_ethosu_BI_pipeline( @@ -158,7 +156,6 @@ def test_mm_u85_BI(self, test_data_generator: Callable[[], Tuple]): ) @parameterized.expand(MMSingleInput.test_data_generators) - @pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534) def test_mm_single_input_u85_BI(self, test_data_generator: Callable[[], Tuple]): test_data = test_data_generator() self._test_mm_ethosu_BI_pipeline(