From 6967adea7a43b037f92f83ad484535681e438498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Tue, 29 Oct 2024 08:46:01 +0100 Subject: [PATCH 1/7] Add functions for usage with DQ/Q folding pass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reuse the logic from the node visiting quantization handling, but replace the quantization parameter fetching from the node meta values. Signed-off-by: Per Åstrand Change-Id: I9a7bbf6384284e60118756ec5661f6b11847aba7 --- backends/arm/tosa_quant_utils.py | 106 ++++++++++++++++++++++++++++--- 1 file changed, 97 insertions(+), 9 deletions(-) diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index 19397fe6b21..6a406aa4395 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -42,6 +42,81 @@ def register_passable_op(op): passable_ops.append(op) +def insert_rescale_ops_to_int32( + tosa_graph: ts.TosaSerializer, inputs: list[TosaArg], node: Node +) -> tuple[list[TosaSerializerTensor], float]: + """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. + The scales are adjusted using the smallest scale of all 'nodes'. + + Returns a list of the rescaled nodes and the scale factor used, + needed by rescale_node_back_to_int8. + + This functions is used in serialization to TOSA for target ops that are + handled by the DQ/D folding pass, which stores the quantization parameters + in the node meta dict as opposed to 'rescale_nodes_to_int32' which search + the graph upstream for DQ nodes. + """ + + tensors = inputs.copy() + + # Reshape tensor according to TOSA dim order + for tensor in tensors: + dim_order = tensor.dim_order + tensor.shape = [tensor.shape[i] for i in dim_order] + + qargs = list(cast(dict[int, QuantArgs], node.meta["input_qparams"]).values()) + + # Scale the int8 quantized input to a common scale in the integer + # domain + min_scale = min([qarg.scale for qarg in qargs]) + scales = [qarg.scale / min_scale for qarg in qargs] + + rescaled_nodes: list[TosaSerializerTensor] = [] + for tensor, qarg, scale in zip(tensors, qargs, scales): + rescaled_nodes.append( + build_rescale_to_int32( + tosa_graph, + tensor, + qarg.zp, + scale, + ) + ) + return rescaled_nodes, min_scale + + +def insert_rescale_node_back_to_int8( + tosa_graph: ts.TosaSerializer, + last_tensor: TosaArg, + scale: float, + node: Node, +) -> None: + """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. + Parameters: + node: The original node that is being handled by the rescales. + last_tensor:the tosa tensor to rescale back. + scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' + tosa_graph: the tosa_graph to manipulate. + + This functions is used in serialization to TOSA for target ops that are + handled by the DQ/D folding pass, which stores the quantization parameters + in the node meta dict as opposed to 'rescale_node_back_to_int8' which search + the graph downstream for Q nodes. + """ + assert len(node.meta["output_qparams"]) == 1 + + qargs_out = cast(dict[int, QuantArgs], node.meta["output_qparams"])[0] + output_rescale_scale = scale / qargs_out.scale + + # Rescale Back to INT8 + build_rescale_from_int32( + tosa_graph, + last_tensor.name, + node.name, + qargs_out.zp, + output_rescale_scale, + ) + + class QuantArgs(NamedTuple): scale: float zp: int @@ -61,6 +136,20 @@ def quantize_value(self, x): def dequantize_value(self, qx: int) -> float: return (qx - self.zp) * self.scale + @classmethod + def from_operator(cls, op, args): + if op in dq_q_ops: + return cls( + scale=cast(float, args[1]), + zp=cast(int, args[2]), + qmin=cast(int, args[3]), + qmax=cast(int, args[4]), + dtype=cast(torch.dtype, args[5]), + ) + else: + # We're only handling per tensor quantization + raise NotImplementedError + def quantize_value(x, qargs: QuantArgs, dtype=np.int8): return np.clip( @@ -77,13 +166,7 @@ def dequantize_value(qx, qargs: QuantArgs): def qargs_from_qnode(node: torch.fx.Node): assert node.target in dq_q_ops, f"Op {node} is not a quant node." - return QuantArgs( - scale=cast(float, node.args[1]), - zp=cast(int, node.args[2]), - qmin=cast(int, node.args[3]), - qmax=cast(int, node.args[4]), - dtype=cast(torch.dtype, node.args[5]), - ) + return QuantArgs.from_operator(node.target, node.args) def get_neighbour_quant_args( @@ -214,8 +297,13 @@ def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs: def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype: - if isinstance(node.target, Callable) and "tosa" in node.target.__name__: - return node.meta["val"].dtype + if isinstance(node.target, Callable) and "output_qparams" in node.meta.keys(): + # Check if the node has had it's quantization parameters folded + # and retrieve the dtype from the meta dict in that case. + assert len(node.meta["output_qparams"]) == 1 + qargs = cast(QuantArgs, node.meta["output_qparams"][0]) + return qargs.dtype + if node.target in dq_q_ops: return cast(torch.dtype, node.args[5]) From 86777b1f03e3dac750beeb70dc7ac1422bb4bae6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Tue, 5 Nov 2024 12:29:03 +0100 Subject: [PATCH 2/7] Introduce a quantization folding pass with annotations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fold DQ/Q nodes into the target operators specified to the pass. Signed-off-by: Per Åstrand Change-Id: I8a09dc0b887dd5f3915ca157f578ecf51772a1a2 --- .../fold_qdq_with_annotated_qparams_pass.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py new file mode 100644 index 00000000000..f1d861cf8b3 --- /dev/null +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -0,0 +1,105 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +from typing import Callable, cast, Iterable + +from executorch.backends.arm.tosa_quant_utils import QuantArgs + +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule, Node + + +class FoldAndAnnotateQParamsPass(ExportPass): + """ + A pass that walks the graph and removes any DQ and Q nodes before and after the target + node in the supplied list of operators. + The quantization parameters from the DQ/Q nodes are stored as meta values to be + accessible for later lowering and serialization passes. + The assumption is that the quantization annotatation adds DQ nodes for all tensor + inputs to the target one Q node to the output. + + Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability): + + x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8) + + x_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(x_q, 0.05487706884741783, -128, -128, 127, torch.int8) + aten_add_tensor: "f32[5]" = ops_aten_add_Tensor(x_dq, x_dq) + aten_add_tensor_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(aten_add_tensor, 0.05487706884741783, -128, -128, 127, torch.int8) + + output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8) + + Becomes: + x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8) + + aten_add_tensor: "i8[5]" = aten_add_Tensor(x_q, x_q) + + output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8) + + The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node. + + """ + + def __init__(self, targeted_ops: Iterable[Callable]): + super().__init__() + self.targeted_ops = targeted_ops + + def call(self, graph_module: GraphModule) -> PassResult: + q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + + # Loop over the graph nodes and find any node in the 'targeted_ops' list. + for n in graph_module.graph.nodes: + n = cast(Node, n) + if n.op != "call_function" or n.target not in self.targeted_ops: + continue + + # Make sure we haven't already set qparams meta information on the node + assert "input_qparams" not in n.meta.keys() + assert "output_qparams" not in n.meta.keys() + + # for the inputs and outputs search the graph for quantization info and + # store the information in a dict with order of the _tensor_ inputs as key, + # ignoring any other arguments to the target node. + n.meta["input_qparams"] = {} + n.meta["output_qparams"] = {} + for i, arg in enumerate(n.args): + if not isinstance(arg, Node): + continue + if arg.target != dq_op: + continue + + # arg.target for argument i is a dequant node, extract the information + n.meta["input_qparams"][i] = QuantArgs.from_operator( + arg.target, arg.args + ) + + # arg.args[0] is the tensor input, replace the input usage + n.replace_input_with(arg, arg.args[0]) + graph_module.graph.erase_node(arg) + + # Copy the users, since we are modifying it. + users_copy = copy.copy(n.users) + for i, user in enumerate(users_copy): + if user.target != q_op: + continue + + # quantization node found here, store the quantization parameters in meta value + n.meta["output_qparams"][i] = QuantArgs.from_operator( + user.target, user.args + ) + + user.replace_all_uses_with(n) + graph_module.graph.erase_node(user) + + # retrace the graph to update the fake tensor types + graph_module = super().call(graph_module).graph_module + + graph_module.recompile() + return PassResult(graph_module, True) From 0386b23126839dfa833b7f9da39e8215213647bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Mon, 28 Oct 2024 09:20:48 +0100 Subject: [PATCH 3/7] Add lowering of TOSA.MIN and TOSA.MAX MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Uses the fold DQ/Q pass to encapsulate the quantization information within the node. Signed-off-by: Per Åstrand Change-Id: I3adbab7e2a23a0208a03bbc423b38c15221a4959 --- backends/arm/_passes/arm_pass_manager.py | 12 ++ .../tosa_supported_operators.py | 2 + backends/arm/operators/__init__.py | 2 + backends/arm/operators/op_max.py | 81 +++++++++++ backends/arm/operators/op_min.py | 81 +++++++++++ backends/arm/quantizer/arm_quantizer.py | 2 + .../quantization_annotation/__init__.py | 1 + .../min_max_annotator.py | 46 ++++++ backends/arm/test/ops/test_maximum.py | 137 ++++++++++++++++++ backends/arm/test/ops/test_minimum.py | 137 ++++++++++++++++++ 10 files changed, 501 insertions(+) create mode 100644 backends/arm/operators/op_max.py create mode 100644 backends/arm/operators/op_min.py create mode 100644 backends/arm/quantizer/quantization_annotation/min_max_annotator.py create mode 100644 backends/arm/test/ops/test_maximum.py create mode 100644 backends/arm/test/ops/test_minimum.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 25811d077bb..64288b59d10 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -29,6 +29,9 @@ DecomposeSoftmaxesPass, ) from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + FoldAndAnnotateQParamsPass, +) from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import ( KeepDimsFalseToSqueezePass, ) @@ -50,6 +53,7 @@ from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_manager import PassManager @@ -80,6 +84,14 @@ def transform_to_backend_pipeline( self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) self.add_pass(DecomposeLinearPass()) + self.add_pass( + FoldAndAnnotateQParamsPass( + [ + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.maximum.default, + ] + ) + ) for spec in compile_spec: if spec.key == "permute_memory_format": memory_format = spec.value.decode() diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 7072ba6a827..2e6e0f9ad08 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -94,6 +94,8 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.relu.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 8c4aa85e579..6db9c968f09 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -19,7 +19,9 @@ op_get_item, op_hardtanh, op_log, + op_max, op_max_pool2d, + op_min, op_mm, op_mul, op_permute, diff --git a/backends/arm/operators/op_max.py b/backends/arm/operators/op_max.py new file mode 100644 index 00000000000..f6ac7b67eed --- /dev/null +++ b/backends/arm/operators/op_max.py @@ -0,0 +1,81 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import cast, List + +import executorch.backends.arm.tosa_quant_utils as tqutils + +import serializer.tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_utils import tosa_shape + +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class MaxVisitor(NodeVisitor): + target = "aten.maximum.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert inputs[0].dtype == inputs[1].dtype + + input_qparams = cast(dict[int, tqutils.QuantArgs], node.meta["input_qparams"]) + min_output = output + + if inputs[0].dtype == ts.DType.INT8: + # insert RESCALEs to int32 + x_scale = input_qparams[0].scale + x_zp = input_qparams[0].zp + + y_scale = input_qparams[1].scale + y_zp = input_qparams[1].zp + + assert ( + x_zp == y_zp + ), "Different zp for inputs, MAX should be quantized with shared quantization!" + assert ( + x_scale == y_scale + ), "Different scale for input, MAX should be quantized with shared quantization!" + + operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node + ) + + output.shape = tosa_shape(output.shape, output.dim_order) + min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) + else: + operand_inputs = inputs + + tosa_graph.addOperator( + TosaOp.Op().MAXIMUM, + [ + operand_inputs[0].name, + operand_inputs[1].name, + ], + [min_output.name], + ) + + if output.dtype == ts.DType.INT8: + # insert RESCALE from int32 back to int8 + tqutils.insert_rescale_node_back_to_int8( + tosa_graph, min_output, scale_back, node + ) diff --git a/backends/arm/operators/op_min.py b/backends/arm/operators/op_min.py new file mode 100644 index 00000000000..b48cce49f7b --- /dev/null +++ b/backends/arm/operators/op_min.py @@ -0,0 +1,81 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import cast, List + +import executorch.backends.arm.tosa_quant_utils as tqutils + +import serializer.tosa_serializer as ts +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_utils import tosa_shape + +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class MinVisitor(NodeVisitor): + target = "aten.minimum.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert inputs[0].dtype == inputs[1].dtype + + input_qparams = cast(dict[int, tqutils.QuantArgs], node.meta["input_qparams"]) + min_output = output + + if inputs[0].dtype == ts.DType.INT8: + # insert RESCALEs to int32 + x_scale = input_qparams[0].scale + x_zp = input_qparams[0].zp + + y_scale = input_qparams[1].scale + y_zp = input_qparams[1].zp + + assert ( + x_zp == y_zp + ), "Different zp for inputs, MIN should be quantized with shared quantization!" + assert ( + x_scale == y_scale + ), "Different scale for input, MIN should be quantized with shared quantization!" + + operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node + ) + + output.shape = tosa_shape(output.shape, output.dim_order) + min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) + else: + operand_inputs = inputs + + tosa_graph.addOperator( + TosaOp.Op().MINIMUM, + [ + operand_inputs[0].name, + operand_inputs[1].name, + ], + [min_output.name], + ) + + if output.dtype == ts.DType.INT8: + # insert RESCALE from int32 back to int8 + tqutils.insert_rescale_node_back_to_int8( + tosa_graph, min_output, scale_back, node + ) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 6f2a5689d39..8815d40b0b0 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -77,6 +77,7 @@ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPattern ], "mul": [[torch.mul]], "sub": [[torch.sub]], + "min_max": [[torch.min], [torch.max]], } return copy.deepcopy(supported_operators) @@ -267,6 +268,7 @@ class ArmQuantizer(Quantizer): "add", "sub", "mul", + "min_max", "mm", "one_to_one", "generic", diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py index 1201df51adc..d9d27cee2ac 100644 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ b/backends/arm/quantizer/quantization_annotation/__init__.py @@ -55,6 +55,7 @@ def decorator(annotator: AnnotatorType): generic_annotator, linear_annotator, max_pool2d_annotator, + min_max_annotator, mm_annotator, mul_annotator, one_to_one_annotator, diff --git a/backends/arm/quantizer/quantization_annotation/min_max_annotator.py b/backends/arm/quantizer/quantization_annotation/min_max_annotator.py new file mode 100644 index 00000000000..43c4d20c134 --- /dev/null +++ b/backends/arm/quantizer/quantization_annotation/min_max_annotator.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Callable, List, Optional + +import torch +from executorch.backends.arm.quantizer import arm_quantizer_utils +from executorch.backends.arm.quantizer.quantization_annotation import register_annotator +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from torch.ao.quantization.quantizer import QuantizationAnnotation +from torch.fx import GraphModule, Node + + +@register_annotator("min_max") +def _annotate_min_max( + gm: GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.target not in ( + torch.ops.aten.minimum.default, + torch.ops.aten.maximum.default, + ): + continue + annotated_partitions.append(node) + min_max_node = node + if arm_quantizer_utils.is_annotated(min_max_node): + continue + + input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec( + min_max_node, gm, quantization_config + ) + if input_qspec_map is not None: + min_max_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_qspec, + _annotated=True, + ) + return annotated_partitions diff --git a/backends/arm/test/ops/test_maximum.py b/backends/arm/test/ops/test_maximum.py new file mode 100644 index 00000000000..61e1cccd0be --- /dev/null +++ b/backends/arm/test/ops/test_maximum.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + + +class TestMaximum(unittest.TestCase): + """Tests a single maximum op""" + + class Maximum(torch.nn.Module): + test_parameters = [ + ( + torch.FloatTensor([1, 2, 3, 5, 7]), + (torch.FloatTensor([2, 1, 2, 1, 10])), + ), + (torch.ones(1, 10, 4, 6), 2 * torch.ones(1, 10, 4, 6)), + (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), + (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)), + (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), + ] + + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.maximum(x, y) + + _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( + _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. + ) + + def _test_maximum_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .check_count({"torch.ops.aten.maximum.default": 1}) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_maximum_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.maximum.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_maximum_ethos_BI_pipeline( + self, + module: torch.nn.Module, + compile_spec: CompileSpec, + test_data: Tuple[torch.Tensor], + ): + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .to_edge() + .partition() + .to_executorch() + .serialize() + ) + + return tester + + @parameterized.expand(Maximum.test_parameters) + def test_maximum_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_maximum_tosa_MI_pipeline(self.Maximum(), test_data) + + @parameterized.expand(Maximum.test_parameters) + def test_maximum_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_maximum_tosa_BI_pipeline(self.Maximum(), test_data) + + @parameterized.expand(Maximum.test_parameters) + @unittest.expectedFailure # Bug in Vela, disabled until pin changes, bug MLETORCH-513 + def test_maximum_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + tester = self._test_maximum_ethos_BI_pipeline( + self.Maximum(), common.get_u55_compile_spec(), test_data + ) + if common.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-300" + ) + + @parameterized.expand(Maximum.test_parameters) + def test_maximum_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + tester = self._test_maximum_ethos_BI_pipeline( + self.Maximum(), common.get_u85_compile_spec(), test_data + ) + if common.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-320" + ) diff --git a/backends/arm/test/ops/test_minimum.py b/backends/arm/test/ops/test_minimum.py new file mode 100644 index 00000000000..b63bf80f69c --- /dev/null +++ b/backends/arm/test/ops/test_minimum.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + + +class TestMinimum(unittest.TestCase): + """Tests a single minimum op""" + + class Minimum(torch.nn.Module): + test_parameters = [ + ( + torch.FloatTensor([1, 2, 3, 5, 7]), + (torch.FloatTensor([2, 1, 2, 1, 10])), + ), + (torch.ones(1, 10, 4, 6), 2 * torch.ones(1, 10, 4, 6)), + (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), + (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)), + (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), + ] + + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.minimum(x, y) + + _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( + _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. + ) + + def _test_minimum_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .check_count({"torch.ops.aten.minimum.default": 1}) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_minimum_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.minimum.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_minimum_ethos_BI_pipeline( + self, + module: torch.nn.Module, + compile_spec: CompileSpec, + test_data: Tuple[torch.Tensor], + ): + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .to_edge() + .partition() + .to_executorch() + .serialize() + ) + + return tester + + @parameterized.expand(Minimum.test_parameters) + def test_minimum_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_minimum_tosa_MI_pipeline(self.Minimum(), test_data) + + @parameterized.expand(Minimum.test_parameters) + def test_minimum_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_minimum_tosa_BI_pipeline(self.Minimum(), test_data) + + @parameterized.expand(Minimum.test_parameters) + @unittest.expectedFailure # Bug in Vela, disabled until pin changes, bug MLETORCH-513 + def test_minimum_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + tester = self._test_minimum_ethos_BI_pipeline( + self.Minimum(), common.get_u55_compile_spec(), test_data + ) + if common.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-300" + ) + + @parameterized.expand(Minimum.test_parameters) + def test_minimum_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + tester = self._test_minimum_ethos_BI_pipeline( + self.Minimum(), common.get_u85_compile_spec(), test_data + ) + if common.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-320" + ) From a8daea510fdd7e274a20bdd44099e5a195a0190f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Tue, 29 Oct 2024 08:45:20 +0100 Subject: [PATCH 4/7] Add ADD to qdq pass handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Per Åstrand Change-Id: I9230209ed3d6cc0b5ec7a35512248648bb8380ee --- backends/arm/_passes/arm_pass_manager.py | 1 + backends/arm/operators/op_add.py | 53 +++++++++++++----------- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 64288b59d10..f16a34a211d 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -89,6 +89,7 @@ def transform_to_backend_pipeline( [ exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.add.Tensor, ] ) ) diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index e52f3eddae7..f1056ec6ec0 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -11,7 +11,6 @@ import executorch.backends.arm.tosa_utils as tutils import serializer.tosa_serializer as ts -import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -41,33 +40,27 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - input_nodes = tutils.get_two_inputs(node) - - if not is_quant_node and not all( - tensor.meta["val"].dtype in (torch.int8, torch.int32) - for tensor in input_nodes - ): - raise RuntimeError( - f"Unexpected non quantized {AddVisitor_080_BI.target} node." - ) - - needs_rescale = not ( - all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes) - and node.meta["val"].dtype == torch.int32 - ) - - if needs_rescale: - # Rescale inputs to 32 bit - rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( - input_nodes, tosa_graph + # Specification (0.80.0) states that input and output types + # should all be the same + assert inputs[0].dtype == inputs[1].dtype == output.dtype + # Handle int8 (quantized) and int32 + assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32] + + if inputs[0].dtype == ts.DType.INT8: + rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node ) + else: + # input[0].dtype == ts.DType.INT32 + # Non quantized input, natively support by TOSA.ADD + rescaled_inputs = inputs - # Prepare add output tensor + if output.dtype == ts.DType.INT8: broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) else: + # output.dtype == ts.DType.INT32 add_output = output - rescaled_inputs = inputs # Do the INT32 Add tosa_graph.addOperator( @@ -80,10 +73,12 @@ def define_node( None, ) - if needs_rescale: + if output.dtype == ts.DType.INT8: # Scale output back to 8 bit # pyre-ignore - tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph) + tqutils.insert_rescale_node_back_to_int8( + tosa_graph, add_output, scale_back, node + ) @register_node_visitor @@ -105,11 +100,19 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - if is_quant_node: + # Specification (0.80.0) states that input and output types + # should all be the same + assert inputs[0].dtype == inputs[1].dtype == output.dtype + + if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: # Call the inherited define_node for handling integers super().define_node(node, tosa_graph, inputs, output, is_quant_node) else: # FP32 Add lowering + assert inputs[0].dtype == ts.DType.FP32 + assert output.dtype == ts.DType.FP32 + + # MI lowering tosa_graph.addOperator( TosaOp.Op().ADD, [inputs[0].name, inputs[1].name], From 2cbf05a39370c2c321c06366228200f74d3ec6b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Tue, 10 Dec 2024 09:51:58 +0100 Subject: [PATCH 5/7] Add test for fold qdq pass annotation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Per Åstrand Change-Id: I6154e13a5a6b75549862709d632ee6dd5c8b0e7f --- .../arm/test/passes/test_fold_qdq_pass.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 backends/arm/test/passes/test_fold_qdq_pass.py diff --git a/backends/arm/test/passes/test_fold_qdq_pass.py b/backends/arm/test/passes/test_fold_qdq_pass.py new file mode 100644 index 00000000000..4ff43e34382 --- /dev/null +++ b/backends/arm/test/passes/test_fold_qdq_pass.py @@ -0,0 +1,72 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + FoldAndAnnotateQParamsPass, +) + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + +from executorch.backends.xnnpack.test.tester.tester import RunPasses + +from executorch.exir.dialects._ops import ops as exir_ops + + +class SimpleQuantizeModel(torch.nn.Module): + def forward(self, x): + return x + x + + def get_inputs(self): + return (torch.rand(1, 1280, 7, 7),) + + +class FoldAndAnnotateQParamsPassTestClass(FoldAndAnnotateQParamsPass): + def __init__(self): + super(FoldAndAnnotateQParamsPassTestClass, self).__init__( + [exir_ops.edge.aten.add.Tensor] + ) + + +class TestFoldAndAnnotateQParamsPass(unittest.TestCase): + """ + Tests the FoldAndAnnotateQParamsPass which folds dq/q nodes into + the node and stores the quantization parameters in meta. + """ + + def test_fold_qdq_pass(self): + """ + Check that the pass runs for add operation and that one q node and one dq node + is removed from the representation. + """ + module = SimpleQuantizeModel() + test_pass_stage = RunPasses([FoldAndAnnotateQParamsPassTestClass]) + ( + ArmTester( + module, + example_inputs=module.get_inputs(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + ) + .quantize() + .export() + .to_edge() + .check_count( + { + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + } + ) + .run_passes(test_pass_stage) + .check_count( + { + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1, + } + ) + ) From ed236c350c4bfbebb1cbb94de22a0aaac1599105 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Thu, 12 Dec 2024 13:27:39 +0100 Subject: [PATCH 6/7] Add helper functions for Q/DQ folding pass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a helper function to retrieve QuantArgs from node.meta and cleanup the handling a bit by introducing the __eq__ operator for QuantArgs. Signed-off-by: Per Åstrand Change-Id: I519a9a286a36a278f40ffb6c679192a54d9f940d --- .../fold_qdq_with_annotated_qparams_pass.py | 26 ++++++++++++++ backends/arm/operators/op_add.py | 4 +-- backends/arm/operators/op_max.py | 35 ++++++++----------- backends/arm/operators/op_min.py | 28 ++++++--------- backends/arm/tosa_quant_utils.py | 29 ++++++++++++--- 5 files changed, 77 insertions(+), 45 deletions(-) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index f1d861cf8b3..6c86db8a0bf 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -16,6 +16,32 @@ from torch.fx import GraphModule, Node +def get_input_qparams(node: Node) -> dict[int, QuantArgs]: + """ + Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'. + Raises a ValueError if the node doesn't have any parameters set. + """ + if "input_qparams" not in node.meta.keys(): + raise ValueError(f"No input quantization parameter found in node {node}") + input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"]) + if len(input_qparams) == 0: + raise ValueError(f"No input quantization parameter found in node {node}") + return input_qparams + + +def get_output_qparams(node: Node) -> dict[int, QuantArgs]: + """ + Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'. + Raises a ValueError if the node doesn't have any parameters set. + """ + if "output_qparams" not in node.meta.keys(): + raise ValueError(f"No output quantization parameter found in node {node}") + input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"]) + if len(input_qparams) == 0: + raise ValueError(f"No output quantization parameter found in node {node}") + return input_qparams + + class FoldAndAnnotateQParamsPass(ExportPass): """ A pass that walks the graph and removes any DQ and Q nodes before and after the target diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index f1056ec6ec0..bdae16fffbd 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -76,9 +76,7 @@ def define_node( if output.dtype == ts.DType.INT8: # Scale output back to 8 bit # pyre-ignore - tqutils.insert_rescale_node_back_to_int8( - tosa_graph, add_output, scale_back, node - ) + tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) @register_node_visitor diff --git a/backends/arm/operators/op_max.py b/backends/arm/operators/op_max.py index f6ac7b67eed..61d889e0db7 100644 --- a/backends/arm/operators/op_max.py +++ b/backends/arm/operators/op_max.py @@ -5,11 +5,13 @@ # pyre-unsafe -from typing import cast, List +from typing import List import executorch.backends.arm.tosa_quant_utils as tqutils - import serializer.tosa_serializer as ts +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -38,30 +40,23 @@ def define_node( ) -> None: assert inputs[0].dtype == inputs[1].dtype - input_qparams = cast(dict[int, tqutils.QuantArgs], node.meta["input_qparams"]) - min_output = output - + max_output = output if inputs[0].dtype == ts.DType.INT8: - # insert RESCALEs to int32 - x_scale = input_qparams[0].scale - x_zp = input_qparams[0].zp - - y_scale = input_qparams[1].scale - y_zp = input_qparams[1].zp - + input_qparams = get_input_qparams(node) assert ( - x_zp == y_zp - ), "Different zp for inputs, MAX should be quantized with shared quantization!" + len(input_qparams) == 2 + ), f"Both inputs needs to have quantization information for {node}" + # insert RESCALEs to int32 assert ( - x_scale == y_scale - ), "Different scale for input, MAX should be quantized with shared quantization!" + input_qparams[0] == input_qparams[1] + ), "Both inputs must have same quantization for MAX" operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( tosa_graph, inputs, node ) output.shape = tosa_shape(output.shape, output.dim_order) - min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) + max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) else: operand_inputs = inputs @@ -71,11 +66,9 @@ def define_node( operand_inputs[0].name, operand_inputs[1].name, ], - [min_output.name], + [max_output.name], ) if output.dtype == ts.DType.INT8: # insert RESCALE from int32 back to int8 - tqutils.insert_rescale_node_back_to_int8( - tosa_graph, min_output, scale_back, node - ) + tqutils.insert_rescale_op_to_int8(tosa_graph, max_output, scale_back, node) diff --git a/backends/arm/operators/op_min.py b/backends/arm/operators/op_min.py index b48cce49f7b..6750ddd41fc 100644 --- a/backends/arm/operators/op_min.py +++ b/backends/arm/operators/op_min.py @@ -5,11 +5,14 @@ # pyre-unsafe -from typing import cast, List +from typing import List import executorch.backends.arm.tosa_quant_utils as tqutils import serializer.tosa_serializer as ts +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -38,23 +41,16 @@ def define_node( ) -> None: assert inputs[0].dtype == inputs[1].dtype - input_qparams = cast(dict[int, tqutils.QuantArgs], node.meta["input_qparams"]) min_output = output - if inputs[0].dtype == ts.DType.INT8: - # insert RESCALEs to int32 - x_scale = input_qparams[0].scale - x_zp = input_qparams[0].zp - - y_scale = input_qparams[1].scale - y_zp = input_qparams[1].zp - + input_qparams = get_input_qparams(node) assert ( - x_zp == y_zp - ), "Different zp for inputs, MIN should be quantized with shared quantization!" + len(input_qparams) == 2 + ), f"Both inputs needs to have quantization information for {node}" + # insert RESCALEs to int32 assert ( - x_scale == y_scale - ), "Different scale for input, MIN should be quantized with shared quantization!" + input_qparams[0] == input_qparams[1] + ), "Both inputs must have same quantization for MIN" operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( tosa_graph, inputs, node @@ -76,6 +72,4 @@ def define_node( if output.dtype == ts.DType.INT8: # insert RESCALE from int32 back to int8 - tqutils.insert_rescale_node_back_to_int8( - tosa_graph, min_output, scale_back, node - ) + tqutils.insert_rescale_op_to_int8(tosa_graph, min_output, scale_back, node) diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index 6a406aa4395..c3b4493b7c4 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -57,6 +57,10 @@ def insert_rescale_ops_to_int32( the graph upstream for DQ nodes. """ + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + ) + tensors = inputs.copy() # Reshape tensor according to TOSA dim order @@ -64,7 +68,8 @@ def insert_rescale_ops_to_int32( dim_order = tensor.dim_order tensor.shape = [tensor.shape[i] for i in dim_order] - qargs = list(cast(dict[int, QuantArgs], node.meta["input_qparams"]).values()) + input_qparams = get_input_qparams(node) + qargs = input_qparams.values() # Scale the int8 quantized input to a common scale in the integer # domain @@ -84,7 +89,7 @@ def insert_rescale_ops_to_int32( return rescaled_nodes, min_scale -def insert_rescale_node_back_to_int8( +def insert_rescale_op_to_int8( tosa_graph: ts.TosaSerializer, last_tensor: TosaArg, scale: float, @@ -102,9 +107,14 @@ def insert_rescale_node_back_to_int8( in the node meta dict as opposed to 'rescale_node_back_to_int8' which search the graph downstream for Q nodes. """ - assert len(node.meta["output_qparams"]) == 1 + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_output_qparams, + ) + + output_qparams = get_output_qparams(node) + assert len(output_qparams) == 1, "More than one output not supported" - qargs_out = cast(dict[int, QuantArgs], node.meta["output_qparams"])[0] + qargs_out = output_qparams[0] output_rescale_scale = scale / qargs_out.scale # Rescale Back to INT8 @@ -136,6 +146,17 @@ def quantize_value(self, x): def dequantize_value(self, qx: int) -> float: return (qx - self.zp) * self.scale + def __eq__(self, other): + if isinstance(other, QuantArgs): + return ( + self.scale == other.scale + and self.zp == other.zp + and self.qmin == other.qmin + and self.qmax == other.qmax + and self.dtype == other.dtype + ) + return False + @classmethod def from_operator(cls, op, args): if op in dq_q_ops: From 2a03d6fca978345844dfd24da381a1efc51c87eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Thu, 12 Dec 2024 13:57:50 +0100 Subject: [PATCH 7/7] Update Q/DQ Folding pass test to sequence of ops MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Per Åstrand Change-Id: I2d133f4347d9999c770e5337162c222368c212f2 --- backends/arm/test/passes/test_fold_qdq_pass.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/backends/arm/test/passes/test_fold_qdq_pass.py b/backends/arm/test/passes/test_fold_qdq_pass.py index 4ff43e34382..222070b4223 100644 --- a/backends/arm/test/passes/test_fold_qdq_pass.py +++ b/backends/arm/test/passes/test_fold_qdq_pass.py @@ -20,17 +20,20 @@ class SimpleQuantizeModel(torch.nn.Module): - def forward(self, x): - return x + x + def forward(self, x, y): + return x + torch.max((x + x), (y + y)) def get_inputs(self): - return (torch.rand(1, 1280, 7, 7),) + return (torch.rand(1, 1280, 7, 7), torch.rand(1, 1280, 7, 7)) class FoldAndAnnotateQParamsPassTestClass(FoldAndAnnotateQParamsPass): def __init__(self): super(FoldAndAnnotateQParamsPassTestClass, self).__init__( - [exir_ops.edge.aten.add.Tensor] + [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.maximum.default, + ] ) @@ -58,15 +61,15 @@ def test_fold_qdq_pass(self): .to_edge() .check_count( { - "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 2, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 6, } ) .run_passes(test_pass_stage) .check_count( { "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, - "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, } ) )