From 9e0e4aa4c559c38dd0eb44d81552388c9dcd8810 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Tue, 12 Aug 2025 11:12:04 +0200 Subject: [PATCH] Arm backend: Replace asserts with report_reject in operator_support Remve redundant check `node.target in self.targets`, as well as replacing asserts with proper report_reject. This way the graph won't be stopped from lowering, but the operators will instead end up on CPU. Signed-off-by: Sebastian Larsson Change-Id: I9a9dcdd61eda269a328dd8f28ff6d3c238dc2ba5 --- .../arm/operator_support/to_copy_support.py | 57 ++++++++++++++----- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/backends/arm/operator_support/to_copy_support.py b/backends/arm/operator_support/to_copy_support.py index 4841728fe12..12c3d90b52b 100644 --- a/backends/arm/operator_support/to_copy_support.py +++ b/backends/arm/operator_support/to_copy_support.py @@ -76,8 +76,6 @@ def _merge_supported_types( def is_node_tosa_supported( self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: - assert node.target in self.targets - supported_dtypes = ( self.ALL_SUPPORTED_TYPES if tosa_spec.support_float() @@ -90,10 +88,27 @@ def is_node_tosa_supported( if v in supported_dtypes ) - # Check input type - assert len(node.all_input_nodes) == 1 + if len(node.all_input_nodes) != 1: + self.reporter.report_reject( + node, + ( + "Expected exactly one input node, " + f"got {len(node.all_input_nodes)} for {node.target}." + ), + ) + return False input_val = node.all_input_nodes[0].meta["val"] - assert isinstance(input_val, torch._subclasses.FakeTensor) + if not isinstance(input_val, torch._subclasses.FakeTensor): + self.reporter.report_reject( + node, + ( + "Invalid or missing meta: expected FakeTensor input, got " + f"{type(input_val).__name__} for {node.target}." + ), + ) + return False + + # Check input type input_dtype = input_val.dtype if input_dtype not in supported_dtypes: self.reporter.report_reject( @@ -104,14 +119,24 @@ def is_node_tosa_supported( # Check output type output_val = node.meta["val"] - assert isinstance(output_val, torch._subclasses.FakeTensor) + if not isinstance(output_val, torch._subclasses.FakeTensor): + self.reporter.report_reject( + node, + ( + "Invalid or missing meta: expected FakeTensor output, got " + f"{type(output_val).__name__} for {node.target}." + ), + ) + return False if output_val.dtype not in supported_dtypes[input_dtype]: self.reporter.report_reject( node, - f"Output dtype {output_val.dtype} is not supported in " - f"{node.target} for input dtype {input_dtype}. " - f"Supported output types: " - f"{''.join(str(t) for t in supported_dtypes[input_dtype])}", + ( + f"Output dtype {output_val.dtype} is not supported in " + f"{node.target} for input dtype {input_dtype}. " + f"Supported output types: " + f"{', '.join(str(t) for t in supported_dtypes[input_dtype])}" + ), ) return False @@ -120,8 +145,10 @@ def is_node_tosa_supported( if node.kwargs["memory_format"] in (torch.preserve_format,): self.reporter.report_reject( node, - f"Argument 'memory_format' is not supported for " - f"{node.target} right now.", + ( + "Argument 'memory_format' is not supported for " + f"{node.target} right now." + ), ) return False @@ -132,8 +159,10 @@ def is_node_tosa_supported( if dim_order != list(range(len(dim_order))): # type: ignore[arg-type] self.reporter.report_reject( node, - f"Argument {dim_order=} is not supported for " - f"{node.target} right now.", + ( + f"Argument {dim_order=} is not supported for " + f"{node.target} right now." + ), ) return False