From 20f906b24f9b0e838956019ec35e04faa6f1b851 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Wed, 20 Aug 2025 13:35:03 +0200 Subject: [PATCH] Arm backend: Add TOSA dialect op for RESIZE Add TOSA backend dialect op for TOSA RESIZE. The dialect op replaces upsample_nearest2d and upsample_bilinear_2d in RewriteUpsamplePass. Also the Nodevisitors of upsample_nearest2d and upsample_bilinear2d are replaced by one NodeVisitor for the resize backend dialect op. Signed-off-by: Oscar Andersson Change-Id: I3a1737428707d767c19fc127d92366ea7786a30f --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 3 + .../arm/_passes/fuse_constant_ops_pass.py | 3 +- backends/arm/_passes/rewrite_upsample.py | 84 ++++++++++ backends/arm/operators/__init__.py | 3 +- ...{op_upsample_nearest2d.py => op_resize.py} | 20 ++- .../arm/operators/op_upsample_bilinear2d.py | 148 ------------------ backends/arm/tosa/dialect/__init__.py | 1 + backends/arm/tosa/dialect/ops/resize.py | 60 +++++++ 9 files changed, 165 insertions(+), 158 deletions(-) create mode 100644 backends/arm/_passes/rewrite_upsample.py rename backends/arm/operators/{op_upsample_nearest2d.py => op_resize.py} (82%) delete mode 100644 backends/arm/operators/op_upsample_bilinear2d.py create mode 100644 backends/arm/tosa/dialect/ops/resize.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index a5d8e17f0cd..93bf20e69c1 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -91,6 +91,7 @@ ReplaceScalarWithTensorArgPassTOSABI, ReplaceScalarWithTensorArgPassTOSAMI, ) +from .rewrite_upsample import RewriteUpsamplePass # noqa from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa from .size_adjust_input_pass import SizeAdjustInputPass # noqa from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 70470890317..96b0d7a5572 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -91,6 +91,7 @@ ReplaceScalarWithTensorArgPassTOSABI, ReplaceScalarWithTensorArgPassTOSAMI, RetraceFoldedDtypesPass, + RewriteUpsamplePass, ScalarsToAttributePass, SizeAdjustInputPass, ToTosaMemoryFormatPass, @@ -204,6 +205,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: # needs to happen before AddBiasPass, but after the table ops are inserted # to be able to validate that conv2d has right dtype arguments. self.add_pass(DecomposeConv2dWithInt16ActivationPass()) + self.add_pass(RewriteUpsamplePass(exported_program)) self.add_pass(AddBiasPass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) @@ -288,6 +290,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(FuseViewCopyTransform()) self.add_pass(FuseConstantArgsPass(exported_program)) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) + self.add_pass(RewriteUpsamplePass(exported_program)) self.add_pass(AddBiasPass(exported_program)) self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) diff --git a/backends/arm/_passes/fuse_constant_ops_pass.py b/backends/arm/_passes/fuse_constant_ops_pass.py index 07f3a4af245..13a33ff66e6 100644 --- a/backends/arm/_passes/fuse_constant_ops_pass.py +++ b/backends/arm/_passes/fuse_constant_ops_pass.py @@ -111,8 +111,9 @@ def call(self, graph_module): if node.op != "call_function": continue if node.target in [ - exir_ops.backend.tosa.TABLE.default, exir_ops.backend.tosa.RESCALE.default, + exir_ops.backend.tosa.RESIZE.default, + exir_ops.backend.tosa.TABLE.default, exir_ops.backend.tosa.TRANSPOSE.default, ]: continue diff --git a/backends/arm/_passes/rewrite_upsample.py b/backends/arm/_passes/rewrite_upsample.py new file mode 100644 index 00000000000..c9f25a1e845 --- /dev/null +++ b/backends/arm/_passes/rewrite_upsample.py @@ -0,0 +1,84 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import ( + create_node, + get_first_fake_tensor, +) +from executorch.backends.arm.tosa.utils import get_resize_parameters +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class RewriteUpsamplePass(ArmPass): + """Rewrite upsample2d nodes to TOSA.RESIZE nodes.""" + + targeted_ops = ( + exir_ops.edge.aten.upsample_nearest2d.vec, + exir_ops.edge.aten.upsample_bilinear2d.vec, + ) + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call(self, graph_module): + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in self.targeted_ops: + continue + modified = True + + if node.target == exir_ops.edge.aten.upsample_bilinear2d.vec: + x, output_size, align_corners, scale_factors = node.args + resize_mode = "bilinear" + else: + x, output_size, scale_factors = node.args + align_corners = False + resize_mode = "nearest" + + with graph_module.graph.inserting_before(node): + tosa_resize_node = create_node( + graph_module.graph, + op_target=exir_ops.backend.tosa.RESIZE.default, + args=(x, output_size, align_corners, scale_factors), + kwargs={"resize_mode": resize_mode}, + from_node=node, + ) + node.replace_all_uses_with(tosa_resize_node) + graph_module.graph.erase_node(node) + input_dtype = get_first_fake_tensor(x).dtype + if input_dtype == torch.int8 and resize_mode == "bilinear": + input_size = get_first_fake_tensor(x).shape + input_size_xy = input_size[2:] + output_size = get_first_fake_tensor(node).shape + output_size_xy = output_size[2:] + scale_n_yx, _, _, _ = get_resize_parameters( + input_size_xy=input_size_xy, + output_size_xy=output_size_xy, + resize_mode=1, + align_corners=align_corners, + ) + output_dtype = get_first_fake_tensor(node).dtype + output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1])) + with graph_module.graph.inserting_after(tosa_resize_node): + rescale_node = create_node( + graph_module.graph, + exir_ops.backend.tosa.RESCALE.default, + ) + tosa_resize_node.replace_all_uses_with(rescale_node) + rescale_node.args = ( + tosa_resize_node, + output_dtype, + output_scale, + 0, # zero point + 0, # zero point + ) + + if modified: + graph_module = super().call(graph_module).graph_module + return PassResult(graph_module, modified) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index e2bda4b7641..d8b371570f6 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -43,6 +43,7 @@ op_reciprocal, op_repeat, op_rescale, + op_resize, op_rshift_tensor, op_rsqrt, op_sigmoid, @@ -54,8 +55,6 @@ op_tanh, op_to_dim_order_copy, op_transpose, - op_upsample_bilinear2d, - op_upsample_nearest2d, op_view, op_where, ops_binary, diff --git a/backends/arm/operators/op_upsample_nearest2d.py b/backends/arm/operators/op_resize.py similarity index 82% rename from backends/arm/operators/op_upsample_nearest2d.py rename to backends/arm/operators/op_resize.py index 3c3ca67c9f5..020395ee7c2 100644 --- a/backends/arm/operators/op_upsample_nearest2d.py +++ b/backends/arm/operators/op_resize.py @@ -24,8 +24,8 @@ @register_node_visitor -class UpsampleNearest2dVisitor(NodeVisitor): - target = "aten.upsample_nearest2d.vec" +class ResizeVisitor(NodeVisitor): + target = "tosa.RESIZE.default" tosa_specs = NodeVisitor.tosa_specs @@ -41,12 +41,18 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts - validate_num_inputs(self.target, inputs, 3) - validate_same_dtype(self.target, [inputs[0], output], ts) + validate_num_inputs(self.target, inputs, [3, 4]) + if node.kwargs.get("resize_mode") == "bilinear": + resize_mode = ResizeMode.BILINEAR + align_corners = bool(node.args[2]) + else: + resize_mode = ResizeMode.NEAREST + align_corners = False + validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT32, ts.DType.FP16, ts.DType.FP32], output.tosa_spec, ) @@ -59,7 +65,7 @@ def define_node( # Align corners shouldn't make a difference for nearest upsampling. We set to False so # half pixel centers are used for resize parameter logic. scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters( - input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=False + input_size_yx, output_size_yx, resize_mode, align_corners=align_corners ) def in_int16_range(x): @@ -86,7 +92,7 @@ def in_int16_range(x): ) attr = ts.TosaSerializerAttribute() attr.ResizeAttribute( - mode=ResizeMode.NEAREST, + mode=resize_mode, ) self._serialize_operator( diff --git a/backends/arm/operators/op_upsample_bilinear2d.py b/backends/arm/operators/op_upsample_bilinear2d.py deleted file mode 100644 index 3cc620727e0..00000000000 --- a/backends/arm/operators/op_upsample_bilinear2d.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe -from typing import Any, List - -import torch - -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, -) -from executorch.backends.arm.tosa.mapping import TosaArg -from executorch.backends.arm.tosa.quant_utils import build_rescale -from executorch.backends.arm.tosa.utils import get_resize_parameters, tosa_shape - - -@register_node_visitor -class UpsampleBilinear2dVisitor(NodeVisitor): - - target = "aten.upsample_bilinear2d.vec" - tosa_specs = NodeVisitor.tosa_specs - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - import serializer.tosa_serializer as ts - from tosa.ResizeMode import ResizeMode # type: ignore - from tosa.RoundingMode import RoundingMode # type: ignore - - validate_num_inputs(self.target, inputs, 4) - validate_same_dtype(self.target, [inputs[0], output], ts) - validate_valid_dtype( - self.target, - [inputs[0], output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], - output.tosa_spec, - ) - - if inputs[0].shape is None or output.shape is None: - raise ValueError("Only static shapes are supported") - - input_dtype = inputs[0].dtype - - # tosa_shape output is NHWC, take HW - input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[ - 1:3 - ] - output_size_yx = tuple([output.shape[dim] for dim in output.dim_order])[1:3] - - # Get align_corners value from the node arguments. - align_corners = bool(node.args[2]) - scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters( - input_size_yx, - output_size_yx, - ResizeMode.NEAREST, - align_corners=align_corners, - ) - - def in_int16_range(x): - return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1) - - if not in_int16_range(scale_n_yx): - raise ValueError("scale_n_yx is out of the int16 range") - if not in_int16_range(scale_d_yx): - raise ValueError("scale_d_yx is out of the int16 range") - if not in_int16_range(border_yx): - raise ValueError("border_yx is out of the int16 range") - - scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]] - - attr = ts.TosaSerializerAttribute() - attr.ResizeAttribute(mode=ResizeMode.BILINEAR) - - scales_tensor = tosa_graph.addConst( - [len(scales)], ts.DType.SHAPE, scales, node.name + "_scales" - ) - offset = offset_yx.tolist() - offset_tensor = tosa_graph.addConst( - [len(offset)], ts.DType.SHAPE, offset, node.name + "_offset" - ) - border = border_yx.tolist() - border_tensor = tosa_graph.addConst( - [len(border)], ts.DType.SHAPE, border, node.name + "_border" - ) - if input_dtype == output.dtype == ts.DType.FP32: - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().RESIZE, - [ - inputs[0].name, - scales_tensor.name, - offset_tensor.name, - border_tensor.name, - ], - [output.name], - attr, - ) - return - elif input_dtype == output.dtype == ts.DType.INT8: - intermediate = tosa_graph.addIntermediate( - tosa_shape(output.shape, output.dim_order), ts.DType.INT32 - ) - self._serialize_operator( - node, - tosa_graph, - ts.TosaOp.Op().RESIZE, - [ - inputs[0].name, - scales_tensor.name, - offset_tensor.name, - border_tensor.name, - ], - [intermediate.name], - attr, - ) - - final_output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1])) - - build_rescale( - tosa_fb=tosa_graph, - scale=[final_output_scale], - input_node=intermediate, - output_name=output.name, - output_type=ts.DType.INT8, - input_zp=[0], - output_zp=[0], - rounding_mode=RoundingMode.SINGLE_ROUND, - ) - else: - raise ValueError( - "Input/output dtype not in {float32, int8}: {input_dtype=} {output.dtype=}" - ) diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index 136f59beb62..f1e3a29ac22 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -5,6 +5,7 @@ from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401 rescale, + resize, table, transpose, ) diff --git a/backends/arm/tosa/dialect/ops/resize.py b/backends/arm/tosa/dialect/ops/resize.py new file mode 100644 index 00000000000..1f976d0f5e0 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/resize.py @@ -0,0 +1,60 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Literal, Optional + +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op + +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +# Add kwarg instead? +@register_fake_tosa_op( + "RESIZE(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, str resize_mode) -> Tensor", # schema + ( + TosaSpecification.create_from_string("TOSA-1.0+INT"), + TosaSpecification.create_from_string("TOSA-1.0+FP"), + ), # target TOSA specifications +) +def RESIZE( + x: torch.Tensor, + output_size: list[int] | None = None, + align_corners: Optional[bool] = False, + scale_factors: list[float] | None = None, + *, + resize_mode: Literal["nearest", "bilinear"], +) -> torch.Tensor: + tosa_spec = get_context_spec() + + if resize_mode not in ("nearest", "bilinear"): + raise TosaValueError(f"Unsupported resize mode {resize_mode} for TOSA RESIZE") + if x.dtype == torch.int8: + if not tosa_spec.support_integer(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support integers", op="RESIZE" + ) + bilinear = resize_mode == "bilinear" + output_dtype = torch.int32 if bilinear else torch.int8 + elif x.dtype in (torch.float16, torch.float32): + if not tosa_spec.support_float(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support float", op="RESIZE" + ) + output_dtype = x.dtype + else: + raise TosaValueError(f"Unsupported input dtype {x.dtype} for TOSA RESIZE") + + # Does it matter which one to use for fake tracing? + fake_aten_tensor = exir_ops.edge.aten.upsample_nearest2d.vec( + x, output_size, scale_factors + ) + + return fake_aten_tensor.to(output_dtype)