diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index c62854767d2..4b10a58b97f 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -19,8 +19,6 @@ from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.sym_util import eval_shape -from .utils import dq_ops, q_ops - class LayoutTransform(ExportPass): """ @@ -91,8 +89,6 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.topk.default, exir_ops.edge.aten._to_copy.default, exir_ops.edge.aten.where.self, - *q_ops, - *dq_ops, _operator.getitem, } @@ -117,7 +113,6 @@ def __init__( super(LayoutTransform, self).__init__() self.edge_program = edge_program self.insert_permute = insert_permute - self.qdq_opset = {*q_ops, *dq_ops} self.transformed_tag = QCOM_AXIS_ORDER def mark_as_transformed(self, node: torch.fx.Node) -> None: diff --git a/backends/qualcomm/builders/op_index.py b/backends/qualcomm/builders/op_index.py index e78284a5e32..ff039f9d7a8 100644 --- a/backends/qualcomm/builders/op_index.py +++ b/backends/qualcomm/builders/op_index.py @@ -38,11 +38,11 @@ def define_node( nodes_to_wrappers, ) - if len(node.args[1]) > 1: - # TODO consider to implement it in a recursive way. - raise NotImplementedError("Not support tuple of tensor.") - - indices_node = node.args[1][0] + # e.g. x[:, index]: + # > node.args[1] = [None, indices] + # > axis = 1 + axis = len(node.args[1]) - 1 + indices_node = node.args[1][axis] indices_tensor = self.get_tensor(indices_node, node).to(torch.int32) assert indices_tensor.size(0) != 0, "Not support empty indices list" @@ -78,7 +78,7 @@ def define_node( gather_op.AddScalarParam( OpGather.param_axis, PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, - {QCOM_DATA: np.int32(0)}, + {QCOM_DATA: np.int32(axis)}, ) return gather_op diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index da4fcb52b77..cd4c83892a5 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -746,13 +746,19 @@ def forward(self, x): class Index(torch.nn.Module): - def __init__(self): + def __init__(self, axis): super().__init__() self.idx0 = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int32) self.idx1 = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32) + self.axis = axis + self.dispatcher = { + 0: lambda x: x[self.idx0] + x[self.idx1], + 1: lambda x: x[:, self.idx0] + x[:, self.idx1], + 2: lambda x: x[:, :, self.idx0] + x[:, :, self.idx1], + } def forward(self, x): - return x[self.idx0] + x[self.idx1] + return self.dispatcher[self.axis](x) class IndexPut(torch.nn.Module): diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index f3ab354543d..cfd5d5c8076 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -469,9 +469,11 @@ def test_qnn_backend_hardtanh(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_index(self): - module = Index() # noqa: F405 + modules = [Index(0), Index(1), Index(2)] # noqa: F405 sample_input = (torch.randn([8, 172, 64]),) - self.lower_module_and_test_output(module, sample_input) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_index_put(self): module = IndexPut() # noqa: F405 @@ -1457,10 +1459,12 @@ def test_qnn_backend_hardtanh(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_index(self): - module = Index() # noqa: F405 + modules = [Index(0), Index(1), Index(2)] # noqa: F405 sample_input = (torch.randn([8, 172, 64]),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_index_put(self): module = IndexPut() # noqa: F405