From a918f627b6dbe2d46dd9e5ccd490983ddc2c7f38 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 4 Nov 2025 15:09:41 +0100 Subject: [PATCH] Arm backend: Fix decomposition logic for FP If the tosa_spec supports floating point, incorrect quantization is not a problem, and we can always not decompose the ops_to_not_decompose. Signed-off-by: Erik Lundell Change-Id: I409005dd7e005ea8469492ed234440f0e41b739e --- backends/arm/tosa/partitioner.py | 37 +++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 1b9a02fbb6a..fdb4dc62abf 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -335,19 +335,24 @@ def ops_to_not_decompose( function that returns True when an op should not be decomposed. """ - ops_to_not_decompose_if_quant_op = [ + ops_to_not_decompose_if_quant_op = { torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, torch.ops.aten.linear.default, - ] + } + ops_to_not_decompose_if_fp = { + torch.ops.aten.linear.default, + } + ops_to_not_decompose_always = { + torch.ops.aten.eye.default, + torch.ops.aten.linspace.default, + torch.ops.aten.logit.default, + } def filter_fn(node: torch.fx.Node) -> bool: - """Return True to keep selected ops intact inside quantized regions. - - The predicate holds when the target is in - ``ops_to_not_decompose_if_quant_op`` and all inputs/outputs are - quantize/dequantize ops, indicating a quantized activation that - should not be decomposed. + """Filter function applied to ops in 'ops_to_not_decompose'. + Returns True if the op should not be decomposed. + If this function returns True, the partitioner *must* accept the node, or the lowering fails. Args: node (torch.fx.Node): FX node to evaluate. @@ -356,6 +361,12 @@ def filter_fn(node: torch.fx.Node) -> bool: bool: True to keep the op intact; otherwise, False. """ + if ( + self.tosa_spec.support_float() + and node.target in ops_to_not_decompose_if_fp + ): + return True + dq = ( torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_channel.default, @@ -394,11 +405,11 @@ def filter_fn(node: torch.fx.Node) -> bool: # By default, do not decompose the operator return True - ops_to_not_decompose = [ - torch.ops.aten.eye.default, - torch.ops.aten.linspace.default, - torch.ops.aten.logit.default, - ] + ops_to_not_decompose_if_quant_op + ops_to_not_decompose = list( + ops_to_not_decompose_always + | ops_to_not_decompose_if_quant_op + | ops_to_not_decompose_if_fp + ) if not self.tosa_spec.is_U55_subset: # Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d