Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions backends/qualcomm/builders/op_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,23 @@ def define_node(
node: torch.fx.Node,
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
) -> PyQnnWrapper.PyQnnOpWrapper:
list_of_tensors = cast(List[torch.fx.Node], node.args[0])
list_of_tensor_wrappers = []
input_nodes = cast(List[torch.fx.Node], node.args[0])
input_tensor_wrappers = []

for tensor_input in list_of_tensors:
input_tensor = self.get_tensor(self.get_node(tensor_input), node)
list_of_tensor_wrappers.append(
for input_node in input_nodes:
source_input_node = self.get_node(input_node)
input_tensor = self.get_tensor(source_input_node, node)
input_tensor_wrappers.append(
self.define_tensor(
tensor_input,
source_input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
)
)

if len(list_of_tensors) != len(list_of_tensor_wrappers):
if len(input_nodes) != len(input_tensor_wrappers):
warnings.warn(
"[QNN Delegate Op Builder]: The number or input tensors is not equal to the number of input tensor wrappers.",
stacklevel=1,
Expand Down Expand Up @@ -76,7 +77,7 @@ def define_node(
QNN_OP_PACKAGE_NAME_QTI_AISW,
OpConcat.op_name,
)
concat_op.AddInputTensors(list_of_tensor_wrappers)
concat_op.AddInputTensors(input_tensor_wrappers)
concat_op.AddOutputTensors([output_tensor_wrapper])

concat_op.AddScalarParam(
Expand Down
9 changes: 9 additions & 0 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,15 @@ def forward(self, x, y):
return torch.cat((y, y, x, x), axis=2)


class Cat5(torch.nn.Module):
def __init__(self):
super().__init__()
self.const_tensor = torch.randn(1, 1, 2, 2)

def forward(self, x, y):
return torch.cat((x, y, self.const_tensor), axis=2)


class CausalMask(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
4 changes: 2 additions & 2 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_qnn_backend_cast(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_cat(self):
modules = [Cat2(), Cat3(), Cat4()] # noqa: F405
modules = [Cat2(), Cat3(), Cat4(), Cat5()] # noqa: F405
sample_input = (torch.randn(1, 1, 2, 2), torch.randn(1, 1, 4, 2))
for i, module in enumerate(modules):
with self.subTest(i=i):
Expand Down Expand Up @@ -1699,7 +1699,7 @@ def test_qnn_backend_cast(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_cat(self):
modules = [Cat2(), Cat3(), Cat4()] # noqa: F405
modules = [Cat2(), Cat3(), Cat4(), Cat5()] # noqa: F405
sample_input = (torch.randn(1, 1, 2, 2), torch.randn(1, 1, 4, 2))
for i, module in enumerate(modules):
with self.subTest(i=i):
Expand Down
Loading