diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 93bf20e69c1..008bc305aad 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -81,7 +81,7 @@ from .insert_int32_casts_after_int64_placeholders import ( # noqa InsertInt32CastsAfterInt64PlaceholdersPass, ) -from .insert_rescales_pass import InsertRescalePass # noqa +from .insert_rescales_pass import InsertRescaleInt32Pass, InsertRescalePass # noqa from .insert_table_ops import InsertTableOpsPass # noqa from .match_arg_dtype_pass import MatchArgDtypePass # noqa from .match_arg_ranks_pass import MatchArgRanksPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b7c511bbe0b..1a0f4e4d384 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -81,6 +81,7 @@ FuseEqualPlaceholdersPass, FuseQuantizedActivationPass, InsertInt32CastsAfterInt64PlaceholdersPass, + InsertRescaleInt32Pass, InsertRescalePass, InsertTableOpsPass, MatchArgDtypePass, @@ -214,6 +215,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) self.add_pass(InsertRescalePass()) + self.add_pass(InsertRescaleInt32Pass()) self.validate_constraints_mandatory() return self._transform(exported_program.graph_module) diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 100ac03c2b0..d56e70e78b3 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -4,9 +4,14 @@ # LICENSE file in the root directory of this source tree. from copy import copy -from typing import cast, Set, Type +from typing import cast, Dict, Optional, Set, Tuple, Type -from executorch.backends.arm._passes.arm_pass_utils import create_node +import torch +from executorch.backends.arm._passes.arm_pass import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_output_qparams, +) from executorch.backends.arm._passes.quant_args import QuantArgs from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.dialects._ops import ops as exir_ops @@ -65,3 +70,234 @@ def call(self, graph_module: GraphModule) -> PassResult: graph_module = super().call(graph_module).graph_module graph_module.recompile() return PassResult(graph_module, modified) + + +class InsertRescaleInt32Pass(ArmPass): + """ + Numerous TOSA ops require inputs and outputs to be 32-bit integers in their + quantized implementations. This pass treats such operator nodes by + inserting rescale ops before and after them if needed. Note that extra logic + that handles the scales and zero points must be in place because the affected + TOSA have naive implementations that do not account for the quantization + parameters. + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + included_targets = [ + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.minimum.default, + ] + + def _int32_qargs(self, s): + """Helper creator function for INT32-based QuantArgs""" + + return QuantArgs( + scale=s, + zp=0, + qmin=torch.iinfo(torch.int32).min, + qmax=torch.iinfo(torch.int32).max, + dtype=torch.int32, + ) + + def _get_inputs_rescaled_qparams( + self, target, input_qparams: Dict[int, QuantArgs] + ) -> Dict[int, QuantArgs]: + """Get the qparams for the INT32 operands to the op ``target`` + + Inputs to the INT32-based operator must be rescaled from INT8 to INT32. + This function computes the ``QuantArgs`` for each of the operands and returns + it as a dict, mapping tensor index to ``QuantArgs``. + """ + + if target in [ + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.maximum.default, + ]: + # For these ops, use the smallest scale among the INT8 operands. + min_scale = min( + [qp.get_scale_per_tensor() for qp in input_qparams.values()] + ) + qparams = { + i: self._int32_qargs(min_scale) for i in range(len(input_qparams)) + } + else: + raise ValueError(f"Not a valid target: {target}") + + return qparams + + def _get_output_qparams( + self, target, inputs_qparams: Dict[int, QuantArgs] + ) -> Optional[QuantArgs]: + """Given an op ``target`` and the ``QuantArgs`` for each of its inputs, compute + the scale of the output based on how the operator itself affects it.""" + + if target in [ + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.minimum.default, + ]: + # The op has not altered the scale; the output scale is equal to + # the operands' scales. + return self._int32_qargs(inputs_qparams[0].get_scale_per_tensor()) + elif target in [ + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.lt.Tensor, + ]: + # Output is bool for these ops and thus no qparams are present + return None + else: + raise ValueError(f"Not a valid target: {target}") + + def _get_rescale_qparams( + self, target, input_qparams: Dict[int, QuantArgs] + ) -> Tuple[Dict[int, QuantArgs], Optional[QuantArgs]]: + """ + Get the quantization parameters of the INT32 inputs/outputs that will + surround the node after the new RESCALE ops have been inserted. + """ + + inputs_rescaled_qparams = self._get_inputs_rescaled_qparams( + target, input_qparams + ) + output_qparams = self._get_output_qparams(target, inputs_rescaled_qparams) + + return (inputs_rescaled_qparams, output_qparams) + + def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> bool: + qargs = node.meta["input_qparams"] + + args_copy = list(node.args) + seen_args = set() + modified = False + for i in qargs: + qp = qargs[i] + if qp.dtype != torch.int8: + continue + + arg_node = args_copy[i] + if arg_node in seen_args: + continue + seen_args.add(arg_node) + + with graph.inserting_after(arg_node): + rescale_node = create_node( + graph, + exir_ops.backend.tosa.RESCALE.default, + ( + arg_node, + torch.int32, + qp.get_scale_per_tensor() + / rescale_qargs[ + i + ].get_scale_per_tensor(), # Old scale / new scale + qp.get_zp_per_tensor(), # Old zero point + rescale_qargs[i].get_zp_per_tensor(), # New zero point + ), + from_node=node, + ) + + node.replace_input_with(arg_node, rescale_node) + modified = True + + return modified + + def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> bool: + if "output_qparams" not in node.meta or len(node.meta["output_qparams"]) == 0: + return False + + qargs = get_output_qparams(node) + assert len(qargs) == 1 + assert rescale_qargs is not None + + qarg = qargs[0] + if qarg.dtype != torch.int8: + return False + + users_copy = list(node.users) + + with graph.inserting_after(node): + rescale_node = create_node( + graph, + exir_ops.backend.tosa.RESCALE.default, + ( + node, + torch.int8, + rescale_qargs.get_scale_per_tensor() + / qarg.get_scale_per_tensor(), # Old scale / new scale + rescale_qargs.get_zp_per_tensor(), # Old zero point + qarg.get_zp_per_tensor(), # New zero point + ), + from_node=node, + ) + + for user in users_copy: + user.replace_input_with(node, rescale_node) + + return True + + def call(self, graph_module: GraphModule) -> PassResult: + graph = graph_module.graph + + modified = False + for node in list(graph.nodes): + node = cast(Node, node) + + if node.op != "call_function" or node.target not in self.included_targets: + continue + + if "input_qparams" not in node.meta or len(node.meta["input_qparams"]) == 0: + continue + input_qparams = node.meta["input_qparams"] + + inputs_rescale_qargs, output_rescale_qargs = self._get_rescale_qparams( + node.target, input_qparams + ) + + inputs_was_rescaled = self._rescale_inputs( + graph, node, inputs_rescale_qargs + ) + outputs_was_rescaled = False + if inputs_was_rescaled: + outputs_was_rescaled = self._rescale_outputs( + graph, node, output_rescale_qargs + ) + modified = True + + # Update node metadata + + if inputs_was_rescaled: + assert len(inputs_rescale_qargs) == len(node.meta["input_qparams"]) + node.meta["input_qparams"] = inputs_rescale_qargs + + if outputs_was_rescaled: + assert len(node.meta["output_qparams"]) == 1 + node.meta["output_qparams"] = {0: output_rescale_qargs} + + # If the output type is specified in the node, change it such + # that it matches the subsequent rescale node(s) that this node + # now has output edges to. + if "dtype" in node.kwargs: + set_node_arg(node, "dtype", torch.int32) + + if modified: + # Retrace the graph to update the fake tensor types + graph_module = super().call(graph_module).graph_module + graph_module.recompile() + + return PassResult(graph_module, modified) diff --git a/backends/arm/operators/op_abs.py b/backends/arm/operators/op_abs.py index ec76eb5517f..943c4778867 100644 --- a/backends/arm/operators/op_abs.py +++ b/backends/arm/operators/op_abs.py @@ -6,9 +6,6 @@ # pyre-unsafe from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils -import executorch.backends.arm.tosa.utils as tutils - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -18,22 +15,20 @@ validate_same_dtype, validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @register_node_visitor -class AbsVisitor_INT(NodeVisitor): +class AbsVisitor(NodeVisitor): target = "aten.abs.default" tosa_specs = [ TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def __init__(self, *args): - super().__init__(*args) - def define_node( self, node: Node, @@ -47,89 +42,18 @@ def define_node( validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [*inputs, output], ts) - # Handle int8 (quantized) and int32 validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - if inputs[0].dtype == ts.DType.INT8: - rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) # type: ignore[possibly-undefined] - else: - # input[0].dtype == ts.DType.INT32 - # Non quantized input, natively support by TOSA.abs - rescaled_inputs = inputs - - if output.dtype == ts.DType.INT8: - broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) - abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) - else: - # output.dtype == ts.DType.INT32 - abs_output = output - - # Do the INT32 Abs - self._serialize_operator( - node, - tosa_graph, + tosa_graph.addOperator( ts.TosaOp.Op().ABS, [ - rescaled_inputs[0].name, + inputs[0].name, ], - [abs_output.name], + [output.name], None, ) - - if output.dtype == ts.DType.INT8: - # Scale output back to 8 bit - # pyre-ignore - tqutils.insert_rescale_op_to_int8( - tosa_graph, abs_output, scale_back, node, self.tosa_spec - ) # type: ignore[possibly-undefined] - - -@register_node_visitor -class AbsVisitor_FP(AbsVisitor_INT): - # inheriting 'target' from BI class - - tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - - import serializer.tosa_serializer as ts # type: ignore - - validate_num_inputs(self.target, inputs, 1) - validate_same_dtype(self.target, [*inputs, output], ts) - - if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: - # Call the inherited define_node for handling integers - super().define_node(node, tosa_graph, inputs, output) - else: - # FP32 Abs lowering - - validate_valid_dtype( - self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec - ) - - # MI lowering - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().ABS, - [inputs[0].name], - [output.name], - None, - ) diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index 2136fe2e946..76b6e67cd8d 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -7,8 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -56,23 +54,12 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - # Do the equal comparison self._serialize_operator( node, tosa_graph, ts.TosaOp.Op().EQUAL, - [input_nodes[0].name, input_nodes[1].name], + [inputs[0].name, inputs[1].name], [output.name], None, ) diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index c538e735880..4bb20cac77f 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -7,8 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -56,22 +54,11 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - self._serialize_operator( node, tosa_graph, ts.TosaOp.Op().GREATER_EQUAL, - [input_nodes[0].name, input_nodes[1].name], + [inputs[0].name, inputs[1].name], [output.name], None, ) diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index d407e28c1b6..c25c959681e 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -7,8 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -56,22 +54,11 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - self._serialize_operator( node, tosa_graph, ts.TosaOp.Op().GREATER, - [input_nodes[0].name, input_nodes[1].name], + [inputs[0].name, inputs[1].name], [output.name], None, ) diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index 403c6c233d3..e62d669814f 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -7,8 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -56,22 +54,11 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - self._serialize_operator( node, tosa_graph, ts.TosaOp.Op().GREATER_EQUAL, - [input_nodes[1].name, input_nodes[0].name], + [inputs[1].name, inputs[0].name], [output.name], None, ) diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index f5132dd4feb..cccb0abd5d7 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -7,8 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -56,22 +54,11 @@ def define_node( ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) - input_nodes = inputs - # Handle quantization - if inputs[0].dtype == ts.DType.INT8: - # Rescale inputs to 32 bit - rescaled_inputs, _ = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - # Update IO - input_nodes = rescaled_inputs - self._serialize_operator( node, tosa_graph, ts.TosaOp.Op().GREATER, - [input_nodes[1].name, input_nodes[0].name], + [inputs[1].name, inputs[0].name], [output.name], None, ) diff --git a/backends/arm/operators/op_maximum.py b/backends/arm/operators/op_maximum.py index 66437f8af1d..50c6e06a4bb 100644 --- a/backends/arm/operators/op_maximum.py +++ b/backends/arm/operators/op_maximum.py @@ -7,12 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, -) - from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -22,9 +16,8 @@ validate_same_dtype, validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.utils import tosa_shape +from executorch.backends.arm.tosa.specification import TosaSpecification from torch.fx import Node @@ -56,35 +49,12 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - max_output = output - if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) - if len(input_qparams) != 2: - raise ValueError( - f"Both inputs need to have quantization information for {node}" - ) - if input_qparams[0] != input_qparams[1]: - raise ValueError( - "Both inputs must have the same quantization parameters for MAX" - ) - - operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - output.shape = tosa_shape(output.shape, output.dim_order) - max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) - else: - operand_inputs = inputs - attr_maximum = ts.TosaSerializerAttribute() - - # Set to PROPOGATE as default + # Set to PROPAGATE as default attr_maximum.MaximumAttribute(nan_mode=NanPropagationMode.PROPAGATE) self._serialize_operator( @@ -92,15 +62,9 @@ def define_node( tosa_graph, ts.TosaOp.Op().MAXIMUM, [ - operand_inputs[0].name, - operand_inputs[1].name, + inputs[0].name, + inputs[1].name, ], - [max_output.name], + [output.name], attr_maximum, ) - - if output.dtype == ts.DType.INT8: - # insert RESCALE from int32 back to int8 - tqutils.insert_rescale_op_to_int8( - tosa_graph, max_output, scale_back, node, self.tosa_spec - ) diff --git a/backends/arm/operators/op_minimum.py b/backends/arm/operators/op_minimum.py index 518366d5463..d5b97f186d3 100644 --- a/backends/arm/operators/op_minimum.py +++ b/backends/arm/operators/op_minimum.py @@ -7,11 +7,6 @@ from typing import Any, List -import executorch.backends.arm.tosa.quant_utils as tqutils - -from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( - get_input_qparams, -) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -23,7 +18,6 @@ ) from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.utils import tosa_shape from torch.fx import Node @@ -55,35 +49,12 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) - scale_back = 1.0 - min_output = output - if inputs[0].dtype == ts.DType.INT8: - input_qparams = get_input_qparams(node) - if len(input_qparams) != 2: - raise ValueError( - f"Both inputs need to have quantization information for {node}" - ) - if input_qparams[0] != input_qparams[1]: - raise ValueError( - "Both inputs must have the same quantization parameters for MIN" - ) - - operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( - tosa_graph, inputs, node, self.tosa_spec - ) - - output.shape = tosa_shape(output.shape, output.dim_order) - min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) - else: - operand_inputs = inputs - attr_minimum = ts.TosaSerializerAttribute() - - # Set to PROPOGATE as default + # Set to PROPAGATE as default attr_minimum.MinimumAttribute(nan_mode=NanPropagationMode.PROPAGATE) self._serialize_operator( @@ -91,15 +62,9 @@ def define_node( tosa_graph, ts.TosaOp.Op().MINIMUM, [ - operand_inputs[0].name, - operand_inputs[1].name, + inputs[0].name, + inputs[1].name, ], - [min_output.name], + [output.name], attr_minimum, ) - - if output.dtype == ts.DType.INT8: - # insert RESCALE from int32 back to int8 - tqutils.insert_rescale_op_to_int8( - tosa_graph, min_output, scale_back, node, self.tosa_spec - ) diff --git a/backends/arm/test/passes/test_insert_rescale_i32_pass.py b/backends/arm/test/passes/test_insert_rescale_i32_pass.py new file mode 100644 index 00000000000..096c90d330d --- /dev/null +++ b/backends/arm/test/passes/test_insert_rescale_i32_pass.py @@ -0,0 +1,77 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm._passes import ( + FoldAndAnnotateQParamsPass, + InsertRescaleInt32Pass, +) +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + + +class NeedsRescaleOps(torch.nn.Module): + """A module containing ops that require INT32 inputs/outputs.""" + + input_t = Tuple[torch.Tensor, torch.Tensor] + + def __init__(self): + super().__init__() + + def forward(self, x, y): + a = torch.maximum(x, y) + b = torch.abs(a) + c = a > b + return c + + def get_inputs(self, dtype) -> input_t: + if dtype == torch.float32: + return (torch.rand(1, 3, 5, 6), torch.rand(1, 3, 5, 6)) + elif dtype == torch.int32: + return ( + torch.randint(3, 5, (3,), dtype=torch.int32), + torch.randint(3, 5, (3,), dtype=torch.int32), + ) + else: + raise ValueError("Not a valid input dtype for model") + + +def test_insert_rescales(): + module = NeedsRescaleOps() + input_t = Tuple[torch.Tensor, torch.Tensor] + ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} + ops_after = { + # "number of op nodes with i8 output" + "number of i8 node inputs" + "executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 2 + + 5, + } + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(torch.float32), + quantize=True, + ops_not_before_pass=ops_not_before, + ops_after_pass=ops_after, + pass_list=[FoldAndAnnotateQParamsPass, InsertRescaleInt32Pass], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + +def test_dont_insert_rescales(): + module = NeedsRescaleOps() + input_t = Tuple[torch.Tensor, torch.Tensor] + ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} + # All inputs are already i32. Rescales should not be added. + ops_not_after = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"} + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(torch.int32), + ops_not_before_pass=ops_not_before, + ops_not_after_pass=ops_not_after, + pass_list=[FoldAndAnnotateQParamsPass, InsertRescaleInt32Pass], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run()