From 02001bf448127f08a3334531be0ae30a7a2552c1 Mon Sep 17 00:00:00 2001 From: haowhsu-quic Date: Tue, 15 Apr 2025 14:28:50 +0800 Subject: [PATCH 1/2] Qualcomm AI Engine Direct - xr model enablement (mld_f) Summary - add gather op support - make cast logic / slice op more general --- backends/qualcomm/_passes/__init__.py | 2 + .../qualcomm/_passes/fuse_consecutive_cast.py | 113 ++++++++++++++++++ backends/qualcomm/_passes/i64_to_i32.py | 29 +++++ backends/qualcomm/_passes/qnn_pass_manager.py | 3 + .../qualcomm/_passes/remove_redundancy.py | 17 ++- backends/qualcomm/builders/__init__.py | 2 + backends/qualcomm/builders/op_argmin.py | 53 ++------ backends/qualcomm/builders/op_gather.py | 75 ++++++++++++ backends/qualcomm/builders/op_to.py | 45 ++++++- backends/qualcomm/builders/qnn_constants.py | 6 + backends/qualcomm/quantizer/annotators.py | 2 +- backends/qualcomm/tests/models.py | 36 ++++++ backends/qualcomm/tests/test_qnn_delegate.py | 75 ++++++++++-- 13 files changed, 396 insertions(+), 62 deletions(-) create mode 100644 backends/qualcomm/_passes/fuse_consecutive_cast.py create mode 100644 backends/qualcomm/builders/op_gather.py diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 81b86992dee..44fa3f69ed5 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -20,6 +20,7 @@ from .expand_broadcast_tensor_shape import ExpandBroadcastTensorShape from .fixed_linear_keep_dim import FixedLinearKeepDim from .fold_qdq import FoldQDQ +from .fuse_consecutive_cast import FuseConsecutiveCast from .fuse_consecutive_transpose import FuseConsecutiveTranspose from .i64_to_i32 import I64toI32 from .insert_io_qdq import InsertIOQDQ @@ -54,6 +55,7 @@ ExpandBroadcastTensorShape, FixedLinearKeepDim, FoldQDQ, + FuseConsecutiveCast, FuseConsecutiveTranspose, I64toI32, InsertIOQDQ, diff --git a/backends/qualcomm/_passes/fuse_consecutive_cast.py b/backends/qualcomm/_passes/fuse_consecutive_cast.py new file mode 100644 index 00000000000..6c27521a950 --- /dev/null +++ b/backends/qualcomm/_passes/fuse_consecutive_cast.py @@ -0,0 +1,113 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + + +class FuseConsecutiveCast(ExportPass): + """ + This pass fuses consecutive cast into one or none to reduce runtime + overhead. + To simplify the fuse logic, we ensure each cast node's output has at most 1 cast node + by cloning cast. + Example: + Before clone cast: + relu -> cast1 ─> cast2 + |──────> cast3 + + After clone cast: + relu ─> cast1 ──────> cast2 + |───> cast4(new) ─> cast3 + """ + + def __init__(self): + super().__init__() + self.op_map = { + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.aten._to_copy.default, + } + self.visited = set() + self.nodes = [] + + def _canonicalize_cast( + self, graph_module: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + # replace all i64 cast nodes with i32 version + graph = graph_module.graph + for n in graph_module.graph.nodes: + if n.target in self.op_map and n.meta["val"].dtype == torch.int64: + users = list(n.users) + for user in users: + # bypass graph output node to meet original convention + if user.op == "output": + continue + + with graph.inserting_after(n): + cast_node = graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + n.args, + kwargs={"dtype": torch.int32}, + ) + cast_node.meta = n.meta + cast_node.meta["val"] = cast_node.meta["val"].to(torch.int32) + user.replace_input_with(n, cast_node) + + graph.eliminate_dead_code() + + # clone nodes for future fusion + for n in graph_module.graph.nodes: + # make sure we're handling cast node instead of convert node + if n.target in self.op_map and n.kwargs.get("dtype", None) is not None: + users = [user for user in list(n.users) if user.target in self.op_map] + if len(users) > 1: + for i in range(1, len(users)): + with graph.inserting_after(n): + clone_cast_node = graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (n.args[0]), + ) + clone_cast_node.meta = n.meta + users[i].replace_input_with(n, clone_cast_node) + + def _traverse(self, node): + if node in self.visited or node.target not in self.op_map: + return + + self.nodes.append(node) + self.visited.add(node) + next_users = [n for n in list(node.users) if n.target in self.op_map] + + assert ( + len(next_users) <= 1 + ), "Each cast node should have at most 1 cast output node after _clone_cast" + if not next_users: + return + else: + self._traverse(list(node.users)[0]) + + def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in graph_module.graph.nodes: + self._traverse(n) + if len(self.nodes) > 1: + input_node, output_node = self.nodes[0], self.nodes[-1] + output_node.replace_input_with(output_node.args[0], input_node.args[0]) + + # clear current stack + self.nodes = [] + + def call(self, graph_module: torch.fx.GraphModule): + self._canonicalize_cast(graph_module) + self._fuse(graph_module) + graph_module.recompile() + dead_code_elimination_pass(graph_module) + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/i64_to_i32.py b/backends/qualcomm/_passes/i64_to_i32.py index f13b035552c..c29d01e9393 100644 --- a/backends/qualcomm/_passes/i64_to_i32.py +++ b/backends/qualcomm/_passes/i64_to_i32.py @@ -31,6 +31,14 @@ class I64toI32(ExportPass): exir_ops.edge.aten.full.default, exir_ops.edge.aten.scalar_tensor.default, } + # This dict is to ensure that the input of the OPs are int64 due to Pytorch restrictions. + # For example, scatter op can only accept args[2], the index, as int64. + # Key: Ops to cast input to i64 + # Value: The args' indices to add casting op + I64_IN_OPS = { + exir_ops.edge.aten.gather.default: [2], + exir_ops.edge.aten.scatter.src: [2], + } copy_op = exir_ops.edge.aten._to_copy.default def __init__( @@ -141,11 +149,32 @@ def _cast_constant_to_int32(self, graph_module: torch.fx.GraphModule): n.replace_all_uses_with(to_dst_node) to_dst_node.args = (n,) + def _cast_op_args_to_i64(self, graph_module: torch.fx.GraphModule): + # input will be cast to i32 during call_operator dtype propogation + # insert i64 cast node to prevent PyTorch's operator validation failure + for node in graph_module.graph.nodes: + if node.target in self.I64_IN_OPS: + with graph_module.graph.inserting_before(node): + arg_indices = self.I64_IN_OPS[node.target] + for arg_index in arg_indices: + input_node = node.args[arg_index] + cast_i64_node = graph_module.graph.create_node( + "call_function", + self.copy_op, + (input_node,), + {"dtype": torch.int64}, + ) + cast_i64_node.meta["val"] = node.meta["val"].to(torch.int64) + args_list = list(node.args) + args_list[arg_index] = cast_i64_node + node.args = tuple(args_list) + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # Record original output dtype to ensure that if user expects int64 as output, # convert the output back to int64 if it is casted from int64->int32. self._record_original_output_dtype(graph_module) self._cast_constant_to_int32(graph_module) + self._cast_op_args_to_i64(graph_module) graph_module = super().call(graph_module).graph_module self._preserve_output_dtype(graph_module) graph_module.recompile() diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index 63c303eb689..81d9dd9203b 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -25,6 +25,7 @@ ExpandBroadcastTensorShape, FixedLinearKeepDim, FoldQDQ, + FuseConsecutiveCast, FuseConsecutiveTranspose, I64toI32, InsertIOQDQ, @@ -182,6 +183,7 @@ def transform_for_to_edge_pipeline( # Before quantizer def transform_for_annotation_pipeline(self, graph_module: GraphModule): + self.add_pass(RemoveRedundancy(quantization_capture=True)) self.add_pass(ReduceDynamicRange()) self.add_pass(RecomposePixelUnshuffle(quantization_capture=True)) self.add_pass(ReplaceArangeArgs()) @@ -214,5 +216,6 @@ def transform_for_preprocess_pipeline(self, exported_program: ExportedProgram): self.add_pass(InsertRequantize()) self.add_pass(InsertIOQDQ(exported_program)) self.add_pass(LayoutTransform(exported_program, insert_permute=True)) + self.add_pass(FuseConsecutiveCast()) self.add_pass(FuseConsecutiveTranspose()) return self._transform(exported_program.graph_module) diff --git a/backends/qualcomm/_passes/remove_redundancy.py b/backends/qualcomm/_passes/remove_redundancy.py index 07b13d4dd67..d045e7732e2 100644 --- a/backends/qualcomm/_passes/remove_redundancy.py +++ b/backends/qualcomm/_passes/remove_redundancy.py @@ -14,9 +14,9 @@ class RemoveRedundancy(ExportPass): Trim certain operators to reduce unnecessary overhead. """ - def __init__(self): + def __init__(self, quantization_capture=False): super(RemoveRedundancy, self).__init__() - self.redundant_ops = { + self.redundant_ops_general = { torch.clone: self._default_condition, torch.ops.aten.clone.default: self._default_condition, exir_ops.edge.aten.clone.default: self._default_condition, @@ -27,7 +27,16 @@ def __init__(self): exir_ops.edge.dim_order_ops._to_dim_order_copy.default: self._dim_order_op_condition, # remove channel_last / contiguous _to_copy if '_skip_dim_order' is set to True exir_ops.edge.aten._to_copy.default: self._to_copy_op_condition, + torch.ops.aten._assert_tensor_metadata.default: self._default_condition, } + self.redundant_ops_annotation = { + torch.ops.aten._assert_tensor_metadata.default: self._default_condition, + } + self.redundant_ops = ( + self.redundant_ops_annotation + if quantization_capture + else self.redundant_ops_general + ) def _dim_order_op_condition(self, node): dim_order = node.kwargs.get("dim_order") @@ -49,6 +58,10 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: continue to_be_remove = n + # assert_tensor_metadata op has no user + if len(n.users.keys()) == 0: + n.args = () + # normal case for user_n in list(n.users.keys()): user_n.replace_input_with(n, n.args[0]) graph_module.graph.erase_node(to_be_remove) diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 705d5d163cd..27faa036dd5 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -32,6 +32,7 @@ op_expand, op_full, op_full_like, + op_gather, op_ge, op_gelu, op_group_norm, @@ -120,6 +121,7 @@ op_expand, op_full, op_full_like, + op_gather, op_ge, op_gelu, op_group_norm, diff --git a/backends/qualcomm/builders/op_argmin.py b/backends/qualcomm/builders/op_argmin.py index fa3fad4a61b..0717e6489fd 100644 --- a/backends/qualcomm/builders/op_argmin.py +++ b/backends/qualcomm/builders/op_argmin.py @@ -10,8 +10,8 @@ import torch from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA -from .node_visitor import NodeVisitor, QNN_TENSOR_TYPE_MAP, register_node_visitor -from .qnn_constants import OpArgmin, OpCast, QNN_OP_PACKAGE_NAME_QTI_AISW +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpArgmin, QNN_OP_PACKAGE_NAME_QTI_AISW @register_node_visitor @@ -26,7 +26,6 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - op_wrapper_list = [] input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) output_tensor = self.get_tensor(node, node) @@ -38,26 +37,14 @@ def define_node( nodes_to_wrappers, ) argmin_input_tensors = [argmin_inp_tensor_wrapper] - - # arg output is index, do not quantize it. - node.meta.pop("quant_attrs", None) - input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf( - input_node, node - ) - - argmin_intermediate_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name + "_cast", - tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=QNN_TENSOR_TYPE_MAP[torch.int32], - quant_encoding=input_quant_encoding, - quant_configs=input_quant_configs, - dims=output_tensor.size(), - tensor=output_tensor, - is_fake_tensor=True, - nodes_to_wrappers=nodes_to_wrappers, + argmin_out_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor.to(torch.int32), + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, ) - - argmin_output_tensors = [argmin_intermediate_tensor_wrapper] + argmin_output_tensors = [argmin_out_tensor_wrapper] dim = cast(int, node.args[1]) if dim < 0: @@ -87,24 +74,4 @@ def define_node( {QCOM_DATA: keep_dims}, ) - op_wrapper_list.append(argmin_op) - - cast_op = PyQnnWrapper.PyQnnOpWrapper( - node.name + "_cast", - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpCast.op_name, - ) - - output_tensor_wrapper = self.define_tensor( - node, - node, - output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, - ) - - cast_op.AddInputTensors([argmin_intermediate_tensor_wrapper]) - cast_op.AddOutputTensors([output_tensor_wrapper]) - op_wrapper_list.append(cast_op) - - return op_wrapper_list + return argmin_op diff --git a/backends/qualcomm/builders/op_gather.py b/backends/qualcomm/builders/op_gather.py new file mode 100644 index 00000000000..7eb8caed57f --- /dev/null +++ b/backends/qualcomm/builders/op_gather.py @@ -0,0 +1,75 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast, Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_DATA + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpGatherElements, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Gather(NodeVisitor): + target = ["aten.gather.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + dim = cast(int, node.args[1]) + + indices_node = node.args[2] + indices_tensor = self.get_tensor(indices_node, node) + indices_tensor_wrapper = self.define_tensor( + indices_node, + node, + indices_tensor.to(torch.int32), + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + gather_input_tensors = [input_tensor_wrapper, indices_tensor_wrapper] + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + gather_output_tensors = [output_tensor_wrapper] + + gather_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpGatherElements.op_name, + ) + gather_op.AddInputTensors(gather_input_tensors) + gather_op.AddOutputTensors(gather_output_tensors) + gather_op.AddScalarParam( + OpGatherElements.param_axis, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: np.uint32(dim)}, + ) + + return gather_op diff --git a/backends/qualcomm/builders/op_to.py b/backends/qualcomm/builders/op_to.py index dc1062846ed..688b4857946 100644 --- a/backends/qualcomm/builders/op_to.py +++ b/backends/qualcomm/builders/op_to.py @@ -10,7 +10,7 @@ import torch from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS -from .node_visitor import NodeVisitor, register_node_visitor +from .node_visitor import NodeVisitor, QNN_TENSOR_TYPE_MAP, register_node_visitor from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -90,9 +90,45 @@ def define_node( PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) + node_input_tensors = [input_tensor_wrapper] + + # if the output / input dtype is int64, we should cast it to int32 first + # since int32 is the only source that can be cast into int64 + # this is mainly for validation purpose, redundant cast ops will be fused + # in preprocess stage. + ops = [] + if ( + node.meta["val"].dtype == torch.int64 + and input_node.meta["val"].dtype != torch.int32 + ) or ( + input_node.meta["val"].dtype == torch.int64 + and node.meta["val"].dtype != torch.int32 + ): + input_quant_encoding, input_quant_configs = self.get_quant_encoding_conf( + input_node, node + ) + cast_intermediate_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_cast", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_TENSOR_TYPE_MAP[torch.int32], + quant_encoding=input_quant_encoding, + quant_configs=input_quant_configs, + dims=input_tensor.size(), + tensor=input_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + cast_op = PyQnnWrapper.PyQnnOpWrapper( + f"{node.name}_cast", + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpCast.op_name, + ) + node_input_tensors = [cast_intermediate_tensor_wrapper] + cast_op.AddInputTensors([input_tensor_wrapper]) + cast_op.AddOutputTensors([cast_intermediate_tensor_wrapper]) + ops.append(cast_op) output_tensor = self.get_tensor(node, node) - output_tensor_wrapper = self.define_tensor( node, node, @@ -105,7 +141,8 @@ def define_node( op = PyQnnWrapper.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name ) - op.AddInputTensors([input_tensor_wrapper]) + op.AddInputTensors(node_input_tensors) op.AddOutputTensors([output_tensor_wrapper]) + ops.append(op) - return op + return ops diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 06e398f7c05..c13a126f76d 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -252,6 +252,12 @@ class OpGather: param_axis: str = "axis" +@dataclass(init=False, frozen=True) +class OpGatherElements: + op_name: str = "GatherElements" + param_axis: str = "axis" + + @dataclass(init=False, frozen=True) class OpGatherND: op_name: str = "GatherNd" diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 469a801feeb..c2b15c5f226 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -750,7 +750,7 @@ def annotate_elu(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) -@register_annotator([torch.ops.aten.embedding.default]) +@register_annotator([torch.ops.aten.embedding.default, torch.ops.aten.gather.default]) def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> None: weight = node.args[0] diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index adf6e256f54..69934414a56 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -729,6 +729,34 @@ def forward(self, x): return torch.min(x, torch.full_like(x, self.fill)) +class Gather(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.gather(x, dim=1, index=y) + + +class GatherArgmin(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + indice = torch.argmin(x, dim=1, keepdim=True) + return torch.gather(x, dim=1, index=indice) + + +class GatherWhere(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + index = torch.where(y > 0, torch.Tensor([1]).int(), torch.Tensor([1]).int()).to( + torch.int64 + ) + return torch.gather(x, x.dim() - 1, index) + + class Gelu(torch.nn.Module): def __init__(self): super().__init__() @@ -1398,6 +1426,14 @@ def forward(self, x, y): return x[:, :seq_length] + self.position_ids[:, :seq_length] +class SliceCopyDefaultParameter(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.cat([x[:1], x[1:]], dim=1) + + class SliceCopyWithStep(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 74c85b773c2..fbc01b6470c 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -478,6 +478,26 @@ def test_qnn_backend_full_like(self): sample_input = (torch.randn(1, 2, 3, 4),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_gather(self): + modules = [ + Gather(), # noqa: F405 + # TODO: resolve accuracy problem + # GatherArgmin(), # noqa: F405 + GatherWhere(), # noqa: F405 + ] + shape = (2, 2, 3, 4) + sample_inputs = [ + ( + torch.arange(128, dtype=torch.float32).view(64, 2), + torch.ones(64, 2, dtype=torch.int64), + ), + # (torch.arange(128, dtype=torch.float32).view(64, 2),), + (torch.randn(shape), torch.randn(shape)), + ] + for i, (module, sample_input) in enumerate(zip(modules, sample_inputs)): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_gelu(self): module = Gelu() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -821,12 +841,17 @@ def test_qnn_backend_select_copy(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_slice_copy(self): - modules = [SliceCopy(), SliceCopyWithStep()] # noqa: F405 - sample_input = ( - torch.randn([1, 512]), - torch.randn([1, 8]), - ) - for module in modules: + modules = [ + SliceCopyDefaultParameter(), # noqa: F405 + SliceCopy(), # noqa: F405 + SliceCopyWithStep(), # noqa: F405 + ] + sample_inputs = [ + (torch.randn([2, 1, 320, 512]),), + (torch.randn([1, 512]), torch.randn([1, 8])), + (torch.randn([1, 512]), torch.randn([1, 8])), + ] + for module, sample_input in zip(modules, sample_inputs): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_stack(self): @@ -1596,6 +1621,27 @@ def test_qnn_backend_full_like(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_gather(self): + modules = [ + Gather(), # noqa: F405 + # TODO: resolve accuracy problem + # GatherArgmin(), # noqa: F405 + GatherWhere(), # noqa: F405 + ] + shape = (2, 2, 3, 4) + sample_inputs = [ + ( + torch.arange(128, dtype=torch.float32).view(64, 2), + torch.ones(64, 2, dtype=torch.int64), + ), + # (torch.arange(128, dtype=torch.float32).view(64, 2),), + (torch.randn(shape), torch.randn(shape)), + ] + for i, (module, sample_input) in enumerate(zip(modules, sample_inputs)): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_gelu(self): module = Gelu() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -1994,12 +2040,17 @@ def test_qnn_backend_sin(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_slice_copy(self): - modules = [SliceCopy(), SliceCopyWithStep()] # noqa: F405 - sample_input = ( - torch.randn([1, 512]), - torch.randn([1, 8]), - ) - for module in modules: + modules = [ + SliceCopyDefaultParameter(), # noqa: F405 + SliceCopy(), # noqa: F405 + SliceCopyWithStep(), # noqa: F405 + ] + sample_inputs = [ + (torch.randn([2, 1, 320, 512]),), + (torch.randn([1, 512]), torch.randn([1, 8])), + (torch.randn([1, 512]), torch.randn([1, 8])), + ] + for module, sample_input in zip(modules, sample_inputs): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) From dca006f308f4d65624bf3a438208a6231faf232d Mon Sep 17 00:00:00 2001 From: haowhsu-quic Date: Wed, 7 May 2025 17:38:02 +0800 Subject: [PATCH 2/2] add test coverage of cast op --- .../qualcomm/_passes/fuse_consecutive_cast.py | 5 ++++- backends/qualcomm/tests/models.py | 10 ++++++++++ backends/qualcomm/tests/test_qnn_delegate.py | 17 ++++++++++++++--- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/backends/qualcomm/_passes/fuse_consecutive_cast.py b/backends/qualcomm/_passes/fuse_consecutive_cast.py index 6c27521a950..1b727ffd75b 100644 --- a/backends/qualcomm/_passes/fuse_consecutive_cast.py +++ b/backends/qualcomm/_passes/fuse_consecutive_cast.py @@ -74,7 +74,8 @@ def _canonicalize_cast( clone_cast_node = graph.create_node( "call_function", exir_ops.edge.aten._to_copy.default, - (n.args[0]), + n.args, + kwargs=n.kwargs, ) clone_cast_node.meta = n.meta users[i].replace_input_with(n, clone_cast_node) @@ -98,6 +99,8 @@ def _traverse(self, node): def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: self._traverse(n) + # TODO: how to handle following scenario (won't happen for quantized graph) + # fp -> to(i32) -> to(fp) if len(self.nodes) > 1: input_node, output_node = self.nodes[0], self.nodes[-1] output_node.replace_input_with(output_node.args[0], input_node.args[0]) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 69934414a56..b7be6aeed88 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -166,6 +166,16 @@ def forward(self, x): return x.type(torch.IntTensor) +class CastMultiUsers(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + index = x.to(torch.long) + res = torch.gather(y, dim=1, index=index) + return res + index.to(torch.int32) + + class Cat2(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index fbc01b6470c..c17af6919a6 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -165,9 +165,14 @@ def test_qnn_backend_bmm(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_cast(self): - module = Cast() # noqa: F405 - sample_input = (10 * torch.rand((9, 4, 5, 3)),) - self.lower_module_and_test_output(module, sample_input) + modules = [Cast(), CastMultiUsers()] # noqa: F405 + sample_inputs = [ + (10 * torch.rand((9, 4, 5, 3)),), + (torch.randint(0, 3, size=(3, 3)), torch.randn(3, 3)), + ] + for i, (module, sample_input) in enumerate(zip(modules, sample_inputs)): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_cat(self): modules = [Cat2(), Cat3(), Cat4()] # noqa: F405 @@ -1234,6 +1239,12 @@ def test_qnn_backend_bmm(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cast(self): + module = CastMultiUsers() # noqa: F405 + sample_input = (torch.randint(0, 3, size=(3, 3)), torch.randn(3, 3)) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cat(self): modules = [Cat2(), Cat3(), Cat4()] # noqa: F405 sample_input = (torch.randn(1, 1, 2, 2), torch.randn(1, 1, 4, 2))