diff --git a/backends/arm/operator_support/to_dim_order_copy_support.py b/backends/arm/operator_support/to_dim_order_copy_support.py index ced9b7c5afc..3cc587d99d3 100644 --- a/backends/arm/operator_support/to_dim_order_copy_support.py +++ b/backends/arm/operator_support/to_dim_order_copy_support.py @@ -2,6 +2,13 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for ``_to_dim_order_copy`` in TOSA. + +Provide dtype-compatibility checks for casting when converting to a specific +dimension order. Supported input/output dtype pairs depend on the active TOSA +profile (integer and/or float). + +""" # pyre-unsafe import copy @@ -25,6 +32,16 @@ @register_tosa_support_check class ToCopySupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for ``_to_dim_order_copy``. + + Attributes: + SUPPORTED_INT_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]): + Allowed output dtypes for each integer input dtype. + SUPPORTED_FP_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]): + Allowed output dtypes for each floating input dtype. + + """ + targets = [ exir_ops.edge.dim_order_ops._to_dim_order_copy.default, ] @@ -40,21 +57,31 @@ def _merge_supported_types( dtypes1: SupportedTypeDict, dtypes2: SupportedTypeDict, ) -> SupportedTypeDict: + """Return a merged mapping of supported dtype transitions. + + Args: + dtypes1 (dict[torch.dtype, list[torch.dtype]]): Base mapping. + dtypes2 (dict[torch.dtype, list[torch.dtype]]): Mapping to merge in. + + Returns: + dict[torch.dtype, list[torch.dtype]]: Combined mapping. + + """ merged_dtypes = copy.deepcopy( dtypes1 - ) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_TYPES + ) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_PROFILE_DTYPES for k, v in dtypes2.items(): merged_dtypes[k] = merged_dtypes.get(k, []) + v return merged_dtypes - SUPPORTED_INT_TYPES: SupportedTypeDict = { + SUPPORTED_INT_PROFILE_DTYPES: SupportedTypeDict = { torch.bool: [torch.bool, torch.int8, torch.int16, torch.int32], torch.int8: [torch.bool, torch.int8, torch.int16, torch.int32], torch.int16: [torch.bool, torch.int8, torch.int16, torch.int32], torch.int32: [torch.bool, torch.int8, torch.int16, torch.int32], torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32], } - SUPPORTED_FLOAT_TYPES: SupportedTypeDict = { + SUPPORTED_FP_PROFILE_DTYPES: SupportedTypeDict = { torch.int8: [torch.int8, torch.float16, torch.bfloat16, torch.float32], torch.int16: [torch.int16, torch.float16, torch.bfloat16, torch.float32], torch.int32: [torch.int32, torch.float16, torch.bfloat16, torch.float32], @@ -92,22 +119,25 @@ def _merge_supported_types( torch.float32, ], } - ALL_SUPPORTED_TYPES = _merge_supported_types( - SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES - ) def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: + """Return True if the node is supported by TOSA. + + Check FakeTensor metadata, validate input dtype is supported for the + active profile, and ensure the output dtype is allowed for the given + input dtype. + """ supported_dtypes: SupportedTypeDict = {} if tosa_spec.support_integer(): supported_dtypes = self._merge_supported_types( - self.SUPPORTED_INT_TYPES, supported_dtypes + self.SUPPORTED_INT_PROFILE_DTYPES, supported_dtypes ) if tosa_spec.support_float(): supported_dtypes = self._merge_supported_types( - self.SUPPORTED_FLOAT_TYPES, supported_dtypes + self.SUPPORTED_FP_PROFILE_DTYPES, supported_dtypes ) if len(node.all_input_nodes) != 1: