From 5b244e83a7a642943c76680ca4832f26bc0bed94 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Thu, 10 Apr 2025 08:53:15 +0200 Subject: [PATCH] Arm backend: Convert asserts to raise errors in op_rsqrt Asserts are converted to proper raises to ensure graph integrity. Improve error messages and add additional check that both input and output are of data type fp32. Signed-off-by: Sebastian Larsson Change-Id: Ied5b0b13f061527d0b01cb864f6716e646bb8db3 --- backends/arm/operators/op_rsqrt.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index 52bcc937c96..e3937f8c44a 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -34,5 +34,14 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - assert inputs[0].dtype == output.dtype == ts.DType.FP32 + if len(node.all_input_nodes) != 1: + raise ValueError( + f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" + ) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: + raise ValueError( + f"Input and output for {self.target} need to be FP32, got " + f"{inputs[0].dtype=} and {output.dtype=}" + ) + tosa_graph.addOperator(ts.TosaOp.Op().RSQRT, [inputs[0].name], [output.name])