diff --git a/backends/arm/operator_support/to_copy_support.py b/backends/arm/operator_support/to_copy_support.py index 12c3d90b52b..e357efe3f24 100644 --- a/backends/arm/operator_support/to_copy_support.py +++ b/backends/arm/operator_support/to_copy_support.py @@ -20,6 +20,8 @@ logger = logging.getLogger(__name__) +SupportedTypeDict = dict[torch.dtype, list[torch.dtype]] + @register_tosa_support_check class ToCopySupported(SupportedTOSAOperatorCheck): @@ -33,8 +35,6 @@ class ToCopySupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - SupportedTypeDict = dict[torch.dtype, list[torch.dtype]] - @staticmethod def _merge_supported_types( # pyre-ignore[11] @@ -53,11 +53,22 @@ def _merge_supported_types( torch.int8: [torch.bool, torch.int16, torch.int32], torch.int16: [torch.bool, torch.int8, torch.int32], torch.int32: [torch.bool, torch.int8, torch.int16], + torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32], } SUPPORTED_FLOAT_TYPES: SupportedTypeDict = { torch.int8: [torch.float16, torch.bfloat16, torch.float32], torch.int16: [torch.float16, torch.bfloat16, torch.float32], torch.int32: [torch.float16, torch.bfloat16, torch.float32], + # INT64 inputs to casts *should* be ok, since they should be rejected by + # CheckInt64InputsAndOutputs if the cast can't be done AOT. + torch.int64: [ + torch.int8, + torch.int16, + torch.int32, + torch.float16, + torch.bfloat16, + torch.float32, + ], torch.bfloat16: [torch.int8, torch.int16, torch.int32, torch.float32], torch.float16: [torch.int8, torch.int16, torch.int32, torch.float32], torch.float32: [ @@ -71,22 +82,20 @@ def _merge_supported_types( ALL_SUPPORTED_TYPES = _merge_supported_types( SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES ) - POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32} def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: - supported_dtypes = ( - self.ALL_SUPPORTED_TYPES - if tosa_spec.support_float() - else self.SUPPORTED_INT_TYPES - ) - # Take into account possible type conversions - supported_dtypes.update( - (k, supported_dtypes[v]) - for k, v in self.POSSIBLE_TYPE_CONVERSIONS.items() - if v in supported_dtypes - ) + + supported_dtypes: SupportedTypeDict = {} + if tosa_spec.support_integer(): + supported_dtypes = self._merge_supported_types( + self.SUPPORTED_INT_TYPES, supported_dtypes + ) + if tosa_spec.support_float(): + supported_dtypes = self._merge_supported_types( + self.SUPPORTED_FLOAT_TYPES, supported_dtypes + ) if len(node.all_input_nodes) != 1: self.reporter.report_reject( @@ -156,7 +165,7 @@ def is_node_tosa_supported( if "dim_order" in node.kwargs: dim_order = node.kwargs["dim_order"] # pyre-ignore[6] - if dim_order != list(range(len(dim_order))): # type: ignore[arg-type] + if dim_order is not None and dim_order != list(range(len(dim_order))): # type: ignore[arg-type] self.reporter.report_reject( node, ( diff --git a/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py b/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py index 0628f010f08..953f14e0b57 100644 --- a/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py +++ b/backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py @@ -33,8 +33,9 @@ class TestT5EncoderModel(unittest.TestCase): # .to_executorch step, i.e. after Arm partitioner. ops_after_partitioner = { "executorch_exir_dialects_edge__ops_aten__to_copy_default": 2, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, "executorch_exir_dialects_edge__ops_aten_view_copy_default": 1, - "torch.ops.higher_order.executorch_call_delegate": 2, + "torch.ops.higher_order.executorch_call_delegate": 3, } def _prepare_inputs( diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py index db04b9425c2..21197aa14fa 100644 --- a/backends/arm/test/ops/test_to_copy.py +++ b/backends/arm/test/ops/test_to_copy.py @@ -70,6 +70,16 @@ def test_copy_tosa_FP(test_data: Tuple): aten_op=[], exir_op=[], ) + # int to int cast is not supported in TOSA+FP profile + if not new_dtype.is_floating_point and not torch.is_floating_point(test_tensor): + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 0, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, + }, + ) + pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() @@ -84,6 +94,15 @@ def test_copy_vgf_FP(test_data: Tuple): exir_op=[], tosa_version="TOSA-1.0+FP", ) + # int to int cast is not supported in TOSA+FP profile + if not new_dtype.is_floating_point and not torch.is_floating_point(test_tensor): + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 0, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1, + }, + ) pipeline.run() diff --git a/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py b/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py index cfed4245eed..ea7e03f8e21 100644 --- a/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py +++ b/backends/arm/test/passes/test_convert_int64_output_ops_to_int32.py @@ -32,17 +32,8 @@ def forward(self, x: torch.Tensor): test_data_suite_convert = { "fp32_input": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float32), torch.int64), "fp16_input": lambda: (torch.rand((1, 2, 3, 4), dtype=torch.float16), torch.int64), - "int16_input": lambda: ( - torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int16), - torch.int64, - ), - "int8_input": lambda: ( - torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), - torch.int64, - ), } - test_data_suite_remove = { "int32_input": lambda: ( torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32), @@ -52,7 +43,7 @@ def forward(self, x: torch.Tensor): @common.parametrize("test_data", test_data_suite_convert) -def test_convert_or_remove_casting_to_int64_covnert_tosa_FP(test_data: Tuple): +def test_convert_or_remove_casting_to_int64_convert_tosa_FP(test_data: Tuple): test_tensor, target_dtype = test_data() module = CastingToInt64Model(target_dtype)