From 2cda2ffd1cc4410a236d0924c7821a67d70f04ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Wed, 14 May 2025 09:40:11 +0200 Subject: [PATCH 1/5] Arm backend: Move rescale ops out of comparison visitors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Some TOSA ops do not support INT8 as inputs and outputs. Instead, only INT32 is supported as a whole number type. Prior to this patch, affected node visitors inserted rescale ops between the data types INT8 and INT32 before and after the operator such that it will accept its input and output. Change this by moving the insertion of the rescale ops to a new pass called InsertRescaleInt32Pass. This will further enable optimizations to the graph by fusing the rescale nodes. Only comparison operators are handled in this patch; the remaining ones are left out to be done in another patch. Signed-off-by: Martin Lindström Change-Id: I6bb8a10a0b453ae9fd8b8604d64cc5103a4da050 --- backends/arm/_passes/__init__.py | 2 +- backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/_passes/insert_rescales_pass.py | 191 +++++++++++++++++- backends/arm/operators/op_eq.py | 15 +- backends/arm/operators/op_ge.py | 15 +- backends/arm/operators/op_gt.py | 15 +- backends/arm/operators/op_le.py | 15 +- backends/arm/operators/op_lt.py | 15 +- .../passes/test_insert_rescale_i32_pass.py | 75 +++++++ 9 files changed, 272 insertions(+), 73 deletions(-) create mode 100644 backends/arm/test/passes/test_insert_rescale_i32_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 93bf20e69c1..008bc305aad 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -81,7 +81,7 @@ from .insert_int32_casts_after_int64_placeholders import ( # noqa InsertInt32CastsAfterInt64PlaceholdersPass, ) -from .insert_rescales_pass import InsertRescalePass # noqa +from .insert_rescales_pass import InsertRescaleInt32Pass, InsertRescalePass # noqa from .insert_table_ops import InsertTableOpsPass # noqa from .match_arg_dtype_pass import MatchArgDtypePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b7c511bbe0b..1a0f4e4d384 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -81,6 +81,7 @@ FuseEqualPlaceholdersPass, FuseQuantizedActivationPass, InsertInt32CastsAfterInt64PlaceholdersPass, + InsertRescaleInt32Pass, InsertRescalePass, InsertTableOpsPass, MatchArgDtypePass, @@ -214,6 +215,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) self.add_pass(InsertRescalePass()) + self.add_pass(InsertRescaleInt32Pass()) self.validate_constraints_mandatory() return self._transform(exported_program.graph_module) diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 100ac03c2b0..dac3a12f9a0 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -4,9 +4,14 @@ # LICENSE file in the root directory of this source tree. from copy import copy -from typing import cast, Set, Type +from typing import cast, Dict, Optional, Set, Tuple, Type -from executorch.backends.arm._passes.arm_pass_utils import create_node +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_output_qparams, +) from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.dialects._ops import ops as exir_ops @@ -65,3 +70,185 @@ def call(self, graph_module: GraphModule) -> PassResult: graph_module = super().call(graph_module).graph_module graph_module.recompile() return PassResult(graph_module, modified) + + +class InsertRescaleInt32Pass(ArmPass): + """ + Numerous TOSA ops require inputs and outputs to be 32-bit integers in their + quantized implementations. This pass treats such operator nodes by + inserting rescale ops before and after them if needed. Note that extra logic + that handles the scales and zero points must be in place because the affected + TOSA have naive implementations that do not account for the quantization + parameters. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + included_targets = [ + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.lt.Tensor, + ] + + def _get_rescale_qparams( + self, target, input_qparams: Dict[int, QuantArgs] + ) -> Tuple[Dict[int, QuantArgs], Optional[QuantArgs]]: + """ + Get the quantization parameters of the Int32 inputs/outputs that will + surround the node. + """ + + # Helper creator function for Int32-based QuantArgs + def int32_qargs(s): + return QuantArgs( + scale=s, + zp=0, + qmin=torch.iinfo(torch.int32).min, + qmax=torch.iinfo(torch.int32).max, + dtype=torch.int32, + ) + + if target in [ + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.lt.Tensor, + ]: + # Use the lowest scale of the operands since that yields the best numerical precision. + min_scale = min( + [qp.get_scale_per_tensor() for qp in input_qparams.values()] + ) + inputs_rescale_qparams = { + i: int32_qargs(min_scale) for i in range(len(input_qparams)) + } + + # Return None as output quant args since the output is not quantized (bool dtype) + return (inputs_rescale_qparams, None) + else: + raise ValueError(f"Unknown target: {target}") + + def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> bool: + qargs = node.meta["input_qparams"] + + args_copy = list(node.args) + seen_args = set() + modified = False + for i in qargs: + qp = qargs[i] + if qp.dtype != torch.int8: + continue + + arg_node = args_copy[i] + if arg_node in seen_args: + continue + seen_args.add(arg_node) + + with graph.inserting_after(arg_node): + rescale_node = create_node( + graph, + exir_ops.backend.tosa.RESCALE.default, + ( + arg_node, + torch.int32, + qp.get_scale_per_tensor() + / rescale_qargs[ + i + ].get_scale_per_tensor(), # Old scale / new scale + qp.get_zp_per_tensor(), # Old zero point + rescale_qargs[i].get_zp_per_tensor(), # New zero point + ), + from_node=node, + ) + + node.replace_input_with(arg_node, rescale_node) + modified = True + + return modified + + def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> bool: + if "output_qparams" not in node.meta or len(node.meta["output_qparams"]) == 0: + return False + + qargs = get_output_qparams(node) + assert len(qargs) == 1 + assert rescale_qargs is not None + + qarg = qargs[0] + if qarg.dtype != torch.int8: + return False + + users_copy = list(node.users) + + with graph.inserting_after(node): + rescale_node = create_node( + graph, + exir_ops.backend.tosa.RESCALE.default, + ( + node, + torch.int8, + rescale_qargs.get_scale_per_tensor() + / qarg.get_scale_per_tensor(), # Old scale / new scale + rescale_qargs.get_zp_per_tensor(), # Old zero point + qarg.get_zp_per_tensor(), # New zero point + ), + from_node=node, + ) + + for user in users_copy: + user.replace_input_with(node, rescale_node) + + return True + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + + modified = False + for node in list(graph.nodes): + node = cast(Node, node) + + if node.op != "call_function" or node.target not in self.included_targets: + continue + + if "input_qparams" not in node.meta or len(node.meta["input_qparams"]) == 0: + continue + input_qparams = node.meta["input_qparams"] + + inputs_rescale_qargs, output_rescale_qargs = self._get_rescale_qparams( + node.target, input_qparams + ) + + inputs_was_rescaled = self._rescale_inputs( + graph, node, inputs_rescale_qargs + ) + outputs_was_rescaled = False + if inputs_was_rescaled: + outputs_was_rescaled = self._rescale_outputs( + graph, node, output_rescale_qargs + ) + modified = True + + # Update node metadata + + if inputs_was_rescaled: + assert len(inputs_rescale_qargs) == len(node.meta["input_qparams"]) + node.meta["input_qparams"] = inputs_rescale_qargs + + if outputs_was_rescaled: + assert len(node.meta["output_qparams"]) == 1 + node.meta["output_qparams"] = {0: output_rescale_qargs} + + # If the output type is specified in the node, change it such + # that it matches the subsequent rescale node(s) that this node + # now has output edges to. + if "dtype" in node.kwargs: + set_node_arg(node, "dtype", torch.int32) + + if modified: + # Retrace the graph to update the fake tensor types + graph_module = super().call(graph_module).graph_module + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index 2136fe2e946..76b6e67cd8d 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -7,8 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -56,23 +54,12 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - # Do the equal comparison self._serialize_operator( node, tosa_graph, ts.TosaOp.Op().EQUAL, - [input_nodes[0].name, input_nodes[1].name], + [inputs[0].name, inputs[1].name], [output.name], None, ) diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index c538e735880..4bb20cac77f 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -7,8 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -56,22 +54,11 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - self._serialize_operator( node, tosa_graph, ts.TosaOp.Op().GREATER_EQUAL, - [input_nodes[0].name, input_nodes[1].name], + [inputs[0].name, inputs[1].name], [output.name], None, ) diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index d407e28c1b6..c25c959681e 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -7,8 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -56,22 +54,11 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - self._serialize_operator( node, tosa_graph, ts.TosaOp.Op().GREATER, - [input_nodes[0].name, input_nodes[1].name], + [inputs[0].name, inputs[1].name], [output.name], None, ) diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index 403c6c233d3..e62d669814f 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -7,8 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -56,22 +54,11 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - self._serialize_operator( node, tosa_graph, ts.TosaOp.Op().GREATER_EQUAL, - [input_nodes[1].name, input_nodes[0].name], + [inputs[1].name, inputs[0].name], [output.name], None, ) diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index f5132dd4feb..cccb0abd5d7 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -7,8 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -56,22 +54,11 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - self._serialize_operator( node, tosa_graph, ts.TosaOp.Op().GREATER, - [input_nodes[1].name, input_nodes[0].name], + [inputs[1].name, inputs[0].name], [output.name], None, ) diff --git a/backends/arm/test/passes/test_insert_rescale_i32_pass.py b/backends/arm/test/passes/test_insert_rescale_i32_pass.py new file mode 100644 index 00000000000..9ff6e3af42f --- /dev/null +++ b/backends/arm/test/passes/test_insert_rescale_i32_pass.py @@ -0,0 +1,75 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm._passes import ( + FoldAndAnnotateQParamsPass, + InsertRescaleInt32Pass, +) +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + + +class NeedsRescaleOps(torch.nn.Module): + """A module containing ops that require INT32 inputs/outputs.""" + + input_t = Tuple[torch.Tensor, torch.Tensor] + + def __init__(self): + super().__init__() + + def forward(self, x, y): + a = x > y + return a + + def get_inputs(self, dtype) -> input_t: + if dtype == torch.float32: + return (torch.rand(1, 3, 5, 6), torch.rand(1, 3, 5, 6)) + elif dtype == torch.int32: + return ( + torch.randint(3, 5, (3,), dtype=torch.int32), + torch.randint(3, 5, (3,), dtype=torch.int32), + ) + else: + raise ValueError("Not a valid input dtype for model") + + +def test_insert_rescales(): + module = NeedsRescaleOps() + input_t = Tuple[torch.Tensor, torch.Tensor] + ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} + ops_after = { + # "number of op nodes with i8 output" + "number of i8 node inputs" + "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 0 + + 2, + } + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(torch.float32), + quantize=True, + ops_not_before_pass=ops_not_before, + ops_after_pass=ops_after, + pass_list=[FoldAndAnnotateQParamsPass, InsertRescaleInt32Pass], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +def test_dont_insert_rescales(): + module = NeedsRescaleOps() + input_t = Tuple[torch.Tensor, torch.Tensor] + ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} + # All inputs are already i32. Rescales should not be added. + ops_not_after = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(torch.int32), + ops_not_before_pass=ops_not_before, + ops_not_after_pass=ops_not_after, + pass_list=[FoldAndAnnotateQParamsPass, InsertRescaleInt32Pass], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() From a647bc37e71c9318cad926ef47d19ae3d8f74be6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Thu, 21 Aug 2025 14:24:27 +0200 Subject: [PATCH 2/5] Arm backend: Move rescales from ABS visitor to pass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Martin Lindström Co-authored-by: Oscar Andersson Change-Id: I62fdc5bea75361d6c32711968bdc1c9d03677ccc --- backends/arm/_passes/insert_rescales_pass.py | 85 +++++++++++++----- backends/arm/operators/op_abs.py | 90 ++----------------- .../passes/test_insert_rescale_i32_pass.py | 9 +- 3 files changed, 76 insertions(+), 108 deletions(-) diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index dac3a12f9a0..97549897836 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -85,6 +85,7 @@ class InsertRescaleInt32Pass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() included_targets = [ + exir_ops.edge.aten.abs.default, exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.ge.Tensor, exir_ops.edge.aten.gt.Tensor, @@ -92,43 +93,85 @@ class InsertRescaleInt32Pass(ArmPass): exir_ops.edge.aten.lt.Tensor, ] - def _get_rescale_qparams( + def _int32_qargs(self, s): + """Helper creator function for INT32-based QuantArgs""" + + return QuantArgs( + scale=s, + zp=0, + qmin=torch.iinfo(torch.int32).min, + qmax=torch.iinfo(torch.int32).max, + dtype=torch.int32, + ) + + def _get_inputs_rescaled_qparams( self, target, input_qparams: Dict[int, QuantArgs] - ) -> Tuple[Dict[int, QuantArgs], Optional[QuantArgs]]: - """ - Get the quantization parameters of the Int32 inputs/outputs that will - surround the node. - """ + ) -> Dict[int, QuantArgs]: + """Get the qparams for the INT32 operands to the op ``target`` - # Helper creator function for Int32-based QuantArgs - def int32_qargs(s): - return QuantArgs( - scale=s, - zp=0, - qmin=torch.iinfo(torch.int32).min, - qmax=torch.iinfo(torch.int32).max, - dtype=torch.int32, - ) + Inputs to the INT32-based operator must be rescaled from INT8 to INT32. + This function computes the ``QuantArgs`` for each of the operands and returns + it as a dict, mapping tensor index to ``QuantArgs``. + """ if target in [ + exir_ops.edge.aten.abs.default, exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.ge.Tensor, exir_ops.edge.aten.gt.Tensor, exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.lt.Tensor, ]: - # Use the lowest scale of the operands since that yields the best numerical precision. + # For these ops, use the smallest scale among the INT8 operands. min_scale = min( [qp.get_scale_per_tensor() for qp in input_qparams.values()] ) - inputs_rescale_qparams = { - i: int32_qargs(min_scale) for i in range(len(input_qparams)) + qparams = { + i: self._int32_qargs(min_scale) for i in range(len(input_qparams)) } + else: + raise ValueError(f"Not a valid target: {target}") + + return qparams + + def _get_output_qparams( + self, target, inputs_qparams: Dict[int, QuantArgs] + ) -> Optional[QuantArgs]: + """Given an op ``target`` and the ``QuantArgs`` for each of its inputs, compute + the scale of the output based on how the operator itself affects it.""" - # Return None as output quant args since the output is not quantized (bool dtype) - return (inputs_rescale_qparams, None) + if target in [ + exir_ops.edge.aten.abs.default, + ]: + # The op has not altered the scale; the output scale is equal to + # the operands' scales. + return self._int32_qargs(inputs_qparams[0].get_scale_per_tensor()) + elif target in [ + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.lt.Tensor, + ]: + # Output is bool for these ops and thus no qparams are present + return None else: - raise ValueError(f"Unknown target: {target}") + raise ValueError(f"Not a valid target: {target}") + + def _get_rescale_qparams( + self, target, input_qparams: Dict[int, QuantArgs] + ) -> Tuple[Dict[int, QuantArgs], Optional[QuantArgs]]: + """ + Get the quantization parameters of the INT32 inputs/outputs that will + surround the node after the new RESCALE ops have been inserted. + """ + + inputs_rescaled_qparams = self._get_inputs_rescaled_qparams( + target, input_qparams + ) + output_qparams = self._get_output_qparams(target, inputs_rescaled_qparams) + + return (inputs_rescaled_qparams, output_qparams) def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> bool: qargs = node.meta["input_qparams"] diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index ec76eb5517f..943c4778867 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -6,9 +6,6 @@ # pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils -import executorch.backends.arm.tosa.utils as tutils - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -18,22 +15,20 @@ validate_same_dtype, validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @register_node_visitor -class AbsVisitor_INT(NodeVisitor): +class AbsVisitor(NodeVisitor): target = "aten.abs.default" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def __init__(self, *args): - super().__init__(*args) - def define_node( self, node: Node, @@ -47,89 +42,18 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) - # Handle int8 (quantized) and int32 validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - if inputs[0].dtype == ts.DType.INT8: - rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) # type: ignore[possibly-undefined] - else: - # input[0].dtype == ts.DType.INT32 - # Non quantized input, natively support by TOSA.abs - rescaled_inputs = inputs - - if output.dtype == ts.DType.INT8: - broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) - abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) - else: - # output.dtype == ts.DType.INT32 - abs_output = output - - # Do the INT32 Abs - self._serialize_operator( - node, - tosa_graph, + tosa_graph.addOperator( ts.TosaOp.Op().ABS, [ - rescaled_inputs[0].name, + inputs[0].name, ], - [abs_output.name], + [output.name], None, ) - - if output.dtype == ts.DType.INT8: - # Scale output back to 8 bit - # pyre-ignore - tqutils.insert_rescale_op_to_int8( - tosa_graph, abs_output, scale_back, node, self.tosa_spec - ) # type: ignore[possibly-undefined] - - -@register_node_visitor -class AbsVisitor_FP(AbsVisitor_INT): - # inheriting 'target' from BI class - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 1) - validate_same_dtype(self.target, [*inputs, output], ts) - - if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: - # Call the inherited define_node for handling integers - super().define_node(node, tosa_graph, inputs, output) - else: - # FP32 Abs lowering - - validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec - ) - - # MI lowering - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().ABS, - [inputs[0].name], - [output.name], - None, - ) diff --git a/backends/arm/test/passes/test_insert_rescale_i32_pass.py b/backends/arm/test/passes/test_insert_rescale_i32_pass.py index 9ff6e3af42f..a898ee70514 100644 --- a/backends/arm/test/passes/test_insert_rescale_i32_pass.py +++ b/backends/arm/test/passes/test_insert_rescale_i32_pass.py @@ -22,8 +22,9 @@ def __init__(self): super().__init__() def forward(self, x, y): - a = x > y - return a + a = torch.abs(x) + b = a > y + return b def get_inputs(self, dtype) -> input_t: if dtype == torch.float32: @@ -43,8 +44,8 @@ def test_insert_rescales(): ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} ops_after = { # "number of op nodes with i8 output" + "number of i8 node inputs" - "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 0 - + 2, + "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 1 + + 3, } pipeline = PassPipeline[input_t]( module, From 463ed4a2700c64a38c5af1418fda68eddb29dbd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Thu, 21 Aug 2025 14:24:33 +0200 Subject: [PATCH 3/5] Arm backend: Move rescales from MAXIMUM/MINIMUM visitor to pass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Martin Lindström Change-Id: I5ee8f97590ce599a9dfc60ced0775654fa565c4e --- backends/arm/_passes/insert_rescales_pass.py | 6 +++ backends/arm/operators/op_maximum.py | 48 +++---------------- backends/arm/operators/op_minimum.py | 45 ++--------------- .../passes/test_insert_rescale_i32_pass.py | 11 +++-- 4 files changed, 23 insertions(+), 87 deletions(-) diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 97549897836..d56e70e78b3 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -91,6 +91,8 @@ class InsertRescaleInt32Pass(ArmPass): exir_ops.edge.aten.gt.Tensor, exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.minimum.default, ] def _int32_qargs(self, s): @@ -121,6 +123,8 @@ def _get_inputs_rescaled_qparams( exir_ops.edge.aten.gt.Tensor, exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.maximum.default, ]: # For these ops, use the smallest scale among the INT8 operands. min_scale = min( @@ -142,6 +146,8 @@ def _get_output_qparams( if target in [ exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.minimum.default, ]: # The op has not altered the scale; the output scale is equal to # the operands' scales. diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index 66437f8af1d..50c6e06a4bb 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -7,12 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, -) - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -22,9 +16,8 @@ validate_same_dtype, validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.utils import tosa_shape +from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @@ -56,35 +49,12 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - max_output = output - if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) - if len(input_qparams) != 2: - raise ValueError( - f"Both inputs need to have quantization information for {node}" - ) - if input_qparams[0] != input_qparams[1]: - raise ValueError( - "Both inputs must have the same quantization parameters for MAX" - ) - - operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - output.shape = tosa_shape(output.shape, output.dim_order) - max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) - else: - operand_inputs = inputs - attr_maximum = ts.TosaSerializerAttribute() - - # Set to PROPOGATE as default + # Set to PROPAGATE as default attr_maximum.MaximumAttribute(nan_mode=NanPropagationMode.PROPAGATE) self._serialize_operator( @@ -92,15 +62,9 @@ def define_node( tosa_graph, ts.TosaOp.Op().MAXIMUM, [ - operand_inputs[0].name, - operand_inputs[1].name, + inputs[0].name, + inputs[1].name, ], - [max_output.name], + [output.name], attr_maximum, ) - - if output.dtype == ts.DType.INT8: - # insert RESCALE from int32 back to int8 - tqutils.insert_rescale_op_to_int8( - tosa_graph, max_output, scale_back, node, self.tosa_spec - ) diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index 518366d5463..d5b97f186d3 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -7,11 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, -) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -23,7 +18,6 @@ ) from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.utils import tosa_shape from torch.fx import Node @@ -55,35 +49,12 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - min_output = output - if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) - if len(input_qparams) != 2: - raise ValueError( - f"Both inputs need to have quantization information for {node}" - ) - if input_qparams[0] != input_qparams[1]: - raise ValueError( - "Both inputs must have the same quantization parameters for MIN" - ) - - operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - output.shape = tosa_shape(output.shape, output.dim_order) - min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) - else: - operand_inputs = inputs - attr_minimum = ts.TosaSerializerAttribute() - - # Set to PROPOGATE as default + # Set to PROPAGATE as default attr_minimum.MinimumAttribute(nan_mode=NanPropagationMode.PROPAGATE) self._serialize_operator( @@ -91,15 +62,9 @@ def define_node( tosa_graph, ts.TosaOp.Op().MINIMUM, [ - operand_inputs[0].name, - operand_inputs[1].name, + inputs[0].name, + inputs[1].name, ], - [min_output.name], + [output.name], attr_minimum, ) - - if output.dtype == ts.DType.INT8: - # insert RESCALE from int32 back to int8 - tqutils.insert_rescale_op_to_int8( - tosa_graph, min_output, scale_back, node, self.tosa_spec - ) diff --git a/backends/arm/test/passes/test_insert_rescale_i32_pass.py b/backends/arm/test/passes/test_insert_rescale_i32_pass.py index a898ee70514..096c90d330d 100644 --- a/backends/arm/test/passes/test_insert_rescale_i32_pass.py +++ b/backends/arm/test/passes/test_insert_rescale_i32_pass.py @@ -22,9 +22,10 @@ def __init__(self): super().__init__() def forward(self, x, y): - a = torch.abs(x) - b = a > y - return b + a = torch.maximum(x, y) + b = torch.abs(a) + c = a > b + return c def get_inputs(self, dtype) -> input_t: if dtype == torch.float32: @@ -44,8 +45,8 @@ def test_insert_rescales(): ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} ops_after = { # "number of op nodes with i8 output" + "number of i8 node inputs" - "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 1 - + 3, + "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 2 + + 5, } pipeline = PassPipeline[input_t]( module, From f21cf7ffd100e6a83bdc677c69d0dc8f7c5792d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Thu, 21 Aug 2025 14:24:36 +0200 Subject: [PATCH 4/5] Arm backend: Move rescales from SUB visitor to pass Signed-off-by: Martin Lindstroem Co-authored-by: Oscar Andersson Change-Id: I38d63015e03e59c267338c84d64731b050854d06 --- backends/arm/_passes/insert_rescales_pass.py | 35 +++++- backends/arm/operators/op_sub.py | 106 ++---------------- .../passes/test_insert_rescale_i32_pass.py | 13 ++- 3 files changed, 46 insertions(+), 108 deletions(-) diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index d56e70e78b3..0081265d31c 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -93,6 +93,7 @@ class InsertRescaleInt32Pass(ArmPass): exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.sub.Tensor, ] def _int32_qargs(self, s): @@ -133,6 +134,33 @@ def _get_inputs_rescaled_qparams( qparams = { i: self._int32_qargs(min_scale) for i in range(len(input_qparams)) } + elif target in [ + exir_ops.edge.aten.sub.Tensor, + ]: + if input_qparams[0].dtype != input_qparams[1].dtype: + raise ValueError( + "Mismatch in dtype args: {input_qparams[0].dtype} != {input_qparams[1].dtype}" + ) + + # We are handling two INT8 or two INT16 numbers. For INT8, if the + # zero point is non-null, the result will be in the range [-255; + # 255], therefore we need 9 bits for the result. We have a 32-bit + # accumulator, so we can divide the scale by (1 << 20) which is + # equivalent to shifting the INT8 operands 20 bits to the left + # before rescaling them both to 2 * max(lhs, rhs). + # + # For INT16, similary logic can be applied, but we instead end up + # with a left shift of 12. + lhs_scale, rhs_scale = ( + qp.get_scale_per_tensor() for qp in input_qparams.values() + ) + max_scale_2x = 2 * max(lhs_scale, rhs_scale) + + # Select shift based on input dtype. + shift_bits = 12 if input_qparams[0].dtype == torch.int16 else 20 + + scale = max_scale_2x / (1 << shift_bits) + qparams = {i: self._int32_qargs(scale) for i in range(len(input_qparams))} else: raise ValueError(f"Not a valid target: {target}") @@ -148,6 +176,7 @@ def _get_output_qparams( exir_ops.edge.aten.abs.default, exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.sub.Tensor, ]: # The op has not altered the scale; the output scale is equal to # the operands' scales. @@ -187,7 +216,7 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b modified = False for i in qargs: qp = qargs[i] - if qp.dtype != torch.int8: + if qp.dtype not in (torch.int8, torch.int16): continue arg_node = args_copy[i] @@ -226,7 +255,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b assert rescale_qargs is not None qarg = qargs[0] - if qarg.dtype != torch.int8: + if qarg.dtype not in (torch.int8, torch.int16): return False users_copy = list(node.users) @@ -237,7 +266,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b exir_ops.backend.tosa.RESCALE.default, ( node, - torch.int8, + qarg.dtype, rescale_qargs.get_scale_per_tensor() / qarg.get_scale_per_tensor(), # Old scale / new scale rescale_qargs.get_zp_per_tensor(), # Old zero point diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index 5f037dc3d1c..b3d1a0fb4d7 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -7,9 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils -import executorch.backends.arm.tosa.utils as tutils - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -19,22 +16,20 @@ validate_same_dtype, validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @register_node_visitor -class SubVisitor_INT(NodeVisitor): +class SubVisitor(NodeVisitor): target = "aten.sub.Tensor" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def __init__(self, *args): - super().__init__(*args) - def define_node( self, node: Node, @@ -50,105 +45,18 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - if inputs[0].dtype == ts.DType.INT8: - rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale( - tosa_graph, inputs, node, self.tosa_spec - ) - elif inputs[0].dtype == ts.DType.INT16: - rescaled_inputs, scale_back = ( - tqutils.insert_rescale_ops_int16_to_int32_maxscale( - tosa_graph, inputs, node, self.tosa_spec - ) - ) - else: - # input[0].dtype == ts.DType.INT32 - # Non quantized input, natively support by TOSA.SUB - rescaled_inputs = inputs - - if output.dtype in [ts.DType.INT8, ts.DType.INT16]: - broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) - sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) - else: - # output.dtype == ts.DType.INT32 - sub_output = output - - # Do the INT32 Sub self._serialize_operator( node, tosa_graph, ts.TosaOp.Op().SUB, [ - rescaled_inputs[0].name, - rescaled_inputs[1].name, + inputs[0].name, + inputs[1].name, ], - [sub_output.name], + [output.name], None, ) - - if output.dtype == ts.DType.INT8: - # Scale output back to 8 bit - # pyre-ignore - tqutils.insert_rescale_op_to_int8( - tosa_graph, - sub_output, - scale_back, - node, - compute_rescale=False, - tosa_spec=self.tosa_spec, - ) # type: ignore[possibly-undefined] - elif output.dtype == ts.DType.INT16: - tqutils.insert_rescale_op_to_int16( - tosa_graph, - sub_output, - scale_back, - node, - compute_rescale=False, - tosa_spec=self.tosa_spec, - ) # type: ignore[possibly-undefined] - - -@register_node_visitor -class SubVisitor_FP(SubVisitor_INT): - # inheriting 'target' from INT class - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, [*inputs, output], ts) - - if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: - # Call the inherited define_node for handling integers - super().define_node(node, tosa_graph, inputs, output) - else: - # FP32 Sub lowering - validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec - ) - - # MI lowering - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().SUB, - [inputs[0].name, inputs[1].name], - [output.name], - None, - ) diff --git a/backends/arm/test/passes/test_insert_rescale_i32_pass.py b/backends/arm/test/passes/test_insert_rescale_i32_pass.py index 096c90d330d..02223990217 100644 --- a/backends/arm/test/passes/test_insert_rescale_i32_pass.py +++ b/backends/arm/test/passes/test_insert_rescale_i32_pass.py @@ -22,10 +22,11 @@ def __init__(self): super().__init__() def forward(self, x, y): - a = torch.maximum(x, y) - b = torch.abs(a) - c = a > b - return c + a = x - y + c = torch.maximum(a, y) + d = torch.abs(c) + e = d > c + return e def get_inputs(self, dtype) -> input_t: if dtype == torch.float32: @@ -45,8 +46,8 @@ def test_insert_rescales(): ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} ops_after = { # "number of op nodes with i8 output" + "number of i8 node inputs" - "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 2 - + 5, + "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 3 + + 7, } pipeline = PassPipeline[input_t]( module, From 0b65f38c3ac7b0a9f1aa18dec4ec71f22ffa647c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Lindstr=C3=B6m?= Date: Tue, 7 Oct 2025 10:38:30 +0200 Subject: [PATCH 5/5] Revert "Arm backend: Move rescales from SUB visitor to pass" This reverts commit f21cf7ffd100e6a83bdc677c69d0dc8f7c5792d6. --- backends/arm/_passes/insert_rescales_pass.py | 35 +----- backends/arm/operators/op_sub.py | 106 ++++++++++++++++-- .../passes/test_insert_rescale_i32_pass.py | 13 +-- 3 files changed, 108 insertions(+), 46 deletions(-) diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 0081265d31c..d56e70e78b3 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -93,7 +93,6 @@ class InsertRescaleInt32Pass(ArmPass): exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, - exir_ops.edge.aten.sub.Tensor, ] def _int32_qargs(self, s): @@ -134,33 +133,6 @@ def _get_inputs_rescaled_qparams( qparams = { i: self._int32_qargs(min_scale) for i in range(len(input_qparams)) } - elif target in [ - exir_ops.edge.aten.sub.Tensor, - ]: - if input_qparams[0].dtype != input_qparams[1].dtype: - raise ValueError( - "Mismatch in dtype args: {input_qparams[0].dtype} != {input_qparams[1].dtype}" - ) - - # We are handling two INT8 or two INT16 numbers. For INT8, if the - # zero point is non-null, the result will be in the range [-255; - # 255], therefore we need 9 bits for the result. We have a 32-bit - # accumulator, so we can divide the scale by (1 << 20) which is - # equivalent to shifting the INT8 operands 20 bits to the left - # before rescaling them both to 2 * max(lhs, rhs). - # - # For INT16, similary logic can be applied, but we instead end up - # with a left shift of 12. - lhs_scale, rhs_scale = ( - qp.get_scale_per_tensor() for qp in input_qparams.values() - ) - max_scale_2x = 2 * max(lhs_scale, rhs_scale) - - # Select shift based on input dtype. - shift_bits = 12 if input_qparams[0].dtype == torch.int16 else 20 - - scale = max_scale_2x / (1 << shift_bits) - qparams = {i: self._int32_qargs(scale) for i in range(len(input_qparams))} else: raise ValueError(f"Not a valid target: {target}") @@ -176,7 +148,6 @@ def _get_output_qparams( exir_ops.edge.aten.abs.default, exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.minimum.default, - exir_ops.edge.aten.sub.Tensor, ]: # The op has not altered the scale; the output scale is equal to # the operands' scales. @@ -216,7 +187,7 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b modified = False for i in qargs: qp = qargs[i] - if qp.dtype not in (torch.int8, torch.int16): + if qp.dtype != torch.int8: continue arg_node = args_copy[i] @@ -255,7 +226,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b assert rescale_qargs is not None qarg = qargs[0] - if qarg.dtype not in (torch.int8, torch.int16): + if qarg.dtype != torch.int8: return False users_copy = list(node.users) @@ -266,7 +237,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b exir_ops.backend.tosa.RESCALE.default, ( node, - qarg.dtype, + torch.int8, rescale_qargs.get_scale_per_tensor() / qarg.get_scale_per_tensor(), # Old scale / new scale rescale_qargs.get_zp_per_tensor(), # Old zero point diff --git a/backends/arm/operators/op_sub.py b/backends/arm/operators/op_sub.py index b3d1a0fb4d7..5f037dc3d1c 100644 --- a/backends/arm/operators/op_sub.py +++ b/backends/arm/operators/op_sub.py @@ -7,6 +7,9 @@ from typing import Any, List +import executorch.backends.arm.tosa.quant_utils as tqutils +import executorch.backends.arm.tosa.utils as tutils + from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -16,20 +19,22 @@ validate_same_dtype, validate_valid_dtype, ) +from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @register_node_visitor -class SubVisitor(NodeVisitor): +class SubVisitor_INT(NodeVisitor): target = "aten.sub.Tensor" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+FP"), ] + def __init__(self, *args): + super().__init__(*args) + def define_node( self, node: Node, @@ -45,18 +50,105 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32], output.tosa_spec, ) + scale_back = 1.0 + if inputs[0].dtype == ts.DType.INT8: + rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale( + tosa_graph, inputs, node, self.tosa_spec + ) + elif inputs[0].dtype == ts.DType.INT16: + rescaled_inputs, scale_back = ( + tqutils.insert_rescale_ops_int16_to_int32_maxscale( + tosa_graph, inputs, node, self.tosa_spec + ) + ) + else: + # input[0].dtype == ts.DType.INT32 + # Non quantized input, natively support by TOSA.SUB + rescaled_inputs = inputs + + if output.dtype in [ts.DType.INT8, ts.DType.INT16]: + broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) + sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) + else: + # output.dtype == ts.DType.INT32 + sub_output = output + + # Do the INT32 Sub self._serialize_operator( node, tosa_graph, ts.TosaOp.Op().SUB, [ - inputs[0].name, - inputs[1].name, + rescaled_inputs[0].name, + rescaled_inputs[1].name, ], - [output.name], + [sub_output.name], None, ) + + if output.dtype == ts.DType.INT8: + # Scale output back to 8 bit + # pyre-ignore + tqutils.insert_rescale_op_to_int8( + tosa_graph, + sub_output, + scale_back, + node, + compute_rescale=False, + tosa_spec=self.tosa_spec, + ) # type: ignore[possibly-undefined] + elif output.dtype == ts.DType.INT16: + tqutils.insert_rescale_op_to_int16( + tosa_graph, + sub_output, + scale_back, + node, + compute_rescale=False, + tosa_spec=self.tosa_spec, + ) # type: ignore[possibly-undefined] + + +@register_node_visitor +class SubVisitor_FP(SubVisitor_INT): + # inheriting 'target' from INT class + + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + import serializer.tosa_serializer as ts # type: ignore + + validate_num_inputs(self.target, inputs, 2) + validate_same_dtype(self.target, [*inputs, output], ts) + + if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: + # Call the inherited define_node for handling integers + super().define_node(node, tosa_graph, inputs, output) + else: + # FP32 Sub lowering + validate_valid_dtype( + self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec + ) + + # MI lowering + self._serialize_operator( + node, + tosa_graph, + ts.TosaOp.Op().SUB, + [inputs[0].name, inputs[1].name], + [output.name], + None, + ) diff --git a/backends/arm/test/passes/test_insert_rescale_i32_pass.py b/backends/arm/test/passes/test_insert_rescale_i32_pass.py index 02223990217..096c90d330d 100644 --- a/backends/arm/test/passes/test_insert_rescale_i32_pass.py +++ b/backends/arm/test/passes/test_insert_rescale_i32_pass.py @@ -22,11 +22,10 @@ def __init__(self): super().__init__() def forward(self, x, y): - a = x - y - c = torch.maximum(a, y) - d = torch.abs(c) - e = d > c - return e + a = torch.maximum(x, y) + b = torch.abs(a) + c = a > b + return c def get_inputs(self, dtype) -> input_t: if dtype == torch.float32: @@ -46,8 +45,8 @@ def test_insert_rescales(): ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} ops_after = { # "number of op nodes with i8 output" + "number of i8 node inputs" - "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 3 - + 7, + "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 2 + + 5, } pipeline = PassPipeline[input_t]( module,