From 094975d8cf232385aa11e44707dd4a15e6bf2fde Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Tue, 16 Sep 2025 13:26:54 +0800 Subject: [PATCH 1/2] Solve concat op --- backends/qualcomm/builders/op_cat.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backends/qualcomm/builders/op_cat.py b/backends/qualcomm/builders/op_cat.py index 9f6eb6676cf..8f16f262549 100644 --- a/backends/qualcomm/builders/op_cat.py +++ b/backends/qualcomm/builders/op_cat.py @@ -33,10 +33,11 @@ def define_node( list_of_tensor_wrappers = [] for tensor_input in list_of_tensors: - input_tensor = self.get_tensor(self.get_node(tensor_input), node) + input_node = self.get_node(tensor_input) + input_tensor = self.get_tensor(input_node, node) list_of_tensor_wrappers.append( self.define_tensor( - tensor_input, + input_node, node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, From 7d4953c23805de70bb28e1184d746c212cb9f273 Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Wed, 17 Sep 2025 10:10:02 +0800 Subject: [PATCH 2/2] code review --- backends/qualcomm/builders/op_cat.py | 18 +++++++++--------- backends/qualcomm/tests/models.py | 9 +++++++++ backends/qualcomm/tests/test_qnn_delegate.py | 4 ++-- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/backends/qualcomm/builders/op_cat.py b/backends/qualcomm/builders/op_cat.py index 8f16f262549..644b087ab9c 100644 --- a/backends/qualcomm/builders/op_cat.py +++ b/backends/qualcomm/builders/op_cat.py @@ -29,15 +29,15 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: - list_of_tensors = cast(List[torch.fx.Node], node.args[0]) - list_of_tensor_wrappers = [] + input_nodes = cast(List[torch.fx.Node], node.args[0]) + input_tensor_wrappers = [] - for tensor_input in list_of_tensors: - input_node = self.get_node(tensor_input) - input_tensor = self.get_tensor(input_node, node) - list_of_tensor_wrappers.append( + for input_node in input_nodes: + source_input_node = self.get_node(input_node) + input_tensor = self.get_tensor(source_input_node, node) + input_tensor_wrappers.append( self.define_tensor( - input_node, + source_input_node, node, input_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, @@ -45,7 +45,7 @@ def define_node( ) ) - if len(list_of_tensors) != len(list_of_tensor_wrappers): + if len(input_nodes) != len(input_tensor_wrappers): warnings.warn( "[QNN Delegate Op Builder]: The number or input tensors is not equal to the number of input tensor wrappers.", stacklevel=1, @@ -77,7 +77,7 @@ def define_node( QNN_OP_PACKAGE_NAME_QTI_AISW, OpConcat.op_name, ) - concat_op.AddInputTensors(list_of_tensor_wrappers) + concat_op.AddInputTensors(input_tensor_wrappers) concat_op.AddOutputTensors([output_tensor_wrapper]) concat_op.AddScalarParam( diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 77ff1be4562..2de2cd098aa 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -274,6 +274,15 @@ def forward(self, x, y): return torch.cat((y, y, x, x), axis=2) +class Cat5(torch.nn.Module): + def __init__(self): + super().__init__() + self.const_tensor = torch.randn(1, 1, 2, 2) + + def forward(self, x, y): + return torch.cat((x, y, self.const_tensor), axis=2) + + class CausalMask(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 5a86d5f286d..0e75cf2844a 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -232,7 +232,7 @@ def test_qnn_backend_cast(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_cat(self): - modules = [Cat2(), Cat3(), Cat4()] # noqa: F405 + modules = [Cat2(), Cat3(), Cat4(), Cat5()] # noqa: F405 sample_input = (torch.randn(1, 1, 2, 2), torch.randn(1, 1, 4, 2)) for i, module in enumerate(modules): with self.subTest(i=i): @@ -1699,7 +1699,7 @@ def test_qnn_backend_cast(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_cat(self): - modules = [Cat2(), Cat3(), Cat4()] # noqa: F405 + modules = [Cat2(), Cat3(), Cat4(), Cat5()] # noqa: F405 sample_input = (torch.randn(1, 1, 2, 2), torch.randn(1, 1, 4, 2)) for i, module in enumerate(modules): with self.subTest(i=i):