diff --git a/backends/arm/operator_support/minmax_support.py b/backends/arm/operator_support/minmax_support.py index 68433819f4b..8ba5d9335dc 100644 --- a/backends/arm/operator_support/minmax_support.py +++ b/backends/arm/operator_support/minmax_support.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for min/max along a dimension in TOSA. + +Provide support checks ensuring that argmax/argmin indices are not consumed, +restricting to float profiles until index quantization is supported. + +""" import torch.fx as fx from executorch.backends.arm.operator_support.tosa_supported_operators import ( @@ -14,6 +20,8 @@ @register_tosa_support_check class MinMaxSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for ``aten.max.dim`` and ``aten.min.dim``.""" + targets = [ exir_ops.edge.aten.max.dim, exir_ops.edge.aten.min.dim, @@ -24,7 +32,16 @@ class MinMaxSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: + """Return True if the node is supported by TOSA. + + Allow max/min when the argmax/argmin output is unused or dropped (i.e., + only the value is consumed). Disallow cases where arg indices are + further used. + + """ if node.target in [exir_ops.edge.aten.max.dim, exir_ops.edge.aten.min.dim]: no_argmax = len(node.users) == 1 no_argmax_users = (len(node.users) == 2) and (