From 1326fc03a46fe79fd353b5244ef38e55ee9b2e99 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Thu, 11 Sep 2025 08:51:04 +0200 Subject: [PATCH] Arm backend: Add docstrings for operator_support/index_select_support.py Change-Id: I20f8e606a153a6095fa340614edb8d4de2cf52ff Signed-off-by: Sebastian Larsson --- .../arm/operator_support/index_select_support.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/backends/arm/operator_support/index_select_support.py b/backends/arm/operator_support/index_select_support.py index 79f1d154a14..a83151adab7 100644 --- a/backends/arm/operator_support/index_select_support.py +++ b/backends/arm/operator_support/index_select_support.py @@ -2,7 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for ``aten.index_select`` in TOSA. +Accept int32 indices and restrict supported weight shapes to 2D or 3D with a +unit batch dimension. + +""" import torch import torch.fx as fx from executorch.backends.arm.operator_support.tosa_supported_operators import ( @@ -15,6 +20,8 @@ @register_tosa_support_check class IndexSelectSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for ``aten.index_select``.""" + targets = [exir_ops.edge.aten.index_select.default] tosa_specs = [ @@ -25,7 +32,12 @@ class IndexSelectSupported(SupportedTOSAOperatorCheck): def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: # type: ignore[override, misc] + """Return True if the node is supported by TOSA. + + Require int32 indices and limit weight shapes to 2D or 3D with a leading + dimension of 1. + """ weights_shape = node.all_input_nodes[0].meta["val"].shape indices_val = node.all_input_nodes[1].meta["val"] indices_dtype = indices_val.dtype