Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion backends/arm/operator_support/index_tensor_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Provide TOSA support checks for ``aten.index.Tensor``.

Reject unsupported patterns such as high-rank index tensors, front-positioned
slice/ellipsis/None markers, and cases that exceed ``int32`` element limits.

"""

import math

Expand All @@ -18,7 +24,8 @@

@register_tosa_support_check
class IndexTensorSupported(SupportedTOSAOperatorCheck):
"""
"""Prevent partitioning of unsupported ``index.Tensor`` usages.

This support check is intended to prevent the partitioning of
currently unsupported usages of the index.Tensor operator.

Expand Down Expand Up @@ -95,6 +102,7 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
t[1:3, torch.arange(5), 2:3, torch.arange(3).reshape(3,1)]
are also possible and can result in some unintuitive behaviors
where batching and indexing are mixed together.

"""

targets = [exir_ops.edge.aten.index.Tensor]
Expand All @@ -107,6 +115,14 @@ class IndexTensorSupported(SupportedTOSAOperatorCheck):
def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool: # type: ignore[override, misc]
"""Return True if ``aten.index.Tensor`` usage fits supported patterns.

Enforces the following constraints:
- No ``None`` (unsqueeze), slice, or ellipsis before an indexing tensor.
- Indexing tensors have rank <= 3.
- The value tensor element count fits in ``int32``.

"""
indices = node.args[1]
for index in indices: # type: ignore[union-attr]
# Usage 2 guard
Expand Down
Loading