From c89d8692f1fd29b84082735d4c646840db626b4d Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Thu, 11 Sep 2025 16:50:02 +0200 Subject: [PATCH] Arm backend: Add docstrings for operator_support/index_tensor_support.py Signed-off-by: Sebastian Larsson Change-Id: If0b3ab50fbfa20fa0ba11c3dd70a3824c599b171 --- .../operator_support/index_tensor_support.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/backends/arm/operator_support/index_tensor_support.py b/backends/arm/operator_support/index_tensor_support.py index 25bc79ea938..92b0ce48a32 100644 --- a/backends/arm/operator_support/index_tensor_support.py +++ b/backends/arm/operator_support/index_tensor_support.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide TOSA support checks for ``aten.index.Tensor``. + +Reject unsupported patterns such as high-rank index tensors, front-positioned +slice/ellipsis/None markers, and cases that exceed ``int32`` element limits. + +""" import math @@ -18,7 +24,8 @@ @register_tosa_support_check class IndexTensorSupported(SupportedTOSAOperatorCheck): - """ + """Prevent partitioning of unsupported ``index.Tensor`` usages. + This support check is intended to prevent the partitioning of currently unsupported usages of the index.Tensor operator. @@ -95,6 +102,7 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck): t[1:3, torch.arange(5), 2:3, torch.arange(3).reshape(3,1)] are also possible and can result in some unintuitive behaviors where batching and indexing are mixed together. + """ targets = [exir_ops.edge.aten.index.Tensor] @@ -107,6 +115,14 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck): def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] + """Return True if ``aten.index.Tensor`` usage fits supported patterns. + + Enforces the following constraints: + - No ``None`` (unsqueeze), slice, or ellipsis before an indexing tensor. + - Indexing tensors have rank <= 3. + - The value tensor element count fits in ``int32``. + + """ indices = node.args[1] for index in indices: # type: ignore[union-attr] # Usage 2 guard