diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f16a34a211d..b4bb809b851 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -31,6 +31,7 @@ from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( FoldAndAnnotateQParamsPass, + QuantizeFullArgument, ) from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import ( KeepDimsFalseToSqueezePass, @@ -84,12 +85,16 @@ def transform_to_backend_pipeline( self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) self.add_pass(DecomposeLinearPass()) + self.add_pass(QuantizeFullArgument()) self.add_pass( FoldAndAnnotateQParamsPass( [ exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.full.default, ] ) ) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 6c86db8a0bf..6ba72eb1022 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -15,6 +15,9 @@ from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node +q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default +dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + def get_input_qparams(node: Node) -> dict[int, QuantArgs]: """ @@ -77,8 +80,6 @@ def __init__(self, targeted_ops: Iterable[Callable]): self.targeted_ops = targeted_ops def call(self, graph_module: GraphModule) -> PassResult: - q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default - dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default # Loop over the graph nodes and find any node in the 'targeted_ops' list. for n in graph_module.graph.nodes: @@ -98,6 +99,22 @@ def call(self, graph_module: GraphModule) -> PassResult: for i, arg in enumerate(n.args): if not isinstance(arg, Node): continue + + # Make sure arg has requires_grad set to False + # For parameters that are not quantized, sometimes (i.e. convolution) + # the Parameter(FakeTensor(...)) has requires_grad set to True, which + # causes the retracing of the graph to fail with: + # + # E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch. + # E + # E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + # E Original traceback: + # E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward + # E x = conv(x) + # + if arg.op == "placeholder": + arg.meta["val"].requires_grad = False + if arg.target != dq_op: continue @@ -129,3 +146,36 @@ def call(self, graph_module: GraphModule) -> PassResult: graph_module.recompile() return PassResult(graph_module, True) + + +class QuantizeFullArgument(ExportPass): + """ + Make sure the fill_value for full.default is quantized. This pass needs to be run before + the folding pass above to make sure that the retraced output of the full.default op is + the right dtype. + """ + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + # Loop over the graph nodes and find any node in the 'targeted_ops' list. + for n in graph_module.graph.nodes: + n = cast(Node, n) + if n.target != exir_ops.edge.aten.full.default: + continue + + # Make sure we have a quantized operator + user = list(n.users)[0] + if user.target != q_op: + continue + + qargs = QuantArgs.from_operator(user.target, user.args) + if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype: + # replace the node arg with a quantized dito and also set dtype + # to get the right output according to the Edge IR specification: + # exir/dialects/edge/edge.yaml:3596 + quantized_full_value = qargs.quantize_value(n.args[1]).item() + n.update_arg(1, quantized_full_value) + n.update_kwarg("dtype", qargs.dtype) + modified = True + + return PassResult(graph_module, modified) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 4caaad92028..6665a99a7bf 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -8,30 +8,41 @@ import serializer.tosa_serializer as ts import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common +from executorch.backends.arm.tosa_specification import TosaSpecification @register_node_visitor -class AvgPool2dVisitor(NodeVisitor): +class AvgPool2dVisitor_0_80_BI(NodeVisitor): target = "aten.avg_pool2d.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + ] + def __init__(self, *args): super().__init__(*args) - def define_node( + def _build_generic_avgpool2d( self, node: torch.fx.Node, tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, + input_zp: int, + output_zp: int, + accumulator_type, ) -> None: input_tensor = inputs[0] + kernel_size_list = inputs[1].special stride_size_list = inputs[2].special try: @@ -39,13 +50,76 @@ def define_node( except IndexError: pad_size_list = [0, 0, 0, 0] - build_avg_pool_2d_common( - node, - tosa_graph, - input_tensor, - kernel_size_list, - stride_size_list, - pad_size_list, - is_quant_node, - output, + attr = ts.TosaSerializerAttribute() + attr.PoolAttribute( + kernel=kernel_size_list, + stride=stride_size_list, + pad=pad_size_list, + input_zp=input_zp, + output_zp=output_zp, + accum_dtype=accumulator_type, + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().AVG_POOL2D, + [input_tensor.name], + [output.name], + attr, + ) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + input_tensor = inputs[0] + assert input_tensor.dtype == ts.DType.INT8 + + accumulator_type = ts.DType.INT32 + + input_qargs = get_input_qparams(node) + input_zp = input_qargs[0].zp + + output_qargs = get_output_qparams(node) + output_zp = output_qargs[0].zp + + self._build_generic_avgpool2d( + node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type ) + + +@register_node_visitor +class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI): + # inheriting 'target' from BI class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert ( + inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 + ), "Only FP32 and INT8 supported" + + if inputs[0].dtype == ts.DType.INT8: + super().define_node(node, tosa_graph, inputs, output, is_quant_node) + + if inputs[0].dtype == ts.DType.FP32: + accumulator_type = ts.DType.FP32 + # Initilize zero point to zero. + input_zp = 0 + output_zp = 0 + + self._build_generic_avgpool2d( + node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type + ) diff --git a/backends/arm/operators/op_batch_norm.py b/backends/arm/operators/op_batch_norm.py index d17c3a1b81f..ee773949d1e 100644 --- a/backends/arm/operators/op_batch_norm.py +++ b/backends/arm/operators/op_batch_norm.py @@ -13,6 +13,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import promote_shape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -21,6 +22,10 @@ class BatchNormVisitor(NodeVisitor): target = "aten._native_batch_norm_legit_no_training.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + def __init__(self, *args): super().__init__(*args) diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index ffbeee7306d..dc64e169364 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -8,16 +8,16 @@ import serializer.tosa_serializer as ts import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - build_rescale_conv_output, - get_quant_arg_downstream, - get_quant_arg_upstream, -) +from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -57,9 +57,6 @@ def define_node( ) -> None: input, weight, bias, stride, pad, dilation, _, _, group = inputs - # Currently only int8 is supported in quantized types. - actual_out_type = ts.DType.INT8 if is_quant_node else output.dtype - # Get the attributes of convolution. attr = ts.TosaSerializerAttribute() pad_attr = [val for val in pad.special for _ in (0, 1)] @@ -82,9 +79,11 @@ def define_node( dilation_attr[1], ) - input_zp = ( - get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0 - ) + input_zp = 0 + if inputs[0].dtype == ts.DType.INT8: + # int8 input requires quantization information + input_qparams = get_input_qparams(node) + input_zp = input_qparams[0].zp attr.ConvAttribute( pad=pad_attr, @@ -100,16 +99,22 @@ def define_node( # Create a zero bias tensor if not presented out_channels = weight.shape[0] bias_name = "bias" + node.name.split("default", 1)[1] + bias_type = output.dtype + if output.dtype == ts.DType.INT8: + # Conv is quantized to int8, but the TOSA operator has + # output type int32, and the bias must be the same type + # as the TOSA output type + bias_type = ts.DType.INT32 bias = tosa_graph.addConst( [out_channels], - ts.DType.INT32 if is_quant_node else output.dtype, + bias_type, [0] * out_channels, name=bias_name, ) # The output type is int32 when input type is int8. conv2d_output_name = output.name - if is_quant_node: + if output.dtype == ts.DType.INT8: conv2d_res = tosa_graph.addIntermediate( tosa_shape(output.shape, output.dim_order), ts.DType.INT32 ) @@ -132,7 +137,7 @@ def define_node( weight_reshaped = tosa_graph.addIntermediate( weight_post_shape, - ts.DType.INT8 if is_quant_node else weight.dtype, + weight.dtype, ) build_reshape( tosa_graph, weight.name, weight_post_shape, weight_reshaped.name @@ -157,20 +162,19 @@ def define_node( # For quantized convolution, rescale the output value back to the same # integer value domain of the next op. Otherwise return float32 output. - if is_quant_node: + if inputs[0].dtype == ts.DType.INT8: # Get scale_factor from input, weight, and output. - input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale - weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale - output_qargs = get_quant_arg_downstream(list(node.users)[0]) - + input_scale = input_qparams[0].scale + weight_scale = input_qparams[1].scale + output_qargs = get_output_qparams(node) build_rescale_conv_output( tosa_graph, # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. conv2d_res, output.name, - actual_out_type, + output.dtype, input_scale, weight_scale, - output_qargs.scale, - output_qargs.zp, + output_qargs[0].scale, + output_qargs[0].zp, ) diff --git a/backends/arm/operators/op_div.py b/backends/arm/operators/op_div.py index 0857e0ed32a..339833c329c 100644 --- a/backends/arm/operators/op_div.py +++ b/backends/arm/operators/op_div.py @@ -13,6 +13,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import tosa_shape from serializer.tosa_serializer import TosaOp @@ -21,6 +22,11 @@ class DivVisitor(NodeVisitor): target = "aten.div.Tensor" + # Only supported for MI + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + def __init__(self, *args): super().__init__(*args) diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py index d2bc1377ce7..23a13dd4869 100644 --- a/backends/arm/operators/op_full.py +++ b/backends/arm/operators/op_full.py @@ -14,10 +14,6 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_downstream, - quantize_value, -) from executorch.backends.arm.tosa_utils import tosa_shape from torch.fx import Node @@ -41,19 +37,14 @@ def define_node( shape = tosa_shape(inputs[0].special, output.dim_order) value = inputs[1].number - if is_quant_node: - qargs = get_quant_arg_downstream(list(node.users)[0]) - qvalue = quantize_value(value, qargs) - dtype = ts.DType.INT8 - data = np.full(shape, qvalue, dtype=np.int8) + + if output.dtype == ts.DType.INT8: + fill_dtype = np.int8 else: - assert ( - output.dtype == ts.DType.FP32 - ), "'Full' currently only supports FP32 for unquantized models." - dtype = ts.DType.FP32 - data = np.full(shape, value, dtype=np.float32) + fill_dtype = np.float32 + data = np.full(shape, value, dtype=fill_dtype) - tosa_graph.addConst(shape, dtype, data, node.name + "full-const") + tosa_graph.addConst(shape, output.dtype, data, node.name + "full-const") tosa_graph.addOperator( ts.TosaOp.Op.IDENTITY, [node.name + "full-const"], [output.name] ) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 74e33ddb02c..0a4092e3a9a 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -13,7 +13,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import ( +from executorch.backends.arm.tosa_quant_utils import ( get_quant_arg_downstream, get_quant_arg_upstream, ) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 2d3a0c2786c..3b1ea9d70fe 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -11,10 +11,12 @@ import serializer.tosa_serializer as ts import torch import torch.fx +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_upstream, get_quantized_node_output_dtype, is_node_quantized, ) @@ -110,8 +112,10 @@ def process_quantized_bias( _, ) = consumer_node.all_input_nodes - input_node_scale = get_quant_arg_upstream(input_node).scale - weight_node_scale = get_quant_arg_upstream(weight_node).scale + input_qargs = get_input_qparams(consumer_node) + + input_node_scale = input_qargs[0].scale + weight_node_scale = input_qargs[1].scale bias_values_quantized = ( (parameter_values / (input_node_scale * weight_node_scale)) .round() diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 4de84ed3458..9ae1a27cf7e 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -127,7 +127,7 @@ def _get_output_node(program: ExportedProgram) -> Node: def _get_output_quantization_params( program: ExportedProgram, output_node: Node -) -> QuantizationParams: +) -> Optional[QuantizationParams]: """ Get output QuantizationParams from a program. Args: @@ -153,8 +153,6 @@ def _get_output_quantization_params( dtype=node.args[5], ) break # break early, there's only one output node - if quant_params is None: - raise RuntimeError("No Quantization parameters not found in exported model.") return quant_params @@ -485,13 +483,17 @@ def run_tosa_ref_model( if tosa_ref_output.dtype == np.int8: tosa_ref_output = tosa_ref_output.astype(np.int32) quant_param = self.qp_output - assert ( - quant_param is not None - ), "There are no quantization parameters, check output parameters" - tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale + if quant_param is not None: + # I.e. bool output is possible for quantized models + tosa_ref_output = ( + tosa_ref_output - quant_param.zp + ) * quant_param.scale if tosa_ref_output.dtype == np.double: tosa_ref_output = tosa_ref_output.astype("float32") + elif tosa_ref_output.dtype == bool: + # retain the bool output though for boolean related comparisons + tosa_ref_output = tosa_ref_output.astype("bool") # tosa_output is a numpy array, convert to torch tensor for comparison tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output)) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 4f9eae64be8..7b129a98877 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -12,7 +12,6 @@ import executorch.backends.xnnpack.test.tester.tester as tester -import numpy as np import serializer.tosa_serializer as ts import torch.fx @@ -319,12 +318,15 @@ def run_method_and_compare_outputs( target_board, ) + quantization_scale = None if is_quantized: reference_stage = self.stages[self.stage_name(tester.Quantize)] - quantization_scale = self.runner_util.qp_output.scale + # bool output is quantized with none quantized output so allow + # self.runner_util.qp_output to be none + if self.runner_util.qp_output is not None: + quantization_scale = self.runner_util.qp_output.scale else: reference_stage = self.stages[self.stage_name(InitialModel)] - quantization_scale = None logger.info( f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'" @@ -504,7 +506,7 @@ def transpose_data_format( inputs_transposed = list(data) for i in range(len(data)): if hasattr(data[i], "shape") and len(data[i].shape) == 4: - inputs_transposed[i] = np.transpose(data[i], dim_order) + inputs_transposed[i] = torch.permute(data[i], dim_order) return tuple(inputs_transposed) def _compare_outputs( diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 1ae319e0cd7..dd28105a63a 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -7,18 +7,13 @@ import logging import os -from typing import Any, cast +from typing import Any import numpy as np import serializer.tosa_serializer as ts import torch from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_downstream, - get_quant_arg_upstream, - q_op, -) from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -140,10 +135,15 @@ def build_reshape(tosa_fb, input_name, new_shape, output_name): def is_bias_node_for_quantized_conv(node): consumer_node = list(node.users)[0] - return ( + + if ( consumer_node.target == exir_ops.edge.aten.convolution.default - and list(consumer_node.users)[0].target == q_op - ) + and consumer_node.args[2] == node + and consumer_node.meta["val"].dtype == torch.int8 + ): + return True + + return False def is_consumer_node_depthwise_conv2d(node): @@ -159,48 +159,6 @@ def is_consumer_node_depthwise_conv2d(node): return False -def build_avg_pool_2d_common( - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - input_tensor: TosaArg, - kernel_size: list, - stride: list, - padding: list, - is_quant_node: bool, - output: TosaArg, -): - accumulator_type = input_tensor.dtype - - if is_quant_node: - # Accumulator type always is int32 when input tensor is an integer type. - accumulator_type = ts.DType.INT32 - - # Initilize zero point to zero. - input_zp = 0 - output_zp = 0 - - if is_quant_node: - input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp - output_zp = get_quant_arg_downstream(list(node.users)[0]).zp - - attr = ts.TosaSerializerAttribute() - attr.PoolAttribute( - kernel=kernel_size, - stride=stride, - pad=padding, - input_zp=input_zp, - output_zp=output_zp, - accum_dtype=accumulator_type, - ) - - tosa_graph.addOperator( - TosaOp.Op().AVG_POOL2D, - [input_tensor.name], - [output.name], - attr, - ) - - def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]: """Returns two input nodes to 'node' in order. If 'node' only has one input, it is returned twice.