Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions backends/arm/operator_support/to_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

logger = logging.getLogger(__name__)

SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]


@register_tosa_support_check
class ToCopySupported(SupportedTOSAOperatorCheck):
Expand All @@ -33,8 +35,6 @@ class ToCopySupported(SupportedTOSAOperatorCheck):
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]

@staticmethod
def _merge_supported_types(
# pyre-ignore[11]
Expand All @@ -53,11 +53,22 @@ def _merge_supported_types(
torch.int8: [torch.bool, torch.int16, torch.int32],
torch.int16: [torch.bool, torch.int8, torch.int32],
torch.int32: [torch.bool, torch.int8, torch.int16],
torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32],
}
SUPPORTED_FLOAT_TYPES: SupportedTypeDict = {
torch.int8: [torch.float16, torch.bfloat16, torch.float32],
torch.int16: [torch.float16, torch.bfloat16, torch.float32],
torch.int32: [torch.float16, torch.bfloat16, torch.float32],
# INT64 inputs to casts *should* be ok, since they should be rejected by
# CheckInt64InputsAndOutputs if the cast can't be done AOT.
torch.int64: [
torch.int8,
torch.int16,
torch.int32,
torch.float16,
torch.bfloat16,
torch.float32,
],
torch.bfloat16: [torch.int8, torch.int16, torch.int32, torch.float32],
torch.float16: [torch.int8, torch.int16, torch.int32, torch.float32],
torch.float32: [
Expand All @@ -71,22 +82,20 @@ def _merge_supported_types(
ALL_SUPPORTED_TYPES = _merge_supported_types(
SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES
)
POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32}

def is_node_tosa_supported(
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool:
supported_dtypes = (
self.ALL_SUPPORTED_TYPES
if tosa_spec.support_float()
else self.SUPPORTED_INT_TYPES
)
# Take into account possible type conversions
supported_dtypes.update(
(k, supported_dtypes[v])
for k, v in self.POSSIBLE_TYPE_CONVERSIONS.items()
if v in supported_dtypes
)

supported_dtypes: SupportedTypeDict = {}
if tosa_spec.support_integer():
supported_dtypes = self._merge_supported_types(
self.SUPPORTED_INT_TYPES, supported_dtypes
)
if tosa_spec.support_float():
supported_dtypes = self._merge_supported_types(
self.SUPPORTED_FLOAT_TYPES, supported_dtypes
)

if len(node.all_input_nodes) != 1:
self.reporter.report_reject(
Expand Down Expand Up @@ -156,7 +165,7 @@ def is_node_tosa_supported(
if "dim_order" in node.kwargs:
dim_order = node.kwargs["dim_order"]
# pyre-ignore[6]
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
if dim_order is not None and dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
self.reporter.report_reject(
node,
(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ class TestT5EncoderModel(unittest.TestCase):
# .to_executorch step, i.e. after Arm partitioner.
ops_after_partitioner = {
"executorch_exir_dialects_edge__ops_aten__to_copy_default": 2,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
"torch.ops.higher_order.executorch_call_delegate": 2,
"torch.ops.higher_order.executorch_call_delegate": 3,
}

def _prepare_inputs(
Expand Down
19 changes: 19 additions & 0 deletions backends/arm/test/ops/test_to_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ def test_copy_tosa_FP(test_data: Tuple):
aten_op=[],
exir_op=[],
)
# int to int cast is not supported in TOSA+FP profile
if not new_dtype.is_floating_point and not torch.is_floating_point(test_tensor):
pipeline.change_args(
"check_count.exir",
{
"torch.ops.higher_order.executorch_call_delegate": 0,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
},
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()


Expand All @@ -84,6 +94,15 @@ def test_copy_vgf_FP(test_data: Tuple):
exir_op=[],
tosa_version="TOSA-1.0+FP",
)
# int to int cast is not supported in TOSA+FP profile
if not new_dtype.is_floating_point and not torch.is_floating_point(test_tensor):
pipeline.change_args(
"check_count.exir",
{
"torch.ops.higher_order.executorch_call_delegate": 0,
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1,
},
)
pipeline.run()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,8 @@ def forward(self, x: torch.Tensor):
test_data_suite_convert = {
"fp32_input": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float32), torch.int64),
"fp16_input": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float16), torch.int64),
"int16_input": lambda: (
torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int16),
torch.int64,
),
"int8_input": lambda: (
torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8),
torch.int64,
),
}


test_data_suite_remove = {
"int32_input": lambda: (
torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32),
Expand All @@ -52,7 +43,7 @@ def forward(self, x: torch.Tensor):


@common.parametrize("test_data", test_data_suite_convert)
def test_convert_or_remove_casting_to_int64_covnert_tosa_FP(test_data: Tuple):
def test_convert_or_remove_casting_to_int64_convert_tosa_FP(test_data: Tuple):
test_tensor, target_dtype = test_data()
module = CastingToInt64Model(target_dtype)

Expand Down
Loading