From 65e80e691cc1a4ca0878e53473331dca72df6fe3 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Wed, 20 Aug 2025 09:30:11 +0200 Subject: [PATCH 1/2] Arm backend: Add TOSA dialect op for MATMUL Adds TOSA backend dialect op for MATMUL and associating pass to rewrite edge.aten.bmm to tosa.MATMUL. Also renames the files of node_visttor that lowers a TOSA backend dialect op to op_tosa_*.py, e.g. op_tosa_matmul.py. Signed-off-by: Oscar Andersson Change-Id: I578e5f7333922e02402dabc24ef1b12adf383b18 --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 4 + .../arm/_passes/fuse_constant_ops_pass.py | 1 + backends/arm/_passes/rewrite_matmul.py | 97 ++++++++++++ backends/arm/operators/__init__.py | 10 +- backends/arm/operators/op_bmm.py | 143 ------------------ backends/arm/operators/op_tosa_matmul.py | 94 ++++++++++++ .../{op_rescale.py => op_tosa_rescale.py} | 0 .../{op_resize.py => op_tosa_resize.py} | 0 .../{op_table.py => op_tosa_table.py} | 0 .../{op_transpose.py => op_tosa_transpose.py} | 0 backends/arm/tosa/dialect/__init__.py | 1 + backends/arm/tosa/dialect/ops/matmul.py | 56 +++++++ 13 files changed, 259 insertions(+), 148 deletions(-) create mode 100644 backends/arm/_passes/rewrite_matmul.py delete mode 100644 backends/arm/operators/op_bmm.py create mode 100644 backends/arm/operators/op_tosa_matmul.py rename backends/arm/operators/{op_rescale.py => op_tosa_rescale.py} (100%) rename backends/arm/operators/{op_resize.py => op_tosa_resize.py} (100%) rename backends/arm/operators/{op_table.py => op_tosa_table.py} (100%) rename backends/arm/operators/{op_transpose.py => op_tosa_transpose.py} (100%) create mode 100644 backends/arm/tosa/dialect/ops/matmul.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 008bc305aad..1374ed8a3d3 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -91,6 +91,7 @@ ReplaceScalarWithTensorArgPassTOSABI, ReplaceScalarWithTensorArgPassTOSAMI, ) +from .rewrite_matmul import RewriteMatmulPass # noqa from .rewrite_upsample import RewriteUpsamplePass # noqa from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa from .size_adjust_input_pass import SizeAdjustInputPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 1a0f4e4d384..c86eaa25962 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -92,6 +92,7 @@ ReplaceScalarWithTensorArgPassTOSABI, ReplaceScalarWithTensorArgPassTOSAMI, RetraceFoldedDtypesPass, + RewriteMatmulPass, RewriteUpsamplePass, ScalarsToAttributePass, SizeAdjustInputPass, @@ -211,6 +212,8 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(RewriteUpsamplePass(exported_program)) self.add_pass(AddBiasPass(exported_program)) + self.add_pass(InsertTableOpsPass(exported_program)) + self.add_pass(RewriteMatmulPass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) @@ -297,6 +300,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(RewriteUpsamplePass(exported_program)) self.add_pass(AddBiasPass(exported_program)) self.add_pass(InsertTableOpsPass(exported_program)) + self.add_pass(RewriteMatmulPass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 07d8288b5f1..c48fc008b5d 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -114,6 +114,7 @@ def call(self, graph_module): if node.op != "call_function": continue if node.target in [ + exir_ops.backend.tosa.MATMUL.default, exir_ops.backend.tosa.RESCALE.default, exir_ops.backend.tosa.RESIZE.default, exir_ops.backend.tosa.TABLE.default, diff --git a/backends/arm/_passes/rewrite_matmul.py b/backends/arm/_passes/rewrite_matmul.py new file mode 100644 index 00000000000..28ff800792b --- /dev/null +++ b/backends/arm/_passes/rewrite_matmul.py @@ -0,0 +1,97 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class RewriteMatmulPass(ArmPass): + """Rewrites aten.bmm to tosa.MATMUL and inserts a tosa.RESCALE op if needed.""" + + _passes_required_after: Set[Type[ExportPass]] = set() + + def _insert_output_rescale(self, graph_module, node, tosa_matmul_node, dtype): + input_qparams = get_input_qparams(node) + output_qparams = get_output_qparams(node)[0] + scale = ( + input_qparams[0].get_scale_per_tensor() + * input_qparams[1].get_scale_per_tensor() + ) / output_qparams.get_scale_per_tensor() + + with graph_module.graph.inserting_after(tosa_matmul_node): + # If the input is int8, we need to cast the output to int32 + rescale_node = create_node( + graph_module.graph, + op_target=exir_ops.backend.tosa.RESCALE.default, + from_node=tosa_matmul_node, + ) + tosa_matmul_node.replace_all_uses_with(rescale_node) + rescale_node.args = ( + tosa_matmul_node, + dtype, + scale, + 0, + output_qparams.get_zp_per_tensor(), + ) + + def call(self, graph_module): + modified = False + for node in graph_module.graph.nodes: + if ( + node.op != "call_function" + or node.target != exir_ops.edge.aten.bmm.default + ): + continue + modified = True + + x1, x2 = node.args + tosa_matmul_target = exir_ops.backend.tosa.MATMUL.default + with graph_module.graph.inserting_before(node): + tosa_matmul_node = create_node( + graph_module.graph, + op_target=tosa_matmul_target, + args=(x1, x2), + kwargs={}, + from_node=node, + ) + node.replace_all_uses_with(tosa_matmul_node) + graph_module.graph.erase_node(node) + + x1_fake_tensor = get_first_fake_tensor(x1) + x2_fake_tensor = get_first_fake_tensor(x2) + output_fake_tensor = tosa_matmul_target(x1_fake_tensor, x2_fake_tensor) + node_output_fake_tensor = get_first_fake_tensor(node) + if ( + output_fake_tensor.dtype == torch.int32 + and node_output_fake_tensor.dtype in (torch.int8, torch.int16) + ): + self._insert_output_rescale( + graph_module, + node, + tosa_matmul_node, + dtype=node_output_fake_tensor.dtype, + ) + if x1_fake_tensor.dtype == torch.int16: + tosa_matmul_node.meta[TosaSpecialDtype.meta_key()] = ( + TosaSpecialDtype.INT48 + ) + + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index d8b371570f6..9278d25959f 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -14,7 +14,6 @@ op_any, op_avg_pool2d, op_bitwise_not, - op_bmm, op_cat, op_ceil, op_clamp, @@ -42,8 +41,6 @@ op_pow, op_reciprocal, op_repeat, - op_rescale, - op_resize, op_rshift_tensor, op_rsqrt, op_sigmoid, @@ -51,10 +48,13 @@ op_slice, op_sub, op_sum, - op_table, op_tanh, op_to_dim_order_copy, - op_transpose, + op_tosa_matmul, + op_tosa_rescale, + op_tosa_resize, + op_tosa_table, + op_tosa_transpose, op_view, op_where, ops_binary, diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py deleted file mode 100644 index 9bebc3597ca..00000000000 --- a/backends/arm/operators/op_bmm.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -"""Provide a visitor for lowering batched matmul (BMM) to TOSA.""" - -from typing import Any, List - -import torch - -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, - get_output_qparams, -) -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, -) -from executorch.backends.arm.tosa import TosaSpecification -from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.quant_utils import build_rescale -from tosa.RoundingMode import RoundingMode # type: ignore - - -@register_node_visitor -class BMMVisitor(NodeVisitor): - """Provide a visitor that lowers ``aten.bmm`` to TOSA ``MATMUL``. - - INT8 accumulates into INT32; add a rescale to INT8 using SINGLE_ROUND - rounding and output zero-point. - - """ - - target = "aten.bmm.default" - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - """Define the TOSA ``MATMUL`` operator and optional rescale.""" - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, [*inputs, output], ts) - validate_valid_dtype( - self.target, - [*inputs, output], - [ts.DType.INT8, ts.DType.INT16, ts.DType.FP32], - output.tosa_spec, - ) - - # aten.bmm maps directly to MATMUL - - # For INT8, we need to get the zero points and add an intermediate tensor - # for a later rescale. - - if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) - input0_zp = input_qparams[0].get_zp_per_tensor() - input1_zp = input_qparams[1].get_zp_per_tensor() - bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) - bmm_output_name = bmm_result.name - elif inputs[0].dtype == ts.DType.INT16: - input_qparams = get_input_qparams(node) - input0_zp = input_qparams[0].get_zp_per_tensor() - input1_zp = input_qparams[1].get_zp_per_tensor() - bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT48) - bmm_output_name = bmm_result.name - else: - bmm_output_name = output.name - input0_zp, input1_zp = 0, 0 - - tosa_graph.addConst([1], inputs[0].dtype, [input0_zp], name=f"{node.name}_A_ZP") - tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=f"{node.name}_B_ZP") - - # Add the MATMUL to the TOSA graph. - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().MATMUL, - [ - inputs[0].name, - inputs[1].name, - f"{node.name}_A_ZP", - f"{node.name}_B_ZP", - ], - [bmm_output_name], - ) - - # As INT8 accumulates into INT32, we need to rescale it back to INT8 - if output.dtype == ts.DType.INT8: - output_qparams = get_output_qparams(node)[0] - final_output_scale = ( - input_qparams[0].get_scale_per_tensor() * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61] - ) / output_qparams.get_scale_per_tensor() - - build_rescale( - tosa_fb=tosa_graph, - scale=[final_output_scale], - # pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined. - input_node=bmm_result, # type: ignore[possibly-undefined] - output_name=output.name, - output_type=ts.DType.INT8, - input_zp=[0], - output_zp=[output_qparams.get_zp_per_tensor()], - rounding_mode=RoundingMode.SINGLE_ROUND, - ) - elif output.dtype == ts.DType.INT16: - output_qparams = get_output_qparams(node)[0] - final_output_scale = ( - input_qparams[0].get_scale_per_tensor() * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61] - ) / output_qparams.get_scale_per_tensor() - - build_rescale( - tosa_fb=tosa_graph, - scale=[final_output_scale], - # pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined. - input_node=bmm_result, # type: ignore[possibly-undefined] - output_name=output.name, - output_type=ts.DType.INT16, - input_zp=[0], - output_zp=[output_qparams.get_zp_per_tensor()], - rounding_mode=RoundingMode.SINGLE_ROUND, - ) diff --git a/backends/arm/operators/op_tosa_matmul.py b/backends/arm/operators/op_tosa_matmul.py new file mode 100644 index 00000000000..b177fd2ba37 --- /dev/null +++ b/backends/arm/operators/op_tosa_matmul.py @@ -0,0 +1,94 @@ +# Copyright 2024-2025 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +"""Provide a visitor for lowering batched matmul (BMM) to TOSA.""" + +from typing import Any, List + +import torch + +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.mapping import TosaArg + + +@register_node_visitor +class MatmulVisitor(NodeVisitor): + """Provide a visitor that serializes TOSA ``MATMUL``.""" + + target = "tosa.MATMUL.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + """Define the TOSA ``MATMUL`` operator.""" + import serializer.tosa_serializer as ts # type: ignore + + validate_num_inputs(self.target, inputs, 2) + validate_same_dtype(self.target, [*inputs], ts) + validate_valid_dtype( + self.target, + [*inputs], + [ts.DType.INT8, ts.DType.INT16, ts.DType.FP32], + output.tosa_spec, + ) + validate_valid_dtype( + self.target, + [output], + [ts.DType.INT32, ts.DType.INT48, ts.DType.FP32], + output.tosa_spec, + ) + + # We need to get the zero points and add an intermediate tensor for INT16 case + if inputs[0].dtype in (ts.DType.INT8, ts.DType.INT16): + input_qparams = get_input_qparams(node) + input0_zp = input_qparams[0].get_zp_per_tensor() + input1_zp = input_qparams[1].get_zp_per_tensor() + else: + input0_zp, input1_zp = 0, 0 + + input_A_ZP_name = f"{node.name}_A_ZP" + input_B_ZP_name = f"{node.name}_B_ZP" + tosa_graph.addConst([1], inputs[0].dtype, [input0_zp], name=input_A_ZP_name) + tosa_graph.addConst([1], inputs[1].dtype, [input1_zp], name=input_B_ZP_name) + + # Add the MATMUL to the TOSA graph. + self._serialize_operator( + node, + tosa_graph, + ts.TosaOp.Op().MATMUL, + [ + inputs[0].name, + inputs[1].name, + input_A_ZP_name, + input_B_ZP_name, + ], + [output.name], + ) diff --git a/backends/arm/operators/op_rescale.py b/backends/arm/operators/op_tosa_rescale.py similarity index 100% rename from backends/arm/operators/op_rescale.py rename to backends/arm/operators/op_tosa_rescale.py diff --git a/backends/arm/operators/op_resize.py b/backends/arm/operators/op_tosa_resize.py similarity index 100% rename from backends/arm/operators/op_resize.py rename to backends/arm/operators/op_tosa_resize.py diff --git a/backends/arm/operators/op_table.py b/backends/arm/operators/op_tosa_table.py similarity index 100% rename from backends/arm/operators/op_table.py rename to backends/arm/operators/op_tosa_table.py diff --git a/backends/arm/operators/op_transpose.py b/backends/arm/operators/op_tosa_transpose.py similarity index 100% rename from backends/arm/operators/op_transpose.py rename to backends/arm/operators/op_tosa_transpose.py diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index f1e3a29ac22..897de70279f 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401 + matmul, rescale, resize, table, diff --git a/backends/arm/tosa/dialect/ops/matmul.py b/backends/arm/tosa/dialect/ops/matmul.py new file mode 100644 index 00000000000..1ba3821f674 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/matmul.py @@ -0,0 +1,56 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op + +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +@register_fake_tosa_op( + "MATMUL(Tensor input1, Tensor input2) -> Tensor", # schema + ( + TosaSpecification.create_from_string("TOSA-1.0+INT"), + ), # target TOSA specifications +) +def MATMUL(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + tosa_spec = get_context_spec() + """Performs matrix multiplication on two input tensors. + Additionally validates TOSA constraints of a MATMUL op. + """ + if x1.dtype != x2.dtype: + raise TosaValueError( + f"Input tensors must have the same dtype, got {x1.dtype} and {x2.dtype}", + op="MATMUL", + ) + if x1.dtype in (torch.int8, torch.int16): + if not tosa_spec.support_integer(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support integers", op="MATMUL" + ) + else: + dtype = torch.int32 + elif x1.dtype in (torch.float16, torch.float32): + if not tosa_spec.support_float(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support float", op="MATMUL" + ) + else: + # float16 supports float16 accumulation as well + dtype = torch.float32 + else: + raise TosaValueError( + f"Input tensors must be of type int8, float16 or float32, got {x1.dtype}", + op="MATMUL", + ) + + aten_fake_tensor = exir_ops.edge.aten.bmm.default(x1, x2) + + return torch.empty_like(aten_fake_tensor, dtype=dtype) From b52fed9d4ec533aecfd0c0af2cec9b1d9316a77a Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Wed, 8 Oct 2025 15:58:19 +0200 Subject: [PATCH 2/2] Arm backend: Remove duplicate InsertTableOpsPass Signed-off-by: Oscar Andersson Change-Id: Iaa2c17c2f2a984970a91022cad2c49fbb7f3202e --- backends/arm/_passes/arm_pass_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index c86eaa25962..ef6d6e6810a 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -212,7 +212,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(RewriteUpsamplePass(exported_program)) self.add_pass(AddBiasPass(exported_program)) - self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(RewriteMatmulPass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(ToTosaMemoryFormatPass(exported_program))