diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 260299d6423..aed65bda812 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -15,6 +15,7 @@ op_avg_pool2d, op_bmm, op_cat, + op_ceil, op_clamp, op_constant_pad_nd, op_conv2d, @@ -22,12 +23,14 @@ op_eq, op_erf, op_exp, + op_floor, op_ge, op_gt, op_index_select, op_index_tensor, op_le, op_log, + op_logical_not, op_lt, op_max_pool2d, op_maximum, @@ -57,5 +60,4 @@ op_where, ops_binary, ops_identity, - ops_unary, ) diff --git a/backends/arm/operators/op_ceil.py b/backends/arm/operators/op_ceil.py new file mode 100644 index 00000000000..5cf89710436 --- /dev/null +++ b/backends/arm/operators/op_ceil.py @@ -0,0 +1,54 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + +import torch.fx + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) +from executorch.backends.arm.tosa import TosaSpecification + +from executorch.backends.arm.tosa.mapping import TosaArg + + +@register_node_visitor +class CeilVisitor(NodeVisitor): + target = "aten.ceil.default" + + # INT case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts # type: ignore # noqa: F401 + + validate_num_inputs(self.target, inputs, 1) + validate_same_dtype(self.target, [*inputs, output], ts) + validate_valid_dtype( + self.target, + inputs[0], + ts.DType.FP32, + output.tosa_spec, + ) + + self._serialize_operator( + node, tosa_graph, ts.TosaOp.Op().CEIL, [inputs[0].name], [output.name] + ) diff --git a/backends/arm/operators/op_floor.py b/backends/arm/operators/op_floor.py new file mode 100644 index 00000000000..77d712096fa --- /dev/null +++ b/backends/arm/operators/op_floor.py @@ -0,0 +1,54 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + +import torch.fx + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) +from executorch.backends.arm.tosa import TosaSpecification + +from executorch.backends.arm.tosa.mapping import TosaArg + + +@register_node_visitor +class FloorVisitor(NodeVisitor): + target = "aten.floor.default" + + # INT case should be handled by op_table + tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts # type: ignore # noqa: F401 + + validate_num_inputs(self.target, inputs, 1) + validate_same_dtype(self.target, [*inputs, output], ts) + validate_valid_dtype( + self.target, + inputs[0], + ts.DType.FP32, + output.tosa_spec, + ) + + self._serialize_operator( + node, tosa_graph, ts.TosaOp.Op().FLOOR, [inputs[0].name], [output.name] + ) diff --git a/backends/arm/operators/op_logical_not.py b/backends/arm/operators/op_logical_not.py new file mode 100644 index 00000000000..640c3b4e44f --- /dev/null +++ b/backends/arm/operators/op_logical_not.py @@ -0,0 +1,59 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + +import torch.fx + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.arm.tosa.mapping import TosaArg + + +@register_node_visitor +class LogicalNotVisitor(NodeVisitor): + target = "aten.logical_not.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + import serializer.tosa_serializer as ts # type: ignore # noqa: F401 + + validate_num_inputs(self.target, inputs, 1) + validate_same_dtype(self.target, [*inputs, output], ts) + validate_valid_dtype( + self.target, + [*inputs, output], + [ts.DType.BOOL], + output.tosa_spec, + ) + + self._serialize_operator( + node, + tosa_graph, + ts.TosaOp.Op().LOGICAL_NOT, + [inputs[0].name], + [output.name], + ) diff --git a/backends/arm/operators/ops_unary.py b/backends/arm/operators/ops_unary.py deleted file mode 100644 index 008330f68b3..00000000000 --- a/backends/arm/operators/ops_unary.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import Any, List - -import torch.fx - -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, -) - -from executorch.backends.arm.tosa.mapping import TosaArg - - -def unary_operator_factory(unary_target: str, tosa_op): - "Creates and registers NodeVisitors for operations that have one input and map directly into a TOSA op." - - # Some TOSA unary operators only support float - fp_only_ops = ["aten.floor.default"] - - class UnaryOperator(NodeVisitor): - target = unary_target - tosa_specs = NodeVisitor.tosa_specs - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts # type: ignore # noqa: F401 - - validate_num_inputs(self.target, inputs, 1) - validate_same_dtype(self.target, [*inputs, output], ts) - - if self.target in fp_only_ops: - validate_valid_dtype( - self.target, - inputs[0], - ts.DType.FP32, - output.tosa_spec, - ) - - self._serialize_operator( - node, tosa_graph, tosa_op, [inputs[0].name], [output.name] - ) - - register_node_visitor(UnaryOperator) - - -import serializer.tosa_serializer as ts # type: ignore - -unary_operator_factory("aten.ceil.default", ts.TosaOp.Op().CEIL) -unary_operator_factory("aten.floor.default", ts.TosaOp.Op().FLOOR) -unary_operator_factory("aten.logical_not.default", ts.TosaOp.Op().LOGICAL_NOT)