From fd63e703eb05dc84ccf921e01a6090dc5c1e5d6f Mon Sep 17 00:00:00 2001 From: Oscar Andersson <87121123+oscarandersson8218@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:27:05 +0100 Subject: [PATCH] Revert "Remove unused functions for quantization handling (#7700)" This reverts commit ffc20208dae8f4900da11bfffb76f749e7514132. --- .../annotate_channels_last_dim_order_pass.py | 7 +- backends/arm/operators/__init__.py | 2 + backends/arm/operators/op_dequant.py | 35 +++ backends/arm/operators/op_hardtanh.py | 7 +- backends/arm/operators/op_quant.py | 35 +++ backends/arm/operators/op_relu.py | 8 +- backends/arm/process_node.py | 22 +- backends/arm/tosa_quant_utils.py | 270 +++++++++++++++++- backends/arm/tosa_utils.py | 28 ++ examples/arm/aot_arm_compiler.py | 6 +- 10 files changed, 399 insertions(+), 21 deletions(-) create mode 100644 backends/arm/operators/op_dequant.py create mode 100644 backends/arm/operators/op_quant.py diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 4aff46de67d..80c5f3c442d 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -15,7 +15,7 @@ get_node_arg, insert_q_dq_pair, ) -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -43,6 +43,9 @@ def _transpose_impl(*args, **kwargs): return args[0] +register_passable_op(torch.ops.passthrough_to_tosa._transpose) + + class AnnotateChannelsLastDimOrder(ExportPass): """ Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index a21bde535ec..157e5ec0923 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -13,6 +13,7 @@ op_bmm, op_cat, op_conv2d, + op_dequant, op_exp, op_full, op_get_item, @@ -23,6 +24,7 @@ op_min, op_mul, op_permute, + op_quant, op_reciprocal, op_relu, op_repeat, diff --git a/backends/arm/operators/op_dequant.py b/backends/arm/operators/op_dequant.py new file mode 100644 index 00000000000..022f4e45ceb --- /dev/null +++ b/backends/arm/operators/op_dequant.py @@ -0,0 +1,35 @@ +# Copyright 2023-2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import serializer.tosa_serializer as ts +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class DequantVisitor(NodeVisitor): + target = "quantized_decomposed.dequantize_per_tensor.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + item_name = inputs[0].name + ## Simply add an identityOp + tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name]) diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index c971b50b665..bfbab55b922 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -19,6 +19,7 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_quant_utils import quantize_value from serializer.tosa_serializer import TosaOp @@ -43,8 +44,8 @@ def define_node( input_qparams = get_input_qparams(node) # pyre-ignore[16] qargs = input_qparams[0] # Convert to quantized representation - clamp_min_qs = qargs.quantize_value(inputs[1].number).item() - clamp_max_qs = qargs.quantize_value(inputs[2].number).item() + clamp_min_qs = quantize_value(inputs[1].number, qargs) + clamp_max_qs = quantize_value(inputs[2].number, qargs) # Set fp values to 0.0 since they are not used clamp_min_fp = 0.0 clamp_max_fp = 0.0 diff --git a/backends/arm/operators/op_quant.py b/backends/arm/operators/op_quant.py new file mode 100644 index 00000000000..fcf9372c113 --- /dev/null +++ b/backends/arm/operators/op_quant.py @@ -0,0 +1,35 @@ +# Copyright 2023-2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import serializer.tosa_serializer as ts +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class QuantVisitor(NodeVisitor): + target = "quantized_decomposed.quantize_per_tensor.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + item_name = inputs[0].name + ## Simply add an identityOp + tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name]) diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py index b5ffa2aa70c..4df13e71b7c 100644 --- a/backends/arm/operators/op_relu.py +++ b/backends/arm/operators/op_relu.py @@ -1,10 +1,11 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe +import executorch.backends.arm.tosa_quant_utils as tqutils import serializer.tosa_serializer as ts import torch.fx @@ -42,8 +43,9 @@ def define_node( clamp_max_qs = 0 if inputs[0].dtype == ts.DType.INT8: out_qargs = get_output_qparams(node) # pyre-ignore[16] - clamp_min_qs = out_qargs[0].quantize_value(0).item() - clamp_max_qs = out_qargs[0].quantize_value(float("inf")).item() + clamp_min_qs = tqutils.quantize_value(0, out_qargs[0]) + clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs[0]) + else: clamp_min_fp = 0 clamp_max_fp = float("inf") diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 36a1567df93..9ab9c49044c 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -12,7 +12,12 @@ import torch import torch.fx from executorch.backends.arm.operators.node_visitor import NodeVisitor -from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa_quant_utils import ( + dq_op, + get_quantized_node_output_dtype, + is_node_quantized, +) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape from torch.export.exported_program import ExportedProgram @@ -30,8 +35,15 @@ def process_call_function( # Convert output (this node itself) output = TosaArg(node) + is_dq_node = node.target == dq_op + if is_dq_node: + output_dtype = ts.DType.INT8 + else: + output_dtype = output.dtype tosa_graph.currRegion.currBasicBlock.addTensor( - output.name, tosa_shape(output.shape, output.dim_order), output.dtype + output.name, + tosa_shape(output.shape, output.dim_order), + output_dtype, ) # Visiting each Node @@ -67,7 +79,11 @@ def process_inputs( tensor = ts.TosaSerializerTensor( inputs[0].name, tosa_shape(input_shape, input_dim_order), - inputs[0].dtype, + ( + map_dtype(get_quantized_node_output_dtype(node)) + if is_node_quantized(node) + else inputs[0].dtype + ), data=None, placeholderFilename=inputs[0].name + ".npy", ) diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index 9869a08c0ba..dff7b12cddd 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 Arm Limited and/or its affiliates. +# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,7 +8,9 @@ # Utiliy functions for TOSA quantized lowerings import math -from typing import cast, NamedTuple +from typing import Callable, cast, NamedTuple, Sequence + +import numpy as np import serializer.tosa_serializer as ts import torch.fx @@ -22,6 +24,22 @@ q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default dq_q_ops = (q_op, dq_op) +passable_ops = [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.repeat.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.cat.default, +] + + +def register_passable_op(op): + """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created""" + passable_ops.append(op) def insert_rescale_ops_to_int32( @@ -35,7 +53,8 @@ def insert_rescale_ops_to_int32( This functions is used in serialization to TOSA for target ops that are handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict. + in the node meta dict as opposed to 'rescale_nodes_to_int32' which search + the graph upstream for DQ nodes. """ # pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' @@ -81,12 +100,13 @@ def insert_rescale_op_to_int8( Parameters: node: The original node that is being handled by the rescales. last_tensor:the tosa tensor to rescale back. - scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32' + scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' tosa_graph: the tosa_graph to manipulate. This functions is used in serialization to TOSA for target ops that are handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict. + in the node meta dict as opposed to 'rescale_node_back_to_int8' which search + the graph downstream for Q nodes. """ # pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( @@ -128,6 +148,17 @@ def quantize_value(self, x): def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor: return (qx - self.zp) * self.scale + def __eq__(self, other): + if isinstance(other, QuantArgs): + return ( + self.scale == other.scale + and self.zp == other.zp + and self.qmin == other.qmin + and self.qmax == other.qmax + and self.dtype == other.dtype + ) + return False + @classmethod def from_operator(cls, op, args): if op in dq_q_ops: @@ -143,6 +174,172 @@ def from_operator(cls, op, args): raise NotImplementedError +def quantize_value(x, qargs: QuantArgs, dtype=np.int8): + return np.clip( + np.round(x / qargs.scale) + qargs.zp, + qargs.qmin, + qargs.qmax, + ).astype(dtype) + + +def dequantize_value(qx, qargs: QuantArgs): + return (np.int64(qx) - qargs.zp) * qargs.scale + + +def qargs_from_qnode(node: torch.fx.Node): + assert node.target in dq_q_ops, f"Op {node} is not a quant node." + + return QuantArgs.from_operator(node.target, node.args) + + +def get_neighbour_quant_args( + node: torch.fx.Node, +) -> tuple[list[QuantArgs], list[QuantArgs]]: + user_q_args = [] + + for user in node.users: + q_args = search_quant_arg_downstream(user) + if q_args: + user_q_args.append(q_args) + + input_q_nodes = [] + for input_node in node.all_input_nodes: + q_args = search_quant_arg_upstream(input_node) + if q_args: + input_q_nodes.append(q_args) + return user_q_args, input_q_nodes + + +def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool: + first_q_arg = q_arg_list[0] + for q_arg in q_arg_list: + if q_arg != first_q_arg: + return False + return True + + +def is_node_quantized(node: torch.fx.Node) -> bool: + if node.target in dq_q_ops: + return True + + user_q_args, input_q_args = get_neighbour_quant_args(node) + + # If we did not find any neighbouring quant nodes, we are not quantized. + if len(input_q_args) == 0 and len(user_q_args) == 0: + return False + + if node.target in passable_ops: + assert all_q_args_equal( + user_q_args + input_q_args + ), f"Node {node} needs same quantization parameters on all inputs and outputs." + + return True + + +def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None: + """ + Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node, + starting with 'node'. + If a passable node with multiple consumers is encountered, + find QuantArgs for all consumers and assert that they are equal. + If a node not in passable_ops is encountered, return None. + If a node without consumers is encountered, return None. + """ + if node.target in dq_q_ops: + return qargs_from_qnode(node) + if node.target not in passable_ops: + return None + consumer_nodes = list(node.users) + if len(consumer_nodes) == 0: + return None + elif len(consumer_nodes) == 1: + return search_quant_arg_downstream(consumer_nodes[0]) + else: + consumer_qargs: list[QuantArgs] = [] + for input in consumer_nodes: + quant_args = search_quant_arg_downstream(input) + if quant_args: + consumer_qargs.append(quant_args) + if len(consumer_qargs) == 0: + return None + assert all_q_args_equal( + consumer_qargs + ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers." + return consumer_qargs[0] + + +def get_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs: + """Calls search_quant_arg_downstream and asserts that QuantArgs are found, + meaning return value can't be None. + """ + qargs = search_quant_arg_downstream(node) + assert qargs, f"Did not find QuantArgs downstream for node {node}" + return qargs + + +def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None: + """ + Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node, + starting with 'node'. + If a passable node with multiple inputs is encountered, + find QuantArgs for all inputs and assert that they are equal. + If a node not in passable_ops is encountered, return None. + If a node without inputs is encountered, return None. + """ + + if node.target in dq_q_ops: + return qargs_from_qnode(node) + if node.target not in passable_ops: + return None + input_nodes = list(node.all_input_nodes) + if len(input_nodes) == 0: + return None + elif len(input_nodes) == 1: + return search_quant_arg_upstream(input_nodes[0]) + else: + input_qargs: list[QuantArgs] = [] + for input in input_nodes: + quant_args = search_quant_arg_upstream(input) + if quant_args: + input_qargs.append(quant_args) + if len(input_qargs) == 0: + return None + assert all_q_args_equal( + input_qargs + ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs." + return input_qargs[0] + + +def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs: + """Calls search_quant_arg_upstream and asserts that QuantArgs are found, + meaning return value can't be None. + """ + qargs = search_quant_arg_upstream(node) + assert qargs, f"Did not find QuantArgs upstream for node {node}" + return qargs + + +def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype: + if isinstance(node.target, Callable) and "output_qparams" in node.meta.keys(): + # Check if the node has had it's quantization parameters folded + # and retrieve the dtype from the meta dict in that case. + assert len(node.meta["output_qparams"]) == 1 + qargs = cast(QuantArgs, node.meta["output_qparams"][0]) + return qargs.dtype + + if node.target in dq_q_ops: + return cast(torch.dtype, node.args[5]) + + # if not a tosa node, nor a q/dq op, walk the graph until we find a q op + user_q_args, input_q_args = get_neighbour_quant_args(node) + if len(user_q_args) > 0: + return user_q_args[0].dtype + elif node.target in passable_ops and len(input_q_args) > 0: + return input_q_args[0].dtype + else: + raise RuntimeError("No quantized node found in graph") + + # Check if scale32 mode is used for given output element type def is_scale32(type): return type == ts.DType.INT8 @@ -279,6 +476,69 @@ def build_rescale_from_int32( return +def rescale_nodes_to_int32( + nodes: Sequence[Node], tosa_graph: ts.TosaSerializer +) -> tuple[list[TosaSerializerTensor], float]: + """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. + The scales are adjusted using the smallest scale of all 'nodes'. + + Returns a list of the rescaled nodes and the scale factor used, + needed by rescale_node_back_to_int8. + """ + + tensors = [TosaArg(node) for node in nodes] + + # Reshape tensor according to tosa dim order + for tensor in tensors: + dim_order = tensor.dim_order + tensor.shape = [tensor.shape[i] for i in dim_order] + + qargs = [get_quant_arg_upstream(node) for node in nodes] + + # Scale the int8 quantized input to a common scale in the integer + # domain + min_scale = min([qarg.scale for qarg in qargs]) + scales = [qarg.scale / min_scale for qarg in qargs] + + rescaled_nodes: list[TosaSerializerTensor] = [] + for tensor, qarg, scale in zip(tensors, qargs, scales): + rescaled_nodes.append( + build_rescale_to_int32( + tosa_graph, + tensor, + qarg.zp, + scale, + ) + ) + return rescaled_nodes, min_scale + + +def rescale_node_back_to_int8( + node: Node, + last_tensor: TosaSerializerTensor, + scale: float, + tosa_graph: ts.TosaSerializer, +): + """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. + Parameters: + node: The original node that is being handled by the rescales. + last_tensor:the tosa tensor to rescale back. + scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' + tosa_graph: the tosa_graph to manipulate. + """ + qargs_out = get_quant_arg_downstream(list(node.users)[0]) + output_rescale_scale = scale / qargs_out.scale + + # Rescale Back to INT8 + build_rescale_from_int32( + tosa_graph, + last_tensor.name, + node.name, + qargs_out.zp, + output_rescale_scale, + ) + + """ Creates a TOSA rescale op based on conv2d parameters. """ diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 9fefdbb3ff3..c03e0ef0bb2 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -115,6 +115,10 @@ def getNodeArgs(node: Node) -> list[TosaArg]: return [TosaArg(arg) for arg in node.args] +def get_input_tensor(node: Node) -> TosaArg: + return TosaArg(node.args[0]) + + def get_output_node(node: Node) -> Node: return list(node.users)[0] @@ -142,6 +146,30 @@ def is_consumer_node_depthwise_conv2d(node): return False +def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]: + """Returns two input nodes to 'node' in order. If 'node' only has one input, + it is returned twice. + + Fails if there are no input nodes. + Fails if there are >2 input nodes and 'check' is True, + """ + + num_inputs = len(node.all_input_nodes) + assert num_inputs > 0, f"Node '{node.name}' requires >0 input, got {num_inputs}." + + input1 = node.all_input_nodes[0] + if num_inputs == 1: + input2 = node.all_input_nodes[0] + else: + input2 = node.all_input_nodes[1] + if check: + assert ( + num_inputs <= 2 + ), f"Node '{node.name}' requires <=2 inputs, got {num_inputs}." + + return input1, input2 + + def tosa_shape(shape, dim_order): return tuple([shape[dim] for dim in dim_order]) diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index a49436193bb..9563be93aad 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -264,11 +264,7 @@ def get_compile_spec( ) -> list[CompileSpec]: spec_builder = None if target == "TOSA": - spec_builder = ( - ArmCompileSpecBuilder() - .tosa_compile_spec("TOSA-0.80+BI") - .set_quantize_io(True) - ) + spec_builder = ArmCompileSpecBuilder().tosa_compile_spec("TOSA-0.80+BI") elif "ethos-u55" in target: spec_builder = ArmCompileSpecBuilder().ethosu_compile_spec( target,