Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions backends/arm/operator_support/to_dim_order_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Declare operator support for ``_to_dim_order_copy`` in TOSA.

Provide dtype-compatibility checks for casting when converting to a specific
dimension order. Supported input/output dtype pairs depend on the active TOSA
profile (integer and/or float).

"""

# pyre-unsafe
import copy
Expand All @@ -25,6 +32,16 @@

@register_tosa_support_check
class ToCopySupported(SupportedTOSAOperatorCheck):
"""Provide TOSA support check for ``_to_dim_order_copy``.

Attributes:
SUPPORTED_INT_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]):
Allowed output dtypes for each integer input dtype.
SUPPORTED_FP_PROFILE_DTYPES (dict[torch.dtype, list[torch.dtype]]):
Allowed output dtypes for each floating input dtype.

"""

targets = [
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
]
Expand All @@ -40,21 +57,31 @@ def _merge_supported_types(
dtypes1: SupportedTypeDict,
dtypes2: SupportedTypeDict,
) -> SupportedTypeDict:
"""Return a merged mapping of supported dtype transitions.

Args:
dtypes1 (dict[torch.dtype, list[torch.dtype]]): Base mapping.
dtypes2 (dict[torch.dtype, list[torch.dtype]]): Mapping to merge in.

Returns:
dict[torch.dtype, list[torch.dtype]]: Combined mapping.

"""
merged_dtypes = copy.deepcopy(
dtypes1
) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_TYPES
) # Use deepcopy to avoid unintentionally modifying SUPPORTED_INT_PROFILE_DTYPES
for k, v in dtypes2.items():
merged_dtypes[k] = merged_dtypes.get(k, []) + v
return merged_dtypes

SUPPORTED_INT_TYPES: SupportedTypeDict = {
SUPPORTED_INT_PROFILE_DTYPES: SupportedTypeDict = {
torch.bool: [torch.bool, torch.int8, torch.int16, torch.int32],
torch.int8: [torch.bool, torch.int8, torch.int16, torch.int32],
torch.int16: [torch.bool, torch.int8, torch.int16, torch.int32],
torch.int32: [torch.bool, torch.int8, torch.int16, torch.int32],
torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32],
}
SUPPORTED_FLOAT_TYPES: SupportedTypeDict = {
SUPPORTED_FP_PROFILE_DTYPES: SupportedTypeDict = {
torch.int8: [torch.int8, torch.float16, torch.bfloat16, torch.float32],
torch.int16: [torch.int16, torch.float16, torch.bfloat16, torch.float32],
torch.int32: [torch.int32, torch.float16, torch.bfloat16, torch.float32],
Expand Down Expand Up @@ -92,22 +119,25 @@ def _merge_supported_types(
torch.float32,
],
}
ALL_SUPPORTED_TYPES = _merge_supported_types(
SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES
)

def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool:
"""Return True if the node is supported by TOSA.

Check FakeTensor metadata, validate input dtype is supported for the
active profile, and ensure the output dtype is allowed for the given
input dtype.

"""
supported_dtypes: SupportedTypeDict = {}
if tosa_spec.support_integer():
supported_dtypes = self._merge_supported_types(
self.SUPPORTED_INT_TYPES, supported_dtypes
self.SUPPORTED_INT_PROFILE_DTYPES, supported_dtypes
)
if tosa_spec.support_float():
supported_dtypes = self._merge_supported_types(
self.SUPPORTED_FLOAT_TYPES, supported_dtypes
self.SUPPORTED_FP_PROFILE_DTYPES, supported_dtypes
)

if len(node.all_input_nodes) != 1:
Expand Down
Loading