diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 80c5f3c442d..4aff46de67d 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -1,4 +1,4 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the @@ -15,7 +15,7 @@ get_node_arg, insert_q_dq_pair, ) -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -43,9 +43,6 @@ def _transpose_impl(*args, **kwargs): return args[0] -register_passable_op(torch.ops.passthrough_to_tosa._transpose) - - class AnnotateChannelsLastDimOrder(ExportPass): """ Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 157e5ec0923..a21bde535ec 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -13,7 +13,6 @@ op_bmm, op_cat, op_conv2d, - op_dequant, op_exp, op_full, op_get_item, @@ -24,7 +23,6 @@ op_min, op_mul, op_permute, - op_quant, op_reciprocal, op_relu, op_repeat, diff --git a/backends/arm/operators/op_dequant.py b/backends/arm/operators/op_dequant.py deleted file mode 100644 index 022f4e45ceb..00000000000 --- a/backends/arm/operators/op_dequant.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import List - -import serializer.tosa_serializer as ts -import torch -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp - - -@register_node_visitor -class DequantVisitor(NodeVisitor): - target = "quantized_decomposed.dequantize_per_tensor.default" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - item_name = inputs[0].name - ## Simply add an identityOp - tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name]) diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index bfbab55b922..c971b50b665 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. +# Copyright 2023-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -19,7 +19,6 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import quantize_value from serializer.tosa_serializer import TosaOp @@ -44,8 +43,8 @@ def define_node( input_qparams = get_input_qparams(node) # pyre-ignore[16] qargs = input_qparams[0] # Convert to quantized representation - clamp_min_qs = quantize_value(inputs[1].number, qargs) - clamp_max_qs = quantize_value(inputs[2].number, qargs) + clamp_min_qs = qargs.quantize_value(inputs[1].number).item() + clamp_max_qs = qargs.quantize_value(inputs[2].number).item() # Set fp values to 0.0 since they are not used clamp_min_fp = 0.0 clamp_max_fp = 0.0 diff --git a/backends/arm/operators/op_quant.py b/backends/arm/operators/op_quant.py deleted file mode 100644 index fcf9372c113..00000000000 --- a/backends/arm/operators/op_quant.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import List - -import serializer.tosa_serializer as ts -import torch -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.tosa_mapping import TosaArg -from serializer.tosa_serializer import TosaOp - - -@register_node_visitor -class QuantVisitor(NodeVisitor): - target = "quantized_decomposed.quantize_per_tensor.default" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - item_name = inputs[0].name - ## Simply add an identityOp - tosa_graph.addOperator(TosaOp.Op().IDENTITY, [item_name], [output.name]) diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py index 4df13e71b7c..b5ffa2aa70c 100644 --- a/backends/arm/operators/op_relu.py +++ b/backends/arm/operators/op_relu.py @@ -1,11 +1,10 @@ -# Copyright 2024 Arm Limited and/or its affiliates. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe -import executorch.backends.arm.tosa_quant_utils as tqutils import serializer.tosa_serializer as ts import torch.fx @@ -43,9 +42,8 @@ def define_node( clamp_max_qs = 0 if inputs[0].dtype == ts.DType.INT8: out_qargs = get_output_qparams(node) # pyre-ignore[16] - clamp_min_qs = tqutils.quantize_value(0, out_qargs[0]) - clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs[0]) - + clamp_min_qs = out_qargs[0].quantize_value(0).item() + clamp_max_qs = out_qargs[0].quantize_value(float("inf")).item() else: clamp_min_fp = 0 clamp_max_fp = float("inf") diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 9ab9c49044c..36a1567df93 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -12,12 +12,7 @@ import torch import torch.fx from executorch.backends.arm.operators.node_visitor import NodeVisitor -from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - dq_op, - get_quantized_node_output_dtype, - is_node_quantized, -) +from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape from torch.export.exported_program import ExportedProgram @@ -35,15 +30,8 @@ def process_call_function( # Convert output (this node itself) output = TosaArg(node) - is_dq_node = node.target == dq_op - if is_dq_node: - output_dtype = ts.DType.INT8 - else: - output_dtype = output.dtype tosa_graph.currRegion.currBasicBlock.addTensor( - output.name, - tosa_shape(output.shape, output.dim_order), - output_dtype, + output.name, tosa_shape(output.shape, output.dim_order), output.dtype ) # Visiting each Node @@ -79,11 +67,7 @@ def process_inputs( tensor = ts.TosaSerializerTensor( inputs[0].name, tosa_shape(input_shape, input_dim_order), - ( - map_dtype(get_quantized_node_output_dtype(node)) - if is_node_quantized(node) - else inputs[0].dtype - ), + inputs[0].dtype, data=None, placeholderFilename=inputs[0].name + ".npy", ) diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index dff7b12cddd..9869a08c0ba 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 Arm Limited and/or its affiliates. +# Copyright 2023-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -8,9 +8,7 @@ # Utiliy functions for TOSA quantized lowerings import math -from typing import Callable, cast, NamedTuple, Sequence - -import numpy as np +from typing import cast, NamedTuple import serializer.tosa_serializer as ts import torch.fx @@ -24,22 +22,6 @@ q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default dq_q_ops = (q_op, dq_op) -passable_ops = [ - exir_ops.edge.aten.view_copy.default, - exir_ops.edge.aten.permute_copy.default, - exir_ops.edge.aten.squeeze_copy.dims, - exir_ops.edge.aten.unsqueeze_copy.default, - exir_ops.edge.aten.split_with_sizes_copy.default, - exir_ops.edge.aten.repeat.default, - exir_ops.edge.aten.clone.default, - exir_ops.edge.aten.slice_copy.Tensor, - exir_ops.edge.aten.cat.default, -] - - -def register_passable_op(op): - """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created""" - passable_ops.append(op) def insert_rescale_ops_to_int32( @@ -53,8 +35,7 @@ def insert_rescale_ops_to_int32( This functions is used in serialization to TOSA for target ops that are handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict as opposed to 'rescale_nodes_to_int32' which search - the graph upstream for DQ nodes. + in the node meta dict. """ # pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' @@ -100,13 +81,12 @@ def insert_rescale_op_to_int8( Parameters: node: The original node that is being handled by the rescales. last_tensor:the tosa tensor to rescale back. - scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' + scale: the scaling factor used to rescale to int32, from the function 'insert_rescale_op_to_int32' tosa_graph: the tosa_graph to manipulate. This functions is used in serialization to TOSA for target ops that are handled by the DQ/D folding pass, which stores the quantization parameters - in the node meta dict as opposed to 'rescale_node_back_to_int8' which search - the graph downstream for Q nodes. + in the node meta dict. """ # pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.' from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( @@ -148,17 +128,6 @@ def quantize_value(self, x): def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor: return (qx - self.zp) * self.scale - def __eq__(self, other): - if isinstance(other, QuantArgs): - return ( - self.scale == other.scale - and self.zp == other.zp - and self.qmin == other.qmin - and self.qmax == other.qmax - and self.dtype == other.dtype - ) - return False - @classmethod def from_operator(cls, op, args): if op in dq_q_ops: @@ -174,172 +143,6 @@ def from_operator(cls, op, args): raise NotImplementedError -def quantize_value(x, qargs: QuantArgs, dtype=np.int8): - return np.clip( - np.round(x / qargs.scale) + qargs.zp, - qargs.qmin, - qargs.qmax, - ).astype(dtype) - - -def dequantize_value(qx, qargs: QuantArgs): - return (np.int64(qx) - qargs.zp) * qargs.scale - - -def qargs_from_qnode(node: torch.fx.Node): - assert node.target in dq_q_ops, f"Op {node} is not a quant node." - - return QuantArgs.from_operator(node.target, node.args) - - -def get_neighbour_quant_args( - node: torch.fx.Node, -) -> tuple[list[QuantArgs], list[QuantArgs]]: - user_q_args = [] - - for user in node.users: - q_args = search_quant_arg_downstream(user) - if q_args: - user_q_args.append(q_args) - - input_q_nodes = [] - for input_node in node.all_input_nodes: - q_args = search_quant_arg_upstream(input_node) - if q_args: - input_q_nodes.append(q_args) - return user_q_args, input_q_nodes - - -def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool: - first_q_arg = q_arg_list[0] - for q_arg in q_arg_list: - if q_arg != first_q_arg: - return False - return True - - -def is_node_quantized(node: torch.fx.Node) -> bool: - if node.target in dq_q_ops: - return True - - user_q_args, input_q_args = get_neighbour_quant_args(node) - - # If we did not find any neighbouring quant nodes, we are not quantized. - if len(input_q_args) == 0 and len(user_q_args) == 0: - return False - - if node.target in passable_ops: - assert all_q_args_equal( - user_q_args + input_q_args - ), f"Node {node} needs same quantization parameters on all inputs and outputs." - - return True - - -def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None: - """ - Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node, - starting with 'node'. - If a passable node with multiple consumers is encountered, - find QuantArgs for all consumers and assert that they are equal. - If a node not in passable_ops is encountered, return None. - If a node without consumers is encountered, return None. - """ - if node.target in dq_q_ops: - return qargs_from_qnode(node) - if node.target not in passable_ops: - return None - consumer_nodes = list(node.users) - if len(consumer_nodes) == 0: - return None - elif len(consumer_nodes) == 1: - return search_quant_arg_downstream(consumer_nodes[0]) - else: - consumer_qargs: list[QuantArgs] = [] - for input in consumer_nodes: - quant_args = search_quant_arg_downstream(input) - if quant_args: - consumer_qargs.append(quant_args) - if len(consumer_qargs) == 0: - return None - assert all_q_args_equal( - consumer_qargs - ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers." - return consumer_qargs[0] - - -def get_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs: - """Calls search_quant_arg_downstream and asserts that QuantArgs are found, - meaning return value can't be None. - """ - qargs = search_quant_arg_downstream(node) - assert qargs, f"Did not find QuantArgs downstream for node {node}" - return qargs - - -def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None: - """ - Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node, - starting with 'node'. - If a passable node with multiple inputs is encountered, - find QuantArgs for all inputs and assert that they are equal. - If a node not in passable_ops is encountered, return None. - If a node without inputs is encountered, return None. - """ - - if node.target in dq_q_ops: - return qargs_from_qnode(node) - if node.target not in passable_ops: - return None - input_nodes = list(node.all_input_nodes) - if len(input_nodes) == 0: - return None - elif len(input_nodes) == 1: - return search_quant_arg_upstream(input_nodes[0]) - else: - input_qargs: list[QuantArgs] = [] - for input in input_nodes: - quant_args = search_quant_arg_upstream(input) - if quant_args: - input_qargs.append(quant_args) - if len(input_qargs) == 0: - return None - assert all_q_args_equal( - input_qargs - ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs." - return input_qargs[0] - - -def get_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs: - """Calls search_quant_arg_upstream and asserts that QuantArgs are found, - meaning return value can't be None. - """ - qargs = search_quant_arg_upstream(node) - assert qargs, f"Did not find QuantArgs upstream for node {node}" - return qargs - - -def get_quantized_node_output_dtype(node: torch.fx.Node) -> torch.dtype: - if isinstance(node.target, Callable) and "output_qparams" in node.meta.keys(): - # Check if the node has had it's quantization parameters folded - # and retrieve the dtype from the meta dict in that case. - assert len(node.meta["output_qparams"]) == 1 - qargs = cast(QuantArgs, node.meta["output_qparams"][0]) - return qargs.dtype - - if node.target in dq_q_ops: - return cast(torch.dtype, node.args[5]) - - # if not a tosa node, nor a q/dq op, walk the graph until we find a q op - user_q_args, input_q_args = get_neighbour_quant_args(node) - if len(user_q_args) > 0: - return user_q_args[0].dtype - elif node.target in passable_ops and len(input_q_args) > 0: - return input_q_args[0].dtype - else: - raise RuntimeError("No quantized node found in graph") - - # Check if scale32 mode is used for given output element type def is_scale32(type): return type == ts.DType.INT8 @@ -476,69 +279,6 @@ def build_rescale_from_int32( return -def rescale_nodes_to_int32( - nodes: Sequence[Node], tosa_graph: ts.TosaSerializer -) -> tuple[list[TosaSerializerTensor], float]: - """Rescales all 'nodes' to int32, adding suitable RESCALE ops to 'tosa_graph'. - The scales are adjusted using the smallest scale of all 'nodes'. - - Returns a list of the rescaled nodes and the scale factor used, - needed by rescale_node_back_to_int8. - """ - - tensors = [TosaArg(node) for node in nodes] - - # Reshape tensor according to tosa dim order - for tensor in tensors: - dim_order = tensor.dim_order - tensor.shape = [tensor.shape[i] for i in dim_order] - - qargs = [get_quant_arg_upstream(node) for node in nodes] - - # Scale the int8 quantized input to a common scale in the integer - # domain - min_scale = min([qarg.scale for qarg in qargs]) - scales = [qarg.scale / min_scale for qarg in qargs] - - rescaled_nodes: list[TosaSerializerTensor] = [] - for tensor, qarg, scale in zip(tensors, qargs, scales): - rescaled_nodes.append( - build_rescale_to_int32( - tosa_graph, - tensor, - qarg.zp, - scale, - ) - ) - return rescaled_nodes, min_scale - - -def rescale_node_back_to_int8( - node: Node, - last_tensor: TosaSerializerTensor, - scale: float, - tosa_graph: ts.TosaSerializer, -): - """Rescales the node back to int8, adding a suitable RESCALE op to 'tosa_graph'. - Parameters: - node: The original node that is being handled by the rescales. - last_tensor:the tosa tensor to rescale back. - scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' - tosa_graph: the tosa_graph to manipulate. - """ - qargs_out = get_quant_arg_downstream(list(node.users)[0]) - output_rescale_scale = scale / qargs_out.scale - - # Rescale Back to INT8 - build_rescale_from_int32( - tosa_graph, - last_tensor.name, - node.name, - qargs_out.zp, - output_rescale_scale, - ) - - """ Creates a TOSA rescale op based on conv2d parameters. """ diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index c03e0ef0bb2..9fefdbb3ff3 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -115,10 +115,6 @@ def getNodeArgs(node: Node) -> list[TosaArg]: return [TosaArg(arg) for arg in node.args] -def get_input_tensor(node: Node) -> TosaArg: - return TosaArg(node.args[0]) - - def get_output_node(node: Node) -> Node: return list(node.users)[0] @@ -146,30 +142,6 @@ def is_consumer_node_depthwise_conv2d(node): return False -def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]: - """Returns two input nodes to 'node' in order. If 'node' only has one input, - it is returned twice. - - Fails if there are no input nodes. - Fails if there are >2 input nodes and 'check' is True, - """ - - num_inputs = len(node.all_input_nodes) - assert num_inputs > 0, f"Node '{node.name}' requires >0 input, got {num_inputs}." - - input1 = node.all_input_nodes[0] - if num_inputs == 1: - input2 = node.all_input_nodes[0] - else: - input2 = node.all_input_nodes[1] - if check: - assert ( - num_inputs <= 2 - ), f"Node '{node.name}' requires <=2 inputs, got {num_inputs}." - - return input1, input2 - - def tosa_shape(shape, dim_order): return tuple([shape[dim] for dim in dim_order]) diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index e842cde6bb7..09434df2abf 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -265,7 +265,11 @@ def get_compile_spec( ) -> list[CompileSpec]: spec_builder = None if target == "TOSA": - spec_builder = ArmCompileSpecBuilder().tosa_compile_spec("TOSA-0.80+BI") + spec_builder = ( + ArmCompileSpecBuilder() + .tosa_compile_spec("TOSA-0.80+BI") + .set_quantize_io(True) + ) elif "ethos-u55" in target: spec_builder = ( ArmCompileSpecBuilder()