From 2f8cb4e870f0066fb0604bf57255febf263adb20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Wed, 20 Nov 2024 15:11:10 +0100 Subject: [PATCH 1/4] Set requires_grad to avoid check for differentiable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The node.args that comes from nn.Parameters in the state_dict sometimes have the requires_grad property set to True. It's seems to stem already from the export stage and .eval() doesn't change the parameter. Address it here in the pass for now. Signed-off-by: Per Åstrand Change-Id: Ie7425764f1f1865de0fc66e1d020f804c7e936b1 --- .../fold_qdq_with_annotated_qparams_pass.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 6c86db8a0bf..24d1a03395d 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -98,6 +98,22 @@ def call(self, graph_module: GraphModule) -> PassResult: for i, arg in enumerate(n.args): if not isinstance(arg, Node): continue + + # Make sure arg has requires_grad set to False + # For parameters that are not quantized, sometimes (i.e. convolution) + # the Parameter(FakeTensor(...)) has requires_grad set to True, which + # causes the retracing of the graph to fail with: + # + # E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch. + # E + # E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) + # E Original traceback: + # E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward + # E x = conv(x) + # + if arg.op == "placeholder": + arg.meta["val"].requires_grad = False + if arg.target != dq_op: continue From 2a90a34a8d39bf2018c0d0cb9d4dda57b63aa028 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Mon, 2 Dec 2024 09:05:51 +0100 Subject: [PATCH 2/4] Allow TOSA tests to not have quant info MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Quantized models might have output without quantization parameters attached to them. The assert for parameters not being None are removed and handled in order to allow for that case. numpy transpose is removed in favor of torch.permute to keep the type of the output after the operation. Signed-off-by: Per Åstrand Change-Id: I0e404062154cefa39f18b5706d72d19cac0e6d73 --- backends/arm/test/runner_utils.py | 16 +++++++++------- backends/arm/test/tester/arm_tester.py | 10 ++++++---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 4de84ed3458..9ae1a27cf7e 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -127,7 +127,7 @@ def _get_output_node(program: ExportedProgram) -> Node: def _get_output_quantization_params( program: ExportedProgram, output_node: Node -) -> QuantizationParams: +) -> Optional[QuantizationParams]: """ Get output QuantizationParams from a program. Args: @@ -153,8 +153,6 @@ def _get_output_quantization_params( dtype=node.args[5], ) break # break early, there's only one output node - if quant_params is None: - raise RuntimeError("No Quantization parameters not found in exported model.") return quant_params @@ -485,13 +483,17 @@ def run_tosa_ref_model( if tosa_ref_output.dtype == np.int8: tosa_ref_output = tosa_ref_output.astype(np.int32) quant_param = self.qp_output - assert ( - quant_param is not None - ), "There are no quantization parameters, check output parameters" - tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale + if quant_param is not None: + # I.e. bool output is possible for quantized models + tosa_ref_output = ( + tosa_ref_output - quant_param.zp + ) * quant_param.scale if tosa_ref_output.dtype == np.double: tosa_ref_output = tosa_ref_output.astype("float32") + elif tosa_ref_output.dtype == bool: + # retain the bool output though for boolean related comparisons + tosa_ref_output = tosa_ref_output.astype("bool") # tosa_output is a numpy array, convert to torch tensor for comparison tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output)) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 4f9eae64be8..7b129a98877 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -12,7 +12,6 @@ import executorch.backends.xnnpack.test.tester.tester as tester -import numpy as np import serializer.tosa_serializer as ts import torch.fx @@ -319,12 +318,15 @@ def run_method_and_compare_outputs( target_board, ) + quantization_scale = None if is_quantized: reference_stage = self.stages[self.stage_name(tester.Quantize)] - quantization_scale = self.runner_util.qp_output.scale + # bool output is quantized with none quantized output so allow + # self.runner_util.qp_output to be none + if self.runner_util.qp_output is not None: + quantization_scale = self.runner_util.qp_output.scale else: reference_stage = self.stages[self.stage_name(InitialModel)] - quantization_scale = None logger.info( f"Comparing Stage '{self.stage_name(test_stage)}' with Stage '{self.stage_name(reference_stage)}'" @@ -504,7 +506,7 @@ def transpose_data_format( inputs_transposed = list(data) for i in range(len(data)): if hasattr(data[i], "shape") and len(data[i].shape) == 4: - inputs_transposed[i] = np.transpose(data[i], dim_order) + inputs_transposed[i] = torch.permute(data[i], dim_order) return tuple(inputs_transposed) def _compare_outputs( From e34c40ae296e45ffc53094cfe490c7f44e0b8428 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Mon, 18 Nov 2024 14:20:35 +0100 Subject: [PATCH 3/4] Convert more NodeVisitors to folding DQ/Q pass usage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Per Åstrand Change-Id: I9201d8bafd543204b697c7276d6929ad3aa09f25 --- backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/operators/op_avg_pool2d.py | 100 ++++++++++++++++++++--- backends/arm/operators/op_batch_norm.py | 5 ++ backends/arm/operators/op_conv2d.py | 48 ++++++----- backends/arm/operators/op_div.py | 6 ++ backends/arm/operators/op_max_pool2d.py | 2 +- backends/arm/process_node.py | 10 ++- backends/arm/tosa_utils.py | 60 ++------------ 8 files changed, 143 insertions(+), 90 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index f16a34a211d..e1c903302c3 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -90,6 +90,8 @@ def transform_to_backend_pipeline( exir_ops.edge.aten.minimum.default, exir_ops.edge.aten.maximum.default, exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.convolution.default, ] ) ) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 4caaad92028..6665a99a7bf 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -8,30 +8,41 @@ import serializer.tosa_serializer as ts import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common +from executorch.backends.arm.tosa_specification import TosaSpecification @register_node_visitor -class AvgPool2dVisitor(NodeVisitor): +class AvgPool2dVisitor_0_80_BI(NodeVisitor): target = "aten.avg_pool2d.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+BI"), + ] + def __init__(self, *args): super().__init__(*args) - def define_node( + def _build_generic_avgpool2d( self, node: torch.fx.Node, tosa_graph: ts.TosaSerializer, inputs: List[TosaArg], output: TosaArg, - is_quant_node: bool, + input_zp: int, + output_zp: int, + accumulator_type, ) -> None: input_tensor = inputs[0] + kernel_size_list = inputs[1].special stride_size_list = inputs[2].special try: @@ -39,13 +50,76 @@ def define_node( except IndexError: pad_size_list = [0, 0, 0, 0] - build_avg_pool_2d_common( - node, - tosa_graph, - input_tensor, - kernel_size_list, - stride_size_list, - pad_size_list, - is_quant_node, - output, + attr = ts.TosaSerializerAttribute() + attr.PoolAttribute( + kernel=kernel_size_list, + stride=stride_size_list, + pad=pad_size_list, + input_zp=input_zp, + output_zp=output_zp, + accum_dtype=accumulator_type, + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().AVG_POOL2D, + [input_tensor.name], + [output.name], + attr, + ) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + input_tensor = inputs[0] + assert input_tensor.dtype == ts.DType.INT8 + + accumulator_type = ts.DType.INT32 + + input_qargs = get_input_qparams(node) + input_zp = input_qargs[0].zp + + output_qargs = get_output_qparams(node) + output_zp = output_qargs[0].zp + + self._build_generic_avgpool2d( + node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type ) + + +@register_node_visitor +class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI): + # inheriting 'target' from BI class + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + assert ( + inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32 + ), "Only FP32 and INT8 supported" + + if inputs[0].dtype == ts.DType.INT8: + super().define_node(node, tosa_graph, inputs, output, is_quant_node) + + if inputs[0].dtype == ts.DType.FP32: + accumulator_type = ts.DType.FP32 + # Initilize zero point to zero. + input_zp = 0 + output_zp = 0 + + self._build_generic_avgpool2d( + node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type + ) diff --git a/backends/arm/operators/op_batch_norm.py b/backends/arm/operators/op_batch_norm.py index d17c3a1b81f..ee773949d1e 100644 --- a/backends/arm/operators/op_batch_norm.py +++ b/backends/arm/operators/op_batch_norm.py @@ -13,6 +13,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import promote_shape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -21,6 +22,10 @@ class BatchNormVisitor(NodeVisitor): target = "aten._native_batch_norm_legit_no_training.default" + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + def __init__(self, *args): super().__init__(*args) diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index ffbeee7306d..dc64e169364 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -8,16 +8,16 @@ import serializer.tosa_serializer as ts import torch +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, + get_output_qparams, +) from executorch.backends.arm.operators.node_visitor import ( NodeVisitor, register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - build_rescale_conv_output, - get_quant_arg_downstream, - get_quant_arg_upstream, -) +from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -57,9 +57,6 @@ def define_node( ) -> None: input, weight, bias, stride, pad, dilation, _, _, group = inputs - # Currently only int8 is supported in quantized types. - actual_out_type = ts.DType.INT8 if is_quant_node else output.dtype - # Get the attributes of convolution. attr = ts.TosaSerializerAttribute() pad_attr = [val for val in pad.special for _ in (0, 1)] @@ -82,9 +79,11 @@ def define_node( dilation_attr[1], ) - input_zp = ( - get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0 - ) + input_zp = 0 + if inputs[0].dtype == ts.DType.INT8: + # int8 input requires quantization information + input_qparams = get_input_qparams(node) + input_zp = input_qparams[0].zp attr.ConvAttribute( pad=pad_attr, @@ -100,16 +99,22 @@ def define_node( # Create a zero bias tensor if not presented out_channels = weight.shape[0] bias_name = "bias" + node.name.split("default", 1)[1] + bias_type = output.dtype + if output.dtype == ts.DType.INT8: + # Conv is quantized to int8, but the TOSA operator has + # output type int32, and the bias must be the same type + # as the TOSA output type + bias_type = ts.DType.INT32 bias = tosa_graph.addConst( [out_channels], - ts.DType.INT32 if is_quant_node else output.dtype, + bias_type, [0] * out_channels, name=bias_name, ) # The output type is int32 when input type is int8. conv2d_output_name = output.name - if is_quant_node: + if output.dtype == ts.DType.INT8: conv2d_res = tosa_graph.addIntermediate( tosa_shape(output.shape, output.dim_order), ts.DType.INT32 ) @@ -132,7 +137,7 @@ def define_node( weight_reshaped = tosa_graph.addIntermediate( weight_post_shape, - ts.DType.INT8 if is_quant_node else weight.dtype, + weight.dtype, ) build_reshape( tosa_graph, weight.name, weight_post_shape, weight_reshaped.name @@ -157,20 +162,19 @@ def define_node( # For quantized convolution, rescale the output value back to the same # integer value domain of the next op. Otherwise return float32 output. - if is_quant_node: + if inputs[0].dtype == ts.DType.INT8: # Get scale_factor from input, weight, and output. - input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale - weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale - output_qargs = get_quant_arg_downstream(list(node.users)[0]) - + input_scale = input_qparams[0].scale + weight_scale = input_qparams[1].scale + output_qargs = get_output_qparams(node) build_rescale_conv_output( tosa_graph, # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. conv2d_res, output.name, - actual_out_type, + output.dtype, input_scale, weight_scale, - output_qargs.scale, - output_qargs.zp, + output_qargs[0].scale, + output_qargs[0].zp, ) diff --git a/backends/arm/operators/op_div.py b/backends/arm/operators/op_div.py index 0857e0ed32a..339833c329c 100644 --- a/backends/arm/operators/op_div.py +++ b/backends/arm/operators/op_div.py @@ -13,6 +13,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.arm.tosa_utils import tosa_shape from serializer.tosa_serializer import TosaOp @@ -21,6 +22,11 @@ class DivVisitor(NodeVisitor): target = "aten.div.Tensor" + # Only supported for MI + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80.0+MI"), + ] + def __init__(self, *args): super().__init__(*args) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 74e33ddb02c..0a4092e3a9a 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -13,7 +13,7 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import ( +from executorch.backends.arm.tosa_quant_utils import ( get_quant_arg_downstream, get_quant_arg_upstream, ) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 2d3a0c2786c..3b1ea9d70fe 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -11,10 +11,12 @@ import serializer.tosa_serializer as ts import torch import torch.fx +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( + get_input_qparams, +) from executorch.backends.arm.operators.node_visitor import NodeVisitor from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_upstream, get_quantized_node_output_dtype, is_node_quantized, ) @@ -110,8 +112,10 @@ def process_quantized_bias( _, ) = consumer_node.all_input_nodes - input_node_scale = get_quant_arg_upstream(input_node).scale - weight_node_scale = get_quant_arg_upstream(weight_node).scale + input_qargs = get_input_qparams(consumer_node) + + input_node_scale = input_qargs[0].scale + weight_node_scale = input_qargs[1].scale bias_values_quantized = ( (parameter_values / (input_node_scale * weight_node_scale)) .round() diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 1ae319e0cd7..dd28105a63a 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -7,18 +7,13 @@ import logging import os -from typing import Any, cast +from typing import Any import numpy as np import serializer.tosa_serializer as ts import torch from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_downstream, - get_quant_arg_upstream, - q_op, -) from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -140,10 +135,15 @@ def build_reshape(tosa_fb, input_name, new_shape, output_name): def is_bias_node_for_quantized_conv(node): consumer_node = list(node.users)[0] - return ( + + if ( consumer_node.target == exir_ops.edge.aten.convolution.default - and list(consumer_node.users)[0].target == q_op - ) + and consumer_node.args[2] == node + and consumer_node.meta["val"].dtype == torch.int8 + ): + return True + + return False def is_consumer_node_depthwise_conv2d(node): @@ -159,48 +159,6 @@ def is_consumer_node_depthwise_conv2d(node): return False -def build_avg_pool_2d_common( - node: torch.fx.Node, - tosa_graph: ts.TosaSerializer, - input_tensor: TosaArg, - kernel_size: list, - stride: list, - padding: list, - is_quant_node: bool, - output: TosaArg, -): - accumulator_type = input_tensor.dtype - - if is_quant_node: - # Accumulator type always is int32 when input tensor is an integer type. - accumulator_type = ts.DType.INT32 - - # Initilize zero point to zero. - input_zp = 0 - output_zp = 0 - - if is_quant_node: - input_zp = get_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp - output_zp = get_quant_arg_downstream(list(node.users)[0]).zp - - attr = ts.TosaSerializerAttribute() - attr.PoolAttribute( - kernel=kernel_size, - stride=stride, - pad=padding, - input_zp=input_zp, - output_zp=output_zp, - accum_dtype=accumulator_type, - ) - - tosa_graph.addOperator( - TosaOp.Op().AVG_POOL2D, - [input_tensor.name], - [output.name], - attr, - ) - - def get_two_inputs(node: Node, check: bool = False) -> tuple[Node, Node]: """Returns two input nodes to 'node' in order. If 'node' only has one input, it is returned twice. From 8e6570eda701eda02d6f9242720a107c99245d50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Mon, 25 Nov 2024 18:05:53 +0100 Subject: [PATCH 4/4] Add full operator to fold dq/q handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Per Åstrand Change-Id: I39d11cff0ef78df08e67f216b8e0bb86af9fac26 --- backends/arm/_passes/arm_pass_manager.py | 3 ++ .../fold_qdq_with_annotated_qparams_pass.py | 38 ++++++++++++++++++- backends/arm/operators/op_full.py | 21 +++------- 3 files changed, 45 insertions(+), 17 deletions(-) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index e1c903302c3..b4bb809b851 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -31,6 +31,7 @@ from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( FoldAndAnnotateQParamsPass, + QuantizeFullArgument, ) from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import ( KeepDimsFalseToSqueezePass, @@ -84,6 +85,7 @@ def transform_to_backend_pipeline( self.add_pass(Conv1dUnsqueezePass(exported_program)) self.add_pass(DecomposeSoftmaxesPass()) self.add_pass(DecomposeLinearPass()) + self.add_pass(QuantizeFullArgument()) self.add_pass( FoldAndAnnotateQParamsPass( [ @@ -92,6 +94,7 @@ def transform_to_backend_pipeline( exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.convolution.default, + exir_ops.edge.aten.full.default, ] ) ) diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 24d1a03395d..6ba72eb1022 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -15,6 +15,9 @@ from executorch.exir.pass_base import ExportPass, PassResult from torch.fx import GraphModule, Node +q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default +dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + def get_input_qparams(node: Node) -> dict[int, QuantArgs]: """ @@ -77,8 +80,6 @@ def __init__(self, targeted_ops: Iterable[Callable]): self.targeted_ops = targeted_ops def call(self, graph_module: GraphModule) -> PassResult: - q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default - dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default # Loop over the graph nodes and find any node in the 'targeted_ops' list. for n in graph_module.graph.nodes: @@ -145,3 +146,36 @@ def call(self, graph_module: GraphModule) -> PassResult: graph_module.recompile() return PassResult(graph_module, True) + + +class QuantizeFullArgument(ExportPass): + """ + Make sure the fill_value for full.default is quantized. This pass needs to be run before + the folding pass above to make sure that the retraced output of the full.default op is + the right dtype. + """ + + def call(self, graph_module: GraphModule) -> PassResult: + modified = False + # Loop over the graph nodes and find any node in the 'targeted_ops' list. + for n in graph_module.graph.nodes: + n = cast(Node, n) + if n.target != exir_ops.edge.aten.full.default: + continue + + # Make sure we have a quantized operator + user = list(n.users)[0] + if user.target != q_op: + continue + + qargs = QuantArgs.from_operator(user.target, user.args) + if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype: + # replace the node arg with a quantized dito and also set dtype + # to get the right output according to the Edge IR specification: + # exir/dialects/edge/edge.yaml:3596 + quantized_full_value = qargs.quantize_value(n.args[1]).item() + n.update_arg(1, quantized_full_value) + n.update_kwarg("dtype", qargs.dtype) + modified = True + + return PassResult(graph_module, modified) diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py index d2bc1377ce7..23a13dd4869 100644 --- a/backends/arm/operators/op_full.py +++ b/backends/arm/operators/op_full.py @@ -14,10 +14,6 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_downstream, - quantize_value, -) from executorch.backends.arm.tosa_utils import tosa_shape from torch.fx import Node @@ -41,19 +37,14 @@ def define_node( shape = tosa_shape(inputs[0].special, output.dim_order) value = inputs[1].number - if is_quant_node: - qargs = get_quant_arg_downstream(list(node.users)[0]) - qvalue = quantize_value(value, qargs) - dtype = ts.DType.INT8 - data = np.full(shape, qvalue, dtype=np.int8) + + if output.dtype == ts.DType.INT8: + fill_dtype = np.int8 else: - assert ( - output.dtype == ts.DType.FP32 - ), "'Full' currently only supports FP32 for unquantized models." - dtype = ts.DType.FP32 - data = np.full(shape, value, dtype=np.float32) + fill_dtype = np.float32 + data = np.full(shape, value, dtype=fill_dtype) - tosa_graph.addConst(shape, dtype, data, node.name + "full-const") + tosa_graph.addConst(shape, output.dtype, data, node.name + "full-const") tosa_graph.addOperator( ts.TosaOp.Op.IDENTITY, [node.name + "full-const"], [output.name] )