diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 25811d077bb..f16a34a211d 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -29,6 +29,9 @@ DecomposeSoftmaxesPass, ) from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + FoldAndAnnotateQParamsPass, +) from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import ( KeepDimsFalseToSqueezePass, ) @@ -50,6 +53,7 @@ from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_manager import PassManager @@ -80,6 +84,15 @@ def transform_to_backend_pipeline( self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) self.add_pass(DecomposeLinearPass()) + self.add_pass( + FoldAndAnnotateQParamsPass( + [ + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.maximum.default, + exir_ops.edge.aten.add.Tensor, + ] + ) + ) for spec in compile_spec: if spec.key == "permute_memory_format": memory_format = spec.value.decode() diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py new file mode 100644 index 00000000000..6c86db8a0bf --- /dev/null +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -0,0 +1,131 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +from typing import Callable, cast, Iterable + +from executorch.backends.arm.tosa_quant_utils import QuantArgs + +from executorch.exir.dialects._ops import ops as exir_ops + +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import GraphModule, Node + + +def get_input_qparams(node: Node) -> dict[int, QuantArgs]: + """ + Get the input quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'. + Raises a ValueError if the node doesn't have any parameters set. + """ + if "input_qparams" not in node.meta.keys(): + raise ValueError(f"No input quantization parameter found in node {node}") + input_qparams = cast(dict[int, QuantArgs], node.meta["input_qparams"]) + if len(input_qparams) == 0: + raise ValueError(f"No input quantization parameter found in node {node}") + return input_qparams + + +def get_output_qparams(node: Node) -> dict[int, QuantArgs]: + """ + Get the output quantization parameters from a node, set by the 'FoldAndAnnotateQParamsPass'. + Raises a ValueError if the node doesn't have any parameters set. + """ + if "output_qparams" not in node.meta.keys(): + raise ValueError(f"No output quantization parameter found in node {node}") + input_qparams = cast(dict[int, QuantArgs], node.meta["output_qparams"]) + if len(input_qparams) == 0: + raise ValueError(f"No output quantization parameter found in node {node}") + return input_qparams + + +class FoldAndAnnotateQParamsPass(ExportPass): + """ + A pass that walks the graph and removes any DQ and Q nodes before and after the target + node in the supplied list of operators. + The quantization parameters from the DQ/Q nodes are stored as meta values to be + accessible for later lowering and serialization passes. + The assumption is that the quantization annotatation adds DQ nodes for all tensor + inputs to the target one Q node to the output. + + Example ('executorch_exir_dialects_edge__ops_' prefix removed from operators for readability): + + x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8) + + x_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(x_q, 0.05487706884741783, -128, -128, 127, torch.int8) + aten_add_tensor: "f32[5]" = ops_aten_add_Tensor(x_dq, x_dq) + aten_add_tensor_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(aten_add_tensor, 0.05487706884741783, -128, -128, 127, torch.int8) + + output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8) + + Becomes: + x_q: "i8[5]" = quantized_decomposed_quantize_per_tensor_default(x, 0.05487706884741783, -128, -128, 127, torch.int8) + + aten_add_tensor: "i8[5]" = aten_add_Tensor(x_q, x_q) + + output_dq: "f32[5]" = quantized_decomposed_dequantize_per_tensor_default(aten_add_tensor_q, 0.05487706884741783, -128, -128, 127, torch.int8) + + The quantization parameters for x_dq and aten_add_tensor_q are store in meta for the aten_add_tensor node. + + """ + + def __init__(self, targeted_ops: Iterable[Callable]): + super().__init__() + self.targeted_ops = targeted_ops + + def call(self, graph_module: GraphModule) -> PassResult: + q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + + # Loop over the graph nodes and find any node in the 'targeted_ops' list. + for n in graph_module.graph.nodes: + n = cast(Node, n) + if n.op != "call_function" or n.target not in self.targeted_ops: + continue + + # Make sure we haven't already set qparams meta information on the node + assert "input_qparams" not in n.meta.keys() + assert "output_qparams" not in n.meta.keys() + + # for the inputs and outputs search the graph for quantization info and + # store the information in a dict with order of the _tensor_ inputs as key, + # ignoring any other arguments to the target node. + n.meta["input_qparams"] = {} + n.meta["output_qparams"] = {} + for i, arg in enumerate(n.args): + if not isinstance(arg, Node): + continue + if arg.target != dq_op: + continue + + # arg.target for argument i is a dequant node, extract the information + n.meta["input_qparams"][i] = QuantArgs.from_operator( + arg.target, arg.args + ) + + # arg.args[0] is the tensor input, replace the input usage + n.replace_input_with(arg, arg.args[0]) + graph_module.graph.erase_node(arg) + + # Copy the users, since we are modifying it. + users_copy = copy.copy(n.users) + for i, user in enumerate(users_copy): + if user.target != q_op: + continue + + # quantization node found here, store the quantization parameters in meta value + n.meta["output_qparams"][i] = QuantArgs.from_operator( + user.target, user.args + ) + + user.replace_all_uses_with(n) + graph_module.graph.erase_node(user) + + # retrace the graph to update the fake tensor types + graph_module = super().call(graph_module).graph_module + + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 7072ba6a827..2e6e0f9ad08 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -94,6 +94,8 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool: exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.minimum.default, + exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.repeat.default, exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.relu.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 8c4aa85e579..6db9c968f09 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -19,7 +19,9 @@ op_get_item, op_hardtanh, op_log, + op_max, op_max_pool2d, + op_min, op_mm, op_mul, op_permute, diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index e52f3eddae7..bdae16fffbd 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -11,7 +11,6 @@ import executorch.backends.arm.tosa_utils as tutils import serializer.tosa_serializer as ts -import torch from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, @@ -41,33 +40,27 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - input_nodes = tutils.get_two_inputs(node) - - if not is_quant_node and not all( - tensor.meta["val"].dtype in (torch.int8, torch.int32) - for tensor in input_nodes - ): - raise RuntimeError( - f"Unexpected non quantized {AddVisitor_080_BI.target} node." - ) - - needs_rescale = not ( - all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes) - and node.meta["val"].dtype == torch.int32 - ) - - if needs_rescale: - # Rescale inputs to 32 bit - rescaled_inputs, scale = tqutils.rescale_nodes_to_int32( - input_nodes, tosa_graph + # Specification (0.80.0) states that input and output types + # should all be the same + assert inputs[0].dtype == inputs[1].dtype == output.dtype + # Handle int8 (quantized) and int32 + assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32] + + if inputs[0].dtype == ts.DType.INT8: + rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node ) + else: + # input[0].dtype == ts.DType.INT32 + # Non quantized input, natively support by TOSA.ADD + rescaled_inputs = inputs - # Prepare add output tensor + if output.dtype == ts.DType.INT8: broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order) add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32) else: + # output.dtype == ts.DType.INT32 add_output = output - rescaled_inputs = inputs # Do the INT32 Add tosa_graph.addOperator( @@ -80,10 +73,10 @@ def define_node( None, ) - if needs_rescale: + if output.dtype == ts.DType.INT8: # Scale output back to 8 bit # pyre-ignore - tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph) + tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) @register_node_visitor @@ -105,11 +98,19 @@ def define_node( output: TosaArg, is_quant_node: bool, ) -> None: - if is_quant_node: + # Specification (0.80.0) states that input and output types + # should all be the same + assert inputs[0].dtype == inputs[1].dtype == output.dtype + + if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]: # Call the inherited define_node for handling integers super().define_node(node, tosa_graph, inputs, output, is_quant_node) else: # FP32 Add lowering + assert inputs[0].dtype == ts.DType.FP32 + assert output.dtype == ts.DType.FP32 + + # MI lowering tosa_graph.addOperator( TosaOp.Op().ADD, [inputs[0].name, inputs[1].name], diff --git a/backends/arm/operators/op_max.py b/backends/arm/operators/op_max.py new file mode 100644 index 00000000000..61d889e0db7 --- /dev/null +++ b/backends/arm/operators/op_max.py @@ -0,0 +1,74 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import List + +import executorch.backends.arm.tosa_quant_utils as tqutils +import serializer.tosa_serializer as ts +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_utils import tosa_shape + +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class MaxVisitor(NodeVisitor): + target = "aten.maximum.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert inputs[0].dtype == inputs[1].dtype + + max_output = output + if inputs[0].dtype == ts.DType.INT8: + input_qparams = get_input_qparams(node) + assert ( + len(input_qparams) == 2 + ), f"Both inputs needs to have quantization information for {node}" + # insert RESCALEs to int32 + assert ( + input_qparams[0] == input_qparams[1] + ), "Both inputs must have same quantization for MAX" + + operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node + ) + + output.shape = tosa_shape(output.shape, output.dim_order) + max_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) + else: + operand_inputs = inputs + + tosa_graph.addOperator( + TosaOp.Op().MAXIMUM, + [ + operand_inputs[0].name, + operand_inputs[1].name, + ], + [max_output.name], + ) + + if output.dtype == ts.DType.INT8: + # insert RESCALE from int32 back to int8 + tqutils.insert_rescale_op_to_int8(tosa_graph, max_output, scale_back, node) diff --git a/backends/arm/operators/op_min.py b/backends/arm/operators/op_min.py new file mode 100644 index 00000000000..6750ddd41fc --- /dev/null +++ b/backends/arm/operators/op_min.py @@ -0,0 +1,75 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import List + +import executorch.backends.arm.tosa_quant_utils as tqutils + +import serializer.tosa_serializer as ts +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_utils import tosa_shape + +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class MinVisitor(NodeVisitor): + target = "aten.minimum.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert inputs[0].dtype == inputs[1].dtype + + min_output = output + if inputs[0].dtype == ts.DType.INT8: + input_qparams = get_input_qparams(node) + assert ( + len(input_qparams) == 2 + ), f"Both inputs needs to have quantization information for {node}" + # insert RESCALEs to int32 + assert ( + input_qparams[0] == input_qparams[1] + ), "Both inputs must have same quantization for MIN" + + operand_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( + tosa_graph, inputs, node + ) + + output.shape = tosa_shape(output.shape, output.dim_order) + min_output = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) + else: + operand_inputs = inputs + + tosa_graph.addOperator( + TosaOp.Op().MINIMUM, + [ + operand_inputs[0].name, + operand_inputs[1].name, + ], + [min_output.name], + ) + + if output.dtype == ts.DType.INT8: + # insert RESCALE from int32 back to int8 + tqutils.insert_rescale_op_to_int8(tosa_graph, min_output, scale_back, node) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 6f2a5689d39..8815d40b0b0 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -77,6 +77,7 @@ def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPattern ], "mul": [[torch.mul]], "sub": [[torch.sub]], + "min_max": [[torch.min], [torch.max]], } return copy.deepcopy(supported_operators) @@ -267,6 +268,7 @@ class ArmQuantizer(Quantizer): "add", "sub", "mul", + "min_max", "mm", "one_to_one", "generic", diff --git a/backends/arm/quantizer/quantization_annotation/__init__.py b/backends/arm/quantizer/quantization_annotation/__init__.py index 1201df51adc..d9d27cee2ac 100644 --- a/backends/arm/quantizer/quantization_annotation/__init__.py +++ b/backends/arm/quantizer/quantization_annotation/__init__.py @@ -55,6 +55,7 @@ def decorator(annotator: AnnotatorType): generic_annotator, linear_annotator, max_pool2d_annotator, + min_max_annotator, mm_annotator, mul_annotator, one_to_one_annotator, diff --git a/backends/arm/quantizer/quantization_annotation/min_max_annotator.py b/backends/arm/quantizer/quantization_annotation/min_max_annotator.py new file mode 100644 index 00000000000..43c4d20c134 --- /dev/null +++ b/backends/arm/quantizer/quantization_annotation/min_max_annotator.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import Callable, List, Optional + +import torch +from executorch.backends.arm.quantizer import arm_quantizer_utils +from executorch.backends.arm.quantizer.quantization_annotation import register_annotator +from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig +from torch.ao.quantization.quantizer import QuantizationAnnotation +from torch.fx import GraphModule, Node + + +@register_annotator("min_max") +def _annotate_min_max( + gm: GraphModule, + quantization_config: QuantizationConfig, + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> Optional[List[List[Node]]]: + annotated_partitions = [] + for node in gm.graph.nodes: + if node.target not in ( + torch.ops.aten.minimum.default, + torch.ops.aten.maximum.default, + ): + continue + annotated_partitions.append(node) + min_max_node = node + if arm_quantizer_utils.is_annotated(min_max_node): + continue + + input_qspec_map, output_qspec = arm_quantizer_utils.get_shared_qspec( + min_max_node, gm, quantization_config + ) + if input_qspec_map is not None: + min_max_node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_qspec, + _annotated=True, + ) + return annotated_partitions diff --git a/backends/arm/test/ops/test_maximum.py b/backends/arm/test/ops/test_maximum.py new file mode 100644 index 00000000000..61e1cccd0be --- /dev/null +++ b/backends/arm/test/ops/test_maximum.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + + +class TestMaximum(unittest.TestCase): + """Tests a single maximum op""" + + class Maximum(torch.nn.Module): + test_parameters = [ + ( + torch.FloatTensor([1, 2, 3, 5, 7]), + (torch.FloatTensor([2, 1, 2, 1, 10])), + ), + (torch.ones(1, 10, 4, 6), 2 * torch.ones(1, 10, 4, 6)), + (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), + (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)), + (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), + ] + + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.maximum(x, y) + + _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( + _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. + ) + + def _test_maximum_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .check_count({"torch.ops.aten.maximum.default": 1}) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_maximum_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.maximum.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_maximum_ethos_BI_pipeline( + self, + module: torch.nn.Module, + compile_spec: CompileSpec, + test_data: Tuple[torch.Tensor], + ): + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .to_edge() + .partition() + .to_executorch() + .serialize() + ) + + return tester + + @parameterized.expand(Maximum.test_parameters) + def test_maximum_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_maximum_tosa_MI_pipeline(self.Maximum(), test_data) + + @parameterized.expand(Maximum.test_parameters) + def test_maximum_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_maximum_tosa_BI_pipeline(self.Maximum(), test_data) + + @parameterized.expand(Maximum.test_parameters) + @unittest.expectedFailure # Bug in Vela, disabled until pin changes, bug MLETORCH-513 + def test_maximum_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + tester = self._test_maximum_ethos_BI_pipeline( + self.Maximum(), common.get_u55_compile_spec(), test_data + ) + if common.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-300" + ) + + @parameterized.expand(Maximum.test_parameters) + def test_maximum_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + tester = self._test_maximum_ethos_BI_pipeline( + self.Maximum(), common.get_u85_compile_spec(), test_data + ) + if common.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-320" + ) diff --git a/backends/arm/test/ops/test_minimum.py b/backends/arm/test/ops/test_minimum.py new file mode 100644 index 00000000000..b63bf80f69c --- /dev/null +++ b/backends/arm/test/ops/test_minimum.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.compile_spec_schema import CompileSpec +from parameterized import parameterized + + +class TestMinimum(unittest.TestCase): + """Tests a single minimum op""" + + class Minimum(torch.nn.Module): + test_parameters = [ + ( + torch.FloatTensor([1, 2, 3, 5, 7]), + (torch.FloatTensor([2, 1, 2, 1, 10])), + ), + (torch.ones(1, 10, 4, 6), 2 * torch.ones(1, 10, 4, 6)), + (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)), + (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)), + (10000 * torch.randn(1, 1, 4, 4), torch.randn(1, 1, 4, 1)), + ] + + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.minimum(x, y) + + _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( + _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. + ) + + def _test_minimum_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), + ) + .export() + .check_count({"torch.ops.aten.minimum.default": 1}) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data) + ) + + def _test_minimum_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + ) + .quantize() + .export() + .check_count({"torch.ops.aten.minimum.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge(config=self._edge_compile_config) + .partition() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_minimum_ethos_BI_pipeline( + self, + module: torch.nn.Module, + compile_spec: CompileSpec, + test_data: Tuple[torch.Tensor], + ): + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize() + .export() + .to_edge() + .partition() + .to_executorch() + .serialize() + ) + + return tester + + @parameterized.expand(Minimum.test_parameters) + def test_minimum_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_minimum_tosa_MI_pipeline(self.Minimum(), test_data) + + @parameterized.expand(Minimum.test_parameters) + def test_minimum_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_minimum_tosa_BI_pipeline(self.Minimum(), test_data) + + @parameterized.expand(Minimum.test_parameters) + @unittest.expectedFailure # Bug in Vela, disabled until pin changes, bug MLETORCH-513 + def test_minimum_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + tester = self._test_minimum_ethos_BI_pipeline( + self.Minimum(), common.get_u55_compile_spec(), test_data + ) + if common.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-300" + ) + + @parameterized.expand(Minimum.test_parameters) + def test_minimum_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + tester = self._test_minimum_ethos_BI_pipeline( + self.Minimum(), common.get_u85_compile_spec(), test_data + ) + if common.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=test_data, target_board="corstone-320" + ) diff --git a/backends/arm/test/passes/test_fold_qdq_pass.py b/backends/arm/test/passes/test_fold_qdq_pass.py new file mode 100644 index 00000000000..222070b4223 --- /dev/null +++ b/backends/arm/test/passes/test_fold_qdq_pass.py @@ -0,0 +1,75 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + FoldAndAnnotateQParamsPass, +) + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + +from executorch.backends.xnnpack.test.tester.tester import RunPasses + +from executorch.exir.dialects._ops import ops as exir_ops + + +class SimpleQuantizeModel(torch.nn.Module): + def forward(self, x, y): + return x + torch.max((x + x), (y + y)) + + def get_inputs(self): + return (torch.rand(1, 1280, 7, 7), torch.rand(1, 1280, 7, 7)) + + +class FoldAndAnnotateQParamsPassTestClass(FoldAndAnnotateQParamsPass): + def __init__(self): + super(FoldAndAnnotateQParamsPassTestClass, self).__init__( + [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.maximum.default, + ] + ) + + +class TestFoldAndAnnotateQParamsPass(unittest.TestCase): + """ + Tests the FoldAndAnnotateQParamsPass which folds dq/q nodes into + the node and stores the quantization parameters in meta. + """ + + def test_fold_qdq_pass(self): + """ + Check that the pass runs for add operation and that one q node and one dq node + is removed from the representation. + """ + module = SimpleQuantizeModel() + test_pass_stage = RunPasses([FoldAndAnnotateQParamsPassTestClass]) + ( + ArmTester( + module, + example_inputs=module.get_inputs(), + compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), + ) + .quantize() + .export() + .to_edge() + .check_count( + { + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 7, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 6, + } + ) + .run_passes(test_pass_stage) + .check_count( + { + "executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 1, + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 2, + } + ) + ) diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index 19397fe6b21..c3b4493b7c4 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -42,6 +42,91 @@ def register_passable_op(op): passable_ops.append(op) +def insert_rescale_ops_to_int32( + tosa_graph: ts.TosaSerializer, inputs: list[TosaArg], node: Node +) -> tuple[list[TosaSerializerTensor], float]: + """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. + The scales are adjusted using the smallest scale of all 'nodes'. + + Returns a list of the rescaled nodes and the scale factor used, + needed by rescale_node_back_to_int8. + + This functions is used in serialization to TOSA for target ops that are + handled by the DQ/D folding pass, which stores the quantization parameters + in the node meta dict as opposed to 'rescale_nodes_to_int32' which search + the graph upstream for DQ nodes. + """ + + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + ) + + tensors = inputs.copy() + + # Reshape tensor according to TOSA dim order + for tensor in tensors: + dim_order = tensor.dim_order + tensor.shape = [tensor.shape[i] for i in dim_order] + + input_qparams = get_input_qparams(node) + qargs = input_qparams.values() + + # Scale the int8 quantized input to a common scale in the integer + # domain + min_scale = min([qarg.scale for qarg in qargs]) + scales = [qarg.scale / min_scale for qarg in qargs] + + rescaled_nodes: list[TosaSerializerTensor] = [] + for tensor, qarg, scale in zip(tensors, qargs, scales): + rescaled_nodes.append( + build_rescale_to_int32( + tosa_graph, + tensor, + qarg.zp, + scale, + ) + ) + return rescaled_nodes, min_scale + + +def insert_rescale_op_to_int8( + tosa_graph: ts.TosaSerializer, + last_tensor: TosaArg, + scale: float, + node: Node, +) -> None: + """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. + Parameters: + node: The original node that is being handled by the rescales. + last_tensor:the tosa tensor to rescale back. + scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' + tosa_graph: the tosa_graph to manipulate. + + This functions is used in serialization to TOSA for target ops that are + handled by the DQ/D folding pass, which stores the quantization parameters + in the node meta dict as opposed to 'rescale_node_back_to_int8' which search + the graph downstream for Q nodes. + """ + from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_output_qparams, + ) + + output_qparams = get_output_qparams(node) + assert len(output_qparams) == 1, "More than one output not supported" + + qargs_out = output_qparams[0] + output_rescale_scale = scale / qargs_out.scale + + # Rescale Back to INT8 + build_rescale_from_int32( + tosa_graph, + last_tensor.name, + node.name, + qargs_out.zp, + output_rescale_scale, + ) + + class QuantArgs(NamedTuple): scale: float zp: int @@ -61,6 +146,31 @@ def quantize_value(self, x): def dequantize_value(self, qx: int) -> float: return (qx - self.zp) * self.scale + def __eq__(self, other): + if isinstance(other, QuantArgs): + return ( + self.scale == other.scale + and self.zp == other.zp + and self.qmin == other.qmin + and self.qmax == other.qmax + and self.dtype == other.dtype + ) + return False + + @classmethod + def from_operator(cls, op, args): + if op in dq_q_ops: + return cls( + scale=cast(float, args[1]), + zp=cast(int, args[2]), + qmin=cast(int, args[3]), + qmax=cast(int, args[4]), + dtype=cast(torch.dtype, args[5]), + ) + else: + # We're only handling per tensor quantization + raise NotImplementedError + def quantize_value(x, qargs: QuantArgs, dtype=np.int8): return np.clip( @@ -77,13 +187,7 @@ def dequantize_value(qx, qargs: QuantArgs): def qargs_from_qnode(node: torch.fx.Node): assert node.target in dq_q_ops, f"Op {node} is not a quant node." - return QuantArgs( - scale=cast(float, node.args[1]), - zp=cast(int, node.args[2]), - qmin=cast(int, node.args[3]), - qmax=cast(int, node.args[4]), - dtype=cast(torch.dtype, node.args[5]), - ) + return QuantArgs.from_operator(node.target, node.args) def get_neighbour_quant_args( @@ -214,8 +318,13 @@ def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs: def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype: - if isinstance(node.target, Callable) and "tosa" in node.target.__name__: - return node.meta["val"].dtype + if isinstance(node.target, Callable) and "output_qparams" in node.meta.keys(): + # Check if the node has had it's quantization parameters folded + # and retrieve the dtype from the meta dict in that case. + assert len(node.meta["output_qparams"]) == 1 + qargs = cast(QuantArgs, node.meta["output_qparams"][0]) + return qargs.dtype + if node.target in dq_q_ops: return cast(torch.dtype, node.args[5])