diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 6fe70aa696c..b67bded7fb9 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -310,11 +310,11 @@ def is_node_supported( if not input_quantized: return False - output_quantized = output_quantized or all( - (output_node.target == self.q_op) - or (not get_first_fake_tensor(output_node).dtype.is_floating_point) - for output_node in node.users + all_q_users = all( + (output_node.target == self.q_op) for output_node in node.users ) + is_floating_point = get_first_fake_tensor(node).dtype.is_floating_point + output_quantized = output_quantized or all_q_users or not is_floating_point if not output_quantized: return False diff --git a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py index 3fe339e0f9e..5bb692ebcaf 100644 --- a/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py +++ b/backends/arm/test/misc/test_partition_decomposed_quantized_ops.py @@ -19,21 +19,39 @@ ) input_t1 = Tuple[torch.Tensor] -aten_op: list[str] = ["torch.ops.aten.add.Tensor", "torch.ops.aten.softplus.default"] -exir_op: list[str] = [ +softplus_aten_op: list[str] = [ + "torch.ops.aten.add.Tensor", + "torch.ops.aten.softplus.default", +] +softplus_exir_op: list[str] = [ "executorch_exir_dialects_edge__ops_aten_add_Tensor", "executorch_exir_dialects_edge__ops_aten_mul_Tensor", "executorch_exir_dialects_edge__ops_aten_exp_default", "executorch_exir_dialects_edge__ops_aten_div_Tensor", ] +linear_residual_aten_op: list[str] = [ + "torch.ops.aten.linear.default", + "torch.ops.aten.gelu.default", + "torch.ops.aten.dropout.default", + "torch.ops.aten.add.Tensor", +] +linear_residual_exir_op: list[str] = [ + "executorch_exir_dialects_edge__ops_aten_gelu_default", + "executorch_exir_dialects_edge__ops_aten_clone_default", + "executorch_exir_dialects_edge__ops_aten_linear_default", + "executorch_exir_dialects_edge__ops_aten_add_Tensor", +] + test_data: dict[input_t1] = { "3d_rand": (torch.rand(1, 5, 5),), } -class Module(torch.nn.Module): +class SoftplusModule(torch.nn.Module): + """Module containing an addition followed by a Softplus. Softplus is currently not supported by TosaBackend.""" + def __init__(self): super().__init__() self.softplus = torch.nn.Softplus() @@ -42,10 +60,35 @@ def forward(self, x: torch.Tensor): return self.softplus(x + x) +class LinearResidualModule(torch.nn.Module): + """Module containing a residual and a linear layer followed by GELU and a Dropout. + GELU is currently not supported by TosaBackend nor TosaQuantizer. + """ + + def __init__( + self, + ): + super().__init__() + self.linear = torch.nn.Linear(in_features=5, out_features=3) + self.gelu = torch.nn.GELU() + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, x: torch.Tensor): + x1 = self.linear(x) + x2 = self.gelu(x1) + x3 = self.dropout(x2) + return x1 + x3 + + +# Softplus is decomposed which messes up the quantization. This test tests that CheckProperQuantization does not +# partition nodes where quantization is not as expected. @common.parametrize("test_data", test_data) def test_softplus_tosa_MI(test_data: input_t1): pipeline = TosaPipelineMI[input_t1]( - Module(), test_data=test_data, aten_op=aten_op, exir_op=exir_op + SoftplusModule(), + test_data=test_data, + aten_op=softplus_aten_op, + exir_op=softplus_exir_op, ) # remove check_count.exir as there will be more than one delegate pipeline.pop_stage("check_count.exir") @@ -55,14 +98,76 @@ def test_softplus_tosa_MI(test_data: input_t1): @common.parametrize("test_data", test_data) def test_softplus_tosa_BI(test_data: input_t1): pipeline = TosaPipelineBI[input_t1]( - Module(), test_data=test_data, aten_op=aten_op, exir_op=exir_op + SoftplusModule(), + test_data=test_data, + aten_op=softplus_aten_op, + exir_op=softplus_exir_op, + ) + pipeline.pop_stage("check_not.exir") + # check that all ops in softplus_exir_op except add are rejected + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.check, + softplus_exir_op[1:], + suffix="exir_post_partition", + ) + pipeline.run() + + +# Since GELU will not be quantized by TosaQuantizer, the Dropout's input will not be quantized either. +# If so, the Dropout should not be partitioned by TosaPartitioner for TOSA BI profile. This test tests that the +# partitioner indeed does not partition the Dropout (clone) for TOSA BI. +@common.parametrize("test_data", test_data) +def test_linear_residaul_tosa_MI(test_data: input_t1): + pipeline = TosaPipelineMI[input_t1]( + LinearResidualModule(), + test_data=test_data, + aten_op=linear_residual_aten_op, + exir_op=linear_residual_exir_op, + use_to_edge_transform_and_lower=True, + ) + # remove check_count.exir as there will be more than one delegate + pipeline.pop_stage("check_count.exir") + pipeline.pop_stage("check_not.exir") + # check that all ops in linear_residual_exir_op except GELU are partitioned + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.check_not, + linear_residual_exir_op[1:], + suffix="exir_post_partition", + ) + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.check, + linear_residual_exir_op[:1], + suffix="exir_post_partition", + ) + pipeline.run() + + +@common.parametrize("test_data", test_data) +def test_linear_residual_tosa_BI(test_data: input_t1): + pipeline = TosaPipelineBI[input_t1]( + LinearResidualModule(), + test_data=test_data, + aten_op=linear_residual_aten_op, + exir_op=linear_residual_exir_op, + use_to_edge_transform_and_lower=True, ) + # remove check_count.exir as there will be more than one delegate + pipeline.pop_stage("check_count.exir") pipeline.pop_stage("check_not.exir") - # check that all ops in exir_op except add are rejected + # check that all ops in linear_residual_exir_op except GELU and Dropout are partitioned + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.check_not, + linear_residual_exir_op[2:], + suffix="exir_post_partition", + ) pipeline.add_stage_after( "to_edge_transform_and_lower", pipeline.tester.check, - exir_op[1:], + linear_residual_exir_op[:2], suffix="exir_post_partition", ) pipeline.run() diff --git a/backends/arm/tosa_partitioner.py b/backends/arm/tosa_partitioner.py index c2bc48d98d7..ab36a5a6f1e 100644 --- a/backends/arm/tosa_partitioner.py +++ b/backends/arm/tosa_partitioner.py @@ -14,6 +14,7 @@ get_tosa_spec, is_tosa, ) # usort: skip +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm.operator_support.tosa_supported_operators import ( tosa_support_factory, ) @@ -66,7 +67,7 @@ def __init__( self.delegation_spec = DelegationSpec(TOSABackend.__name__, compile_spec) self.additional_checks = additional_checks - def partition(self, exported_program: ExportedProgram) -> PartitionResult: + def partition(self, exported_program: ExportedProgram) -> PartitionResult: # noqa # Run the CapabilityBasedPartitioner to return the largest possible # subgraphs containing the nodes with the tags @@ -110,6 +111,20 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: del node.meta["delegation_tag"] break + if tosa_spec.support_float(): + continue + + if is_partitioned(node): + for input in node.all_input_nodes: + if is_partitioned(input): + continue + if get_first_fake_tensor(input).dtype.is_floating_point: + logger.info( + f"Not partitioning {node.name} becuase input {input.name} has floating point dtype." + ) + del node.meta["delegation_tag"] + break + tag_constant_data(exported_program) return PartitionResult(