From 063a196b2819004ae83bd339bd43b377607145f0 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Thu, 19 Jun 2025 16:13:36 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - Ensure that math invariant ops don't change scale and offset Summary: - After QNN 2.35, op validation will check math invariant op should not change scale and offset. So, we should replace annotate_single_in_single_out with annotate_single_in_share_out to ensure not to change scale and offset - Fix the error for internal CI - Fix the bug for batch norm op - Fix the bug for index_put op --- .../qualcomm/_passes/remove_redundancy.py | 1 + backends/qualcomm/builders/op_batch_norm.py | 2 +- backends/qualcomm/builders/op_index_put.py | 228 ++++++++++++++++-- backends/qualcomm/quantizer/annotators.py | 45 ++-- backends/qualcomm/tests/test_qnn_delegate.py | 10 + 5 files changed, 246 insertions(+), 40 deletions(-) diff --git a/backends/qualcomm/_passes/remove_redundancy.py b/backends/qualcomm/_passes/remove_redundancy.py index 22d476ef21b..2ec8161613b 100644 --- a/backends/qualcomm/_passes/remove_redundancy.py +++ b/backends/qualcomm/_passes/remove_redundancy.py @@ -29,6 +29,7 @@ def __init__(self, quantization_capture=False): # remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True exir_ops.edge.aten._to_copy.default: self._to_copy_op_condition, torch.ops.aten._assert_tensor_metadata.default: self._default_condition, + torch.ops.aten._assert_scalar.default: self._default_condition, } self.redundant_ops_annotation = { torch.ops.aten._assert_tensor_metadata.default: self._default_condition, diff --git a/backends/qualcomm/builders/op_batch_norm.py b/backends/qualcomm/builders/op_batch_norm.py index 48c5d5d1b51..25a9c2b123e 100644 --- a/backends/qualcomm/builders/op_batch_norm.py +++ b/backends/qualcomm/builders/op_batch_norm.py @@ -128,7 +128,7 @@ def define_node( bias_tensor = self.try_dequantize( bias_node, get_parameter(bias_node, self.edge_program) ) - amount = (filter_tensor * mean_tensor) / torch.sqrt(var_tensor + eps) + amount = filter_tensor * mean_tensor bias_tensor = bias_tensor - amount self.update_encoding(bias_node, bias_tensor, eps) bias_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py index de59b1a0489..9972066e165 100644 --- a/backends/qualcomm/builders/op_index_put.py +++ b/backends/qualcomm/builders/op_index_put.py @@ -1,13 +1,22 @@ +import warnings from typing import Dict import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import numpy as np import torch -from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS +from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS +from executorch.exir.dialects._ops import ops as exir_ops -from .node_visitor import NodeVisitor +from .node_visitor import NodeVisitor, QNN_TENSOR_TYPE_MAP from .node_visitor_manager import register_node_visitor -from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW +from .qnn_constants import ( + OpConcat, + OpReshape, + OpScatterNd, + OpTile, + QNN_OP_PACKAGE_NAME_QTI_AISW, +) @register_node_visitor @@ -22,6 +31,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: + op_wrapper_list = [] input_node = self.get_node(node.args[0]) # Because the args[0] of index_put op doesn't annotate, need to fill in the quant_attr with the node here. if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): @@ -35,31 +45,198 @@ def define_node( PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - indicies_node = node.args[1] - indices_list = [ - self.get_tensor(idx, idx) for idx in indicies_node if idx is not None - ] - - # Unpack the tuple - indices_unpacked = [torch.flatten(idx) for idx in indices_list] - - # Convert to 2-D tensor - indices_qnn = torch.cat(indices_unpacked).unsqueeze(0) - indice_node = [n for n in indicies_node if isinstance(n, torch.fx.Node)] - # TODO consider to write a pass to combine to one input tensor for indices - assert len(indice_node) == 1, "Not support multiple indices tensor" + indicies_node = node.args[1] + index_node_dim = None + index_nodes = [] + index_tensors = [] + target_index = [] + # If there is None in a list, it means all range at that dimension + # E.g., indicies_node: [None, None, aten__to_copy_default_1] + if isinstance(indicies_node, list): + for index, idx_node in enumerate(indicies_node): + # First, collect the indice_node and index of None to construct the shape of index node + # E.g., shape of input: [1, 1024, 12, 64] + # For "None" axis (assume indicies_node: [None, None, aten__to_copy_default_1]), + # target_index: [1, 1024, x], x is the shape of index_tensor, index_node_dim: 2 + if isinstance(idx_node, torch.fx.Node): + index_nodes.append(idx_node) + index_tensors.append(self.get_tensor(idx_node, idx_node)) + target_index.extend(index_tensors[-1].size()) + index_node_dim = index + elif idx_node is None and index_node_dim is None: + # E.g., indicies_node: [None, aten__to_copy_default_1, None] + # Don't need to consider "None" after index_node. + target_index.append(input_tensor.size(index)) + else: + warnings.warn( + f"[QNN Delegate Op Builder]: Get the index {idx_node} that is neither a node nor None", + stacklevel=1, + ) + return + # Assume that there is only one node in list + assert len(index_nodes) == 1, "Not support multiple indices tensor" + indice_node = index_nodes[0] + indice_tensor = index_tensors[0] indices_tensor_wrapper = self.define_tensor( - indice_node[0], + indice_node, node, - indices_qnn, + indice_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) - value_node = self.get_node(node.args[2]) - value_tensor = self.get_tensor(value_node, node) + # Need to reconstruct the index tensor. + # E.g., based on ScatterND Op Def in QNN Docs. + # Given that + # shape of input: [1, 12, 1024, 64] + # indicies_node: [None, None, aten__to_copy_default_1] + # shape of aten__to_copy_default_1: [1] + # The shape of index tensor should be [1, 12, 1, 3] + # The index tensor is treated as 4-dimensional tensor of 3-tuples, + # where each 3-tuple is a partial-index into input + # Reference code for QNN ScatterNd: + # output = np.copy(input) + # update_indices = indices.shape[:-1] + # for idx in np.ndindex(update_indices): + # output[indices[idx]] = updates[idx] + + # Append one dimension to specify x-tuple + index_shape = target_index + [1] + # Reshape the index_node for tile op + reshape_shape = [ + shape if id == index_node_dim else 1 for id, shape in enumerate(index_shape) + ] + reshape_output_tensor = indice_tensor.reshape(reshape_shape) + reshape_output_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_reshape", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_TENSOR_TYPE_MAP[reshape_output_tensor.dtype], + quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + quant_configs={}, + dims=reshape_output_tensor.size(), + tensor=reshape_output_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + reshape_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpReshape.op_name, + ) + reshape_op.AddInputTensors([indices_tensor_wrapper]) + reshape_op.AddOutputTensors([reshape_output_tensor_wrapper]) + op_wrapper_list.append(reshape_op) + index_put_index_input_tensor_wrapper = reshape_output_tensor_wrapper + + # Tile the index_node and concat the target index + if None in indicies_node: + tile_output_tensor = reshape_output_tensor.expand(index_shape) + # Tile the index_node to align with the shape of target_index + # Only need to tile the dim of None axis + # E.g., indicies_node: [None, None, aten__to_copy_default_1] + # Should tile the first two dimension. + multiples = [ + shape if id != index_node_dim else 1 + for id, shape in enumerate(index_shape) + ] + multiples_shape = [len(index_shape)] + tile_output_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_tile", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_TENSOR_TYPE_MAP[tile_output_tensor.dtype], + quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + quant_configs={}, + dims=tile_output_tensor.size(), + tensor=tile_output_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + tile_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpTile.op_name, + ) + tile_op.AddInputTensors([reshape_output_tensor_wrapper]) + tile_op.AddOutputTensors([tile_output_tensor_wrapper]) + tile_op.AddTensorParam( + OpTile.param_multiples, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(multiples_shape), + multiples_shape, + np.array(multiples, dtype=np.uint32), + True, + ) + op_wrapper_list.append(tile_op) + + # Repeat index for "None" axis in indicies_node + ranges = [ + torch.arange(dim, dtype=indice_tensor.dtype) + for dim in target_index[:-1] + ] + target_index_shape = target_index + [len(ranges)] + target_index_tensor = torch.cartesian_prod(*ranges) + reshape_target_index_shape = [ + shape if id != index_node_dim else 1 + for id, shape in enumerate(target_index_shape) + ] + target_index_tensor = target_index_tensor.reshape( + reshape_target_index_shape + ) + target_index_tensor = target_index_tensor.expand( + target_index_shape + ).contiguous() + target_index_node = torch.fx.Node( + node.graph, + node.name + "_target_index", + "call_function", + exir_ops.edge.aten.tensor.default, + (), # args + {}, # kwargs + ) + target_index_tensor_wrapper = self.define_tensor( + target_index_node, + node, + target_index_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) + # Concat target_index and tile output to reconstruct index_node + # Cannot use QNN Pack (stack) since QNN Pack is not support int32 dtype + concat_output_tensor = torch.concat( + (target_index_tensor, tile_output_tensor), dim=-1 + ) + concat_output_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_concat", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_TENSOR_TYPE_MAP[concat_output_tensor.dtype], + quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + quant_configs={}, + dims=concat_output_tensor.size(), + tensor=concat_output_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + concat_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpConcat.op_name, + ) + concat_op.AddInputTensors( + [target_index_tensor_wrapper, tile_output_tensor_wrapper] + ) + concat_op.AddOutputTensors([concat_output_tensor_wrapper]) + concat_op.AddScalarParam( + OpConcat.param_axis, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(concat_output_tensor.dim() - 1)}, + ) + op_wrapper_list.append(concat_op) + index_put_index_input_tensor_wrapper = concat_output_tensor_wrapper + + value_node = self.get_node(node.args[2]) + value_tensor = self.get_tensor(value_node, node) value_tensor_wrapper = self.define_tensor( value_node, node, @@ -67,6 +244,7 @@ def define_node( PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) + output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, @@ -82,8 +260,12 @@ def define_node( OpScatterNd.op_name, ) index_put_op.AddInputTensors( - [input_tensor_wrapper, indices_tensor_wrapper, value_tensor_wrapper] + [ + input_tensor_wrapper, + index_put_index_input_tensor_wrapper, + value_tensor_wrapper, + ] ) index_put_op.AddOutputTensors([output_tensor_wrapper]) - - return index_put_op + op_wrapper_list.append(index_put_op) + return op_wrapper_list diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index e1e2ca6dff6..6233abb01e1 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -108,6 +108,26 @@ def annotate_in_out_obs_sharing_op( ) +def annotate_single_in_share_out( + node: Node, quantization_config: QuantizationConfig +) -> None: + if _is_annotated([node]): + return + + input_qspec_map = {} + if _is_float_tensor(node.args[0]): + input_act = node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = quantization_config.input_activation + + if _is_float_tensor(node): + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=SharedQuantizationSpec((input_act, node)), + _annotated=True, + ) + + def annotate_single_in(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]): return @@ -505,7 +525,7 @@ def annotate_avgpool2d(node: Node, quantization_config: QuantizationConfig) -> N def annotate_permute(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) if not _is_annotated([node]): - annotate_single_in_single_out(node, quantization_config) + annotate_single_in_share_out(node, quantization_config) @register_annotator([torch.ops.aten.prelu.default]) @@ -523,7 +543,7 @@ def annotate_prelu(node: Node, quantization_config: QuantizationConfig) -> None: def annotate_view(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) if not _is_annotated([node]): - annotate_single_in_single_out(node, quantization_config) + annotate_single_in_share_out(node, quantization_config) @register_annotator([torch.ops.aten.pixel_shuffle.default]) @@ -641,7 +661,7 @@ def annotate_scaled_dot_product_attention( def annotate_squeeze(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) if not _is_annotated([node]): - annotate_single_in_single_out(node, quantization_config) + annotate_single_in_share_out(node, quantization_config) @register_annotator([torch.ops.aten.rms_norm.default]) @@ -757,7 +777,7 @@ def annotate_pow(node: Node, quantization_config: QuantizationConfig) -> None: def annotate_unsqueeze(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) if not _is_annotated([node]): - annotate_single_in_single_out(node, quantization_config) + annotate_single_in_share_out(node, quantization_config) @register_annotator( @@ -770,14 +790,14 @@ def annotate_unsqueeze_copy( ) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) if not _is_annotated([node]): - annotate_single_in_single_out(node, quantization_config) + annotate_single_in_share_out(node, quantization_config) @register_annotator([torch.ops.aten.transpose.int]) def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) if not _is_annotated([node]): - annotate_single_in_single_out(node, quantization_config) + annotate_single_in_share_out(node, quantization_config) @register_annotator([torch.ops.aten.elu.default]) @@ -803,14 +823,7 @@ def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> N def annotate_index(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) if not _is_annotated([node]): - input_qspec_map = {} - input = node.args[0] - input_qspec_map[input] = quantization_config.input_activation - node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=SharedQuantizationSpec((input, node)), - _annotated=True, - ) + annotate_single_in_share_out(node, quantization_config) @register_annotator( @@ -856,7 +869,7 @@ def annotate_exp(node: Node, quantization_config: QuantizationConfig) -> None: def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) if not _is_annotated([node]): - annotate_single_in_single_out(node, quantization_config) + annotate_single_in_share_out(node, quantization_config) @register_annotator([torch.ops.aten.group_norm.default]) @@ -896,7 +909,7 @@ def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) -> def annotate_flatten(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) if not _is_annotated([node]): - annotate_single_in_single_out(node, quantization_config) + annotate_single_in_share_out(node, quantization_config) @register_annotator([torch.ops.aten.stack.default]) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 747a6804957..ace8ee36213 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -351,6 +351,7 @@ def test_qnn_backend_element_wise_and(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_element_wise_ceil(self): + torch.manual_seed(8) module = Ceil() # noqa: F405 sample_input = (torch.randn([2, 5, 1, 3]),) self.lower_module_and_test_output(module, sample_input) @@ -715,6 +716,7 @@ def test_qnn_backend_layer_norm(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_leaky_relu(self): + torch.manual_seed(8) test_comb = [ { QCOM_MODULE: [LeakyReLUDefault()], # noqa: F405 @@ -3085,6 +3087,10 @@ def test_qnn_backend_draw_graph(self): ), "Generated .dot file does not match the golden file." def test_qnn_backend_generate_optrace(self): + if self.enable_x86_64: + self.skipTest( + "At the moment, testing is only being conducted on the device." + ) module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) backend_options = generate_htp_compiler_spec(use_fp16=True) @@ -3791,6 +3797,10 @@ def test_qnn_backend_draw_graph(self): ), "Generated .dot file does not match the golden file." def test_qnn_backend_generate_optrace(self): + if self.enable_x86_64: + self.skipTest( + "At the moment, testing is only being conducted on the device." + ) module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) module = self.get_qdq_module(module, sample_input)