diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index 52bcc937c96..e3937f8c44a 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -34,5 +34,14 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - assert inputs[0].dtype == output.dtype == ts.DType.FP32 + if len(node.all_input_nodes) != 1: + raise ValueError( + f"Expected 1 input for {self.target}, got {len(node.all_input_nodes)}" + ) + if inputs[0].dtype != ts.DType.FP32 or output.dtype != ts.DType.FP32: + raise ValueError( + f"Input and output for {self.target} need to be FP32, got " + f"{inputs[0].dtype=} and {output.dtype=}" + ) + tosa_graph.addOperator(ts.TosaOp.Op().RSQRT, [inputs[0].name], [output.name])