diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 1b9a02fbb6a..fdb4dc62abf 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -335,19 +335,24 @@ def ops_to_not_decompose( function that returns True when an op should not be decomposed. """ - ops_to_not_decompose_if_quant_op = [ + ops_to_not_decompose_if_quant_op = { torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, torch.ops.aten.linear.default, - ] + } + ops_to_not_decompose_if_fp = { + torch.ops.aten.linear.default, + } + ops_to_not_decompose_always = { + torch.ops.aten.eye.default, + torch.ops.aten.linspace.default, + torch.ops.aten.logit.default, + } def filter_fn(node: torch.fx.Node) -> bool: - """Return True to keep selected ops intact inside quantized regions. - - The predicate holds when the target is in - ``ops_to_not_decompose_if_quant_op`` and all inputs/outputs are - quantize/dequantize ops, indicating a quantized activation that - should not be decomposed. + """Filter function applied to ops in 'ops_to_not_decompose'. + Returns True if the op should not be decomposed. + If this function returns True, the partitioner *must* accept the node, or the lowering fails. Args: node (torch.fx.Node): FX node to evaluate. @@ -356,6 +361,12 @@ def filter_fn(node: torch.fx.Node) -> bool: bool: True to keep the op intact; otherwise, False. """ + if ( + self.tosa_spec.support_float() + and node.target in ops_to_not_decompose_if_fp + ): + return True + dq = ( torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_channel.default, @@ -394,11 +405,11 @@ def filter_fn(node: torch.fx.Node) -> bool: # By default, do not decompose the operator return True - ops_to_not_decompose = [ - torch.ops.aten.eye.default, - torch.ops.aten.linspace.default, - torch.ops.aten.logit.default, - ] + ops_to_not_decompose_if_quant_op + ops_to_not_decompose = list( + ops_to_not_decompose_always + | ops_to_not_decompose_if_quant_op + | ops_to_not_decompose_if_fp + ) if not self.tosa_spec.is_U55_subset: # Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d