Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

import itertools

from typing import List

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op

from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, QuantArgs
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule
Expand All @@ -24,6 +27,22 @@ class AnnotateDecomposedMatmulPass(ExportPass):
matmul-op (can be mm or bmm).
"""

def _match_partition_to_node(
self, node: torch.fx.Node, partitioned_inputs: List[torch.fx.Node]
) -> torch.fx.Node:
"""
The partition.input_nodes order is not guaranteed. Compare these
with the matmul node inputs coming in and return the nodes
in the correct order.
"""
if not node or node in partitioned_inputs or node.op == "placeholder":
return node
else:
return self._match_partition_to_node(
node.all_input_nodes[0], partitioned_inputs
)
raise RuntimeError(f"Cannot find an input node which matches, {node}.")

def call(self, graph_module: GraphModule) -> PassResult:
matmul_partitions = get_source_partitions(
graph_module.graph,
Expand All @@ -45,28 +64,36 @@ def call(self, graph_module: GraphModule) -> PassResult:
matmul_node = [
node for node in partition.nodes if node.target in matmul_targets
][0]

if quantized_input:
matmul_args = matmul_node.all_input_nodes
for i in range(len(matmul_args)):
input_node = partition.input_nodes[i]
matmul_input_node = matmul_args[i]
for node in matmul_args:
input_node = self._match_partition_to_node(
node, partition.input_nodes
)

# Remove partition input dq-node
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
graph_module.graph.erase_node(input_node)
input_node_qargs = input_node.args[1:]
input_node_qargs = QuantArgs.from_operator(
input_node.target, input_node.args
)

with graph_module.graph.inserting_before(matmul_node):
# Create new dq-node before matmul
dq_node = create_node(
graph=graph_module.graph,
op_target=dq_op,
)
dq_node.args = (matmul_input_node, *input_node_qargs)
matmul_node.replace_input_with(matmul_input_node, dq_node)
dq_node.args = (node, *input_node_qargs)
matmul_node.replace_input_with(node, dq_node)

partition_output = list(partition.output_nodes[0].users)[0]
quantized_output = partition_output.target == q_op
if quantized_output:
output_node_qargs = partition_output.args[1:]
output_node_qargs = QuantArgs.from_operator(
partition_output.target, partition_output.args
)
with graph_module.graph.inserting_after(matmul_node):
# Create q-node after matmul
q_node = create_node(
Expand Down
5 changes: 0 additions & 5 deletions backends/arm/test/ops/test_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,16 @@ def test_matmul_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data)

@parameterized.expand(MatMul.test_data_generators)
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
def test_matmul_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data)

@parameterized.expand(BMM.test_data_generators)
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
def test_bmm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data)

@parameterized.expand(BMMSingleInput.test_data_generators)
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
def test_bmm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_bmm_tosa_BI_pipeline(self.BMMSingleInput(), test_data)
Expand All @@ -162,7 +159,6 @@ def test_bmm_u55_BI_xfails(self, test_data_generator: Callable[[], Tuple]):

@parameterized.expand(BMM.test_data_generators)
@pytest.mark.corstone_fvp
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
def test_bmm_u85_BI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_bmm_ethosu_BI_pipeline(
Expand All @@ -183,7 +179,6 @@ def test_bmm_single_input_u55_BI_xfails(

@parameterized.expand(BMMSingleInput.test_data_generators)
@pytest.mark.corstone_fvp
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
def test_bmm_single_input_u85_BI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_bmm_ethosu_BI_pipeline(
Expand Down
3 changes: 0 additions & 3 deletions backends/arm/test/ops/test_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def test_mm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
self._test_mm_tosa_BI_pipeline(self.MM(), test_data)

@parameterized.expand(MMSingleInput.test_data_generators)
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
def test_mm_single_input_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_mm_tosa_BI_pipeline(self.MMSingleInput(), test_data)
Expand All @@ -150,15 +149,13 @@ def test_mm_single_input_u55_BI(self, test_data_generator: Callable[[], Tuple]):
)

@parameterized.expand(MM.test_data_generators)
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
def test_mm_u85_BI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_mm_ethosu_BI_pipeline(
common.get_u85_compile_spec(), self.MM(), test_data
)

@parameterized.expand(MMSingleInput.test_data_generators)
@pytest.mark.flaky # TODO: Investigate flakyness (MLETORCH-534)
def test_mm_single_input_u85_BI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_mm_ethosu_BI_pipeline(
Expand Down
Loading