From 7d9736f7f11e36882ed4ac5b6547c7ffd0f9751d Mon Sep 17 00:00:00 2001 From: Per Held Date: Mon, 3 Feb 2025 12:16:35 +0100 Subject: [PATCH] Arm backend: Add upsample_bilinear2d op Change-Id: Idb0c948d888d5ea543e287948080e799f77fa153 --- .../tosa_supported_operators.py | 2 + backends/arm/operators/__init__.py | 1 + .../arm/operators/op_upsample_bilinear2d.py | 100 +++++++ .../arm/quantizer/quantization_annotator.py | 1 + .../arm/test/ops/test_upsample_bilinear2d.py | 247 ++++++++++++++++++ backends/arm/tosa_partitioner.py | 1 + 6 files changed, 352 insertions(+) create mode 100644 backends/arm/operators/op_upsample_bilinear2d.py create mode 100644 backends/arm/test/ops/test_upsample_bilinear2d.py diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 952cfb17cf0..5de90bda252 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -207,6 +207,7 @@ def is_node_supported( exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.upsample_bilinear2d.vec, exir_ops.edge.aten.upsample_nearest2d.vec, exir_ops.edge.aten.var.correction, exir_ops.edge.aten.var.dim, @@ -365,6 +366,7 @@ def is_node_supported( exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.upsample_bilinear2d.vec, exir_ops.edge.aten.upsample_nearest2d.vec, exir_ops.edge.aten.gelu.default, ): diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index da050c5994e..3ee243779e6 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -46,6 +46,7 @@ op_to_copy, op_to_dim_order_copy, op_transpose, + op_upsample_bilinear2d, op_upsample_nearest2d, op_view, op_where, diff --git a/backends/arm/operators/op_upsample_bilinear2d.py b/backends/arm/operators/op_upsample_bilinear2d.py new file mode 100644 index 00000000000..52eebf70900 --- /dev/null +++ b/backends/arm/operators/op_upsample_bilinear2d.py @@ -0,0 +1,100 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import List + +import torch + +import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_quant_utils import build_rescale +from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape +from tosa_tools.v0_80.tosa.ResizeMode import ResizeMode # type: ignore + + +@register_node_visitor +class UpsampleBilinear2dVisitor_0_80(NodeVisitor): + target = "aten.upsample_bilinear2d.vec" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + assert ( + inputs[0].shape is not None and output.shape is not None + ), "Only static shapes are supported" + + input_dtype = inputs[0].dtype + + # tosa_shape output is NHWC, take HW + input_size_yx = torch.tensor( + tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3] + ) + # Ignore scale and size parameters, directly use the output size as + # we only support static shapes currently + output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3]) + + scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters( + input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True + ) + + def in_int16_range(x): + return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1) + + assert in_int16_range(scale_n_yx) + assert in_int16_range(scale_d_yx) + assert in_int16_range(border_yx) + + attr = ts.TosaSerializerAttribute() + attr.ResizeAttribute( + scale=[scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]], + offset=offset_yx.tolist(), + border=border_yx.tolist(), + mode=ResizeMode.BILINEAR, + ) + + if input_dtype == output.dtype == ts.DType.FP32: + tosa_graph.addOperator( + ts.TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr + ) + return + elif input_dtype == output.dtype == ts.DType.INT8: + intermediate = tosa_graph.addIntermediate( + tosa_shape(output.shape, output.dim_order), ts.DType.INT32 + ) + + tosa_graph.addOperator( + ts.TosaOp.Op().RESIZE, [inputs[0].name], [intermediate.name], attr + ) + + final_output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1])) + + build_rescale( + tosa_fb=tosa_graph, + scale=[final_output_scale], + input_node=intermediate, + output_name=output.name, + output_type=ts.DType.INT8, + output_shape=output.shape, + input_zp=0, + output_zp=0, + is_double_round=False, + ) + else: + raise ValueError( + "Input/output dtype not in {float32, int8}: {input_dtype=} {output.dtype=}" + ) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 5398101fd9a..ad866fa9d13 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -215,6 +215,7 @@ def _match_pattern( torch.ops.aten.flip.default, torch.ops.aten.chunk.default, torch.ops.aten.contiguous.default, + torch.ops.aten.upsample_bilinear2d.vec, torch.ops.aten.upsample_nearest2d.vec, torch.ops.aten.pad.default, torch.ops.aten.amax.default, diff --git a/backends/arm/test/ops/test_upsample_bilinear2d.py b/backends/arm/test/ops/test_upsample_bilinear2d.py new file mode 100644 index 00000000000..c1a1292aa4e --- /dev/null +++ b/backends/arm/test/ops/test_upsample_bilinear2d.py @@ -0,0 +1,247 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op = "torch.ops.aten.upsample_bilinear2d.vec" +input_t1 = Tuple[torch.Tensor] # Input x + +test_data_suite_tosa = { + # (test_name, test_data, size, scale_factor, compare_outputs) + "rand_double_scale": (torch.rand(2, 4, 8, 3), None, 2.0, True), + "rand_double_scale_one_dim": (torch.rand(2, 4, 8, 3), None, (1.0, 2.0), True), + "rand_double_size": (torch.rand(2, 4, 8, 3), (16, 6), None, True), + "rand_one_double_scale": (torch.rand(2, 4, 1, 1), None, 2.0, True), + "rand_one_double_size": (torch.rand(2, 4, 1, 1), (2, 2), None, True), + "rand_one_same_scale": (torch.rand(2, 4, 1, 1), None, 1.0, True), + "rand_one_same_size": (torch.rand(2, 4, 1, 1), (1, 1), None, True), + # Can't compare outputs as the rounding when selecting the nearest pixel is + # different between PyTorch and TOSA. Just check the legalization went well. + # TODO Improve the test infrastructure to support more in depth verification + # of the TOSA legalization results. + "rand_half_scale": (torch.rand(2, 4, 8, 6), None, 0.5, False), + "rand_half_size": (torch.rand(2, 4, 8, 6), (4, 3), None, False), + "rand_one_and_half_scale": (torch.rand(2, 4, 8, 3), None, 1.5, False), + "rand_one_and_half_size": (torch.rand(2, 4, 8, 3), (12, 4), None, False), + # Use randn for a bunch of tests to get random numbers from the + # normal distribution where negative is also a possibilty + "randn_double_scale_negative": (torch.randn(2, 4, 8, 3), None, 2.0, True), + "randn_double_scale_one_dim_negative": ( + torch.randn(2, 4, 8, 3), + None, + (1.0, 2.0), + True, + ), + "randn_double_size_negative": (torch.randn(2, 4, 8, 3), (16, 6), None, True), + "randn_one_double_scale_negative": (torch.randn(2, 4, 1, 1), None, 2.0, True), + "randn_one_double_size_negative": (torch.randn(2, 4, 1, 1), (2, 2), None, True), + "randn_one_same_scale_negative": (torch.randn(2, 4, 1, 1), None, 1.0, True), + "randn_one_same_size_negative": (torch.randn(2, 4, 1, 1), (1, 1), None, True), +} + +test_data_suite_Uxx = { + "rand_half_scale": (torch.rand(2, 4, 8, 6), None, 0.5, False), + "rand_half_size": (torch.rand(2, 4, 8, 6), (4, 3), None, False), + "rand_one_and_half_scale": (torch.rand(2, 4, 8, 3), None, 1.5, False), + "rand_one_and_half_size": (torch.rand(2, 4, 8, 3), (12, 4), None, False), +} + + +class UpsamplingBilinear2d(torch.nn.Module): + def __init__( + self, + size: Optional[Tuple[int]], + scale_factor: Optional[float | Tuple[float]], + ): + super().__init__() + self.upsample = torch.nn.UpsamplingBilinear2d( # noqa: TOR101 + size=size, scale_factor=scale_factor + ) + + def forward(self, x): + return self.upsample(x) + + +class Upsample(torch.nn.Module): + def __init__( + self, + size: Optional[Tuple[int]], + scale_factor: Optional[float | Tuple[float]], + ): + super().__init__() + self.upsample = torch.nn.Upsample( + size=size, scale_factor=scale_factor, mode="bilinear", align_corners=True + ) + + def forward(self, x): + return self.upsample(x) + + +class Interpolate(torch.nn.Module): + def __init__( + self, + size: Optional[Tuple[int]], + scale_factor: Optional[float | Tuple[float]], + ): + super().__init__() + self.upsample = lambda x: torch.nn.functional.interpolate( + x, size=size, scale_factor=scale_factor, mode="bilinear", align_corners=True + ) + + def forward(self, x): + return self.upsample(x) + + +@common.parametrize("test_data", test_data_suite_tosa) +def test_upsample_bilinear2d_vec_tosa_MI_UpsamplingBilinear2d( + test_data: torch.Tensor, +): + test_data, size, scale_factor, compare_outputs = test_data + + pipeline = TosaPipelineMI[input_t1]( + UpsamplingBilinear2d(size, scale_factor), + (test_data,), + aten_op, + exir_op=[], + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tosa) +def test_upsample_bilinear2d_vec_tosa_MI_Upsample( + test_data: torch.Tensor, +): + test_data, size, scale_factor, compare_outputs = test_data + + pipeline = TosaPipelineMI[input_t1]( + Upsample(size, scale_factor), + (test_data,), + aten_op, + exir_op=[], + ) + if not compare_outputs: + pipeline.pop_stage(-1) + + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tosa) +def test_upsample_bilinear2d_vec_tosa_MI_Interpolate( + test_data: torch.Tensor, +): + test_data, size, scale_factor, compare_outputs = test_data + + pipeline = TosaPipelineMI[input_t1]( + Interpolate(size, scale_factor), + (test_data,), + aten_op, + exir_op=[], + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tosa) +def test_upsample_bilinear2d_vec_tosa_BI_intropolate( + test_data: torch.Tensor, +): + test_data, size, scale_factor, compare_outputs = test_data + + pipeline = TosaPipelineBI[input_t1]( + UpsamplingBilinear2d(size, scale_factor), + (test_data,), + aten_op, + exir_op=[], + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_tosa) +def test_upsample_bilinear2d_vec_tosa_BI_Upsample( + test_data: torch.Tensor, +): + test_data, size, scale_factor, compare_outputs = test_data + + pipeline = TosaPipelineBI[input_t1]( + Upsample(size, scale_factor), + (test_data,), + aten_op, + exir_op=[], + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_Uxx) +@common.XfailIfNoCorstone320 +def test_upsample_bilinear2d_vec_U85_BI_Upsample(test_data: input_t1): + test_data, size, scale_factor, compare_outputs = test_data + + pipeline = EthosU85PipelineBI[input_t1]( + Upsample(size, scale_factor), + (test_data,), + aten_op, + run_on_fvp=True, + qtol=1, + use_to_edge_transform_and_lower=True, + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_Uxx) +@common.XfailIfNoCorstone320 +def test_upsample_bilinear2d_vec_U85_BI_Interpolate( + test_data: torch.Tensor, +): + test_data, size, scale_factor, compare_outputs = test_data + + pipeline = EthosU85PipelineBI[input_t1]( + Interpolate(size, scale_factor), + (test_data,), + aten_op, + run_on_fvp=True, + qtol=1, + use_to_edge_transform_and_lower=True, + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite_Uxx) +@common.XfailIfNoCorstone320 +def test_upsample_bilinear2d_vec_U85_BI_UpsamplingBilinear2d( + test_data: torch.Tensor, +): + test_data, size, scale_factor, compare_outputs = test_data + + pipeline = EthosU85PipelineBI[input_t1]( + UpsamplingBilinear2d(size, scale_factor), + (test_data,), + aten_op, + run_on_fvp=True, + qtol=1, + use_to_edge_transform_and_lower=True, + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() diff --git a/backends/arm/tosa_partitioner.py b/backends/arm/tosa_partitioner.py index 06b0555bc44..738c5ab8204 100644 --- a/backends/arm/tosa_partitioner.py +++ b/backends/arm/tosa_partitioner.py @@ -170,6 +170,7 @@ def filter_fn(node: torch.fx.Node) -> bool: ops_to_not_decompose = [ torch.ops.aten.linear.default, + torch.ops.aten.upsample_bilinear2d.vec, torch.ops.aten.upsample_nearest2d.vec, torch.ops.aten.eye.default, torch.ops.aten.linspace.default,