Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.view_copy.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.mean.dim,
Expand Down Expand Up @@ -144,5 +145,6 @@ def ops_to_not_decompose(
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
ops_to_not_decompose = [
torch.ops.aten.linear.default,
torch.ops.aten.upsample_nearest2d.vec,
]
return (ops_to_not_decompose, None)
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@
op_tanh,
op_transpose,
op_unsqueeze,
op_upsample_nearest2d,
op_view,
)
68 changes: 68 additions & 0 deletions backends/arm/operators/op_upsample_nearest2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape
from serializer.tosa_serializer import TosaOp

from tosa.ResizeMode import ResizeMode


@register_node_visitor
class UpsampleNearest2dVisitor(NodeVisitor):
target = "aten.upsample_nearest2d.vec"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
assert (
inputs[0].shape is not None and output.shape is not None
), "Only static shapes are supported"

# tosa_shape output is NHWC, take HW
input_size_yx = torch.tensor(
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
)
# Ignore scale and size parameters, directly use the output size as
# we only support static shapes currently
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])

scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
)

def in_int16_range(x):
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)

assert in_int16_range(scale_n_yx)
assert in_int16_range(scale_d_yx)
assert in_int16_range(border_yx)

attr = ts.TosaSerializerAttribute()
attr.ResizeAttribute(
scale=[scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]],
offset=offset_yx.tolist(),
border=border_yx.tolist(),
mode=ResizeMode.NEAREST,
)

tosa_graph.addOperator(
TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr
)
1 change: 1 addition & 0 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class ArmQuantizer(Quantizer):
"mm",
"one_to_one",
"generic",
"upsample_nearest2d",
]

def __init__(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,5 @@ def decorator(annotator: AnnotatorType):
mul_annotator,
one_to_one_annotator,
sub_annotator,
upsample_nearest2d_annotator,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import itertools
from typing import Callable, List, Optional

import torch
from executorch.backends.arm.quantizer.quantization_annotation import register_annotator
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
SharedQuantizationSpec,
)
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


def _filter_upsample_nearest2d(filter_fn: Optional[Callable[[Node], bool]] = None):
def filter(node: Node):
is_upsample = node.target == torch.ops.aten.upsample_nearest2d.vec
if filter_fn is None:
return is_upsample
else:
return is_upsample and filter_fn(node)

return filter


@register_annotator("upsample_nearest2d")
def _annotate_upsample_nearest2d(
gm: torch.fx.GraphModule,
quantization_config: QuantizationConfig,
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
module_partitions = get_source_partitions(
gm.graph,
[
torch.nn.UpsamplingNearest2d,
torch.nn.Upsample,
torch.nn.functional.interpolate,
],
_filter_upsample_nearest2d(filter_fn),
)
upsample_partitions = list(
itertools.chain.from_iterable(module_partitions.values())
)
annotated_partitions = []

for upsample_partition in upsample_partitions:
annotated_partitions.append(upsample_partition.nodes)

assert len(upsample_partition.nodes) == 1
upsample_node = upsample_partition.nodes[0]

input_act = upsample_node.args[0]
assert isinstance(input_act, Node)

input_act_qspec = quantization_config.get_input_act_qspec()
output_act_qspec = SharedQuantizationSpec((input_act, upsample_node))

upsample_node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act: input_act_qspec,
},
output_qspec=output_act_qspec,
_annotated=True,
)

return annotated_partitions
165 changes: 165 additions & 0 deletions backends/arm/test/ops/test_upsample_nearest2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from typing import Optional, Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from parameterized import parameterized


test_data_suite = [
# (test_name, test_data, size, scale_factor, compare_outputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️ these test. Awesome!

("rand_double_scale", torch.rand(2, 4, 8, 3), None, 2.0, True),
("rand_double_scale_one_dim", torch.rand(2, 4, 8, 3), None, (1.0, 2.0), True),
("rand_double_size", torch.rand(2, 4, 8, 3), (16, 6), None, True),
("rand_one_double_scale", torch.rand(2, 4, 1, 1), None, 2.0, True),
("rand_one_double_size", torch.rand(2, 4, 1, 1), (2, 2), None, True),
("rand_one_same_scale", torch.rand(2, 4, 1, 1), None, 1.0, True),
("rand_one_same_size", torch.rand(2, 4, 1, 1), (1, 1), None, True),
# Can't compare outputs as the rounding when selecting the nearest pixel is
# different between PyTorch and TOSA. Just check the legalization went well.
# TODO Improve the test infrastructure to support more in depth verification
# of the TOSA legalization results.
("rand_half_scale", torch.rand(2, 4, 8, 6), None, 0.5, False),
("rand_half_size", torch.rand(2, 4, 8, 6), (4, 3), None, False),
("rand_one_and_half_scale", torch.rand(2, 4, 8, 3), None, 1.5, False),
("rand_one_and_half_size", torch.rand(2, 4, 8, 3), (12, 4), None, False),
]


class TestUpsampleNearest2d(unittest.TestCase):
class UpsamplingNearest2d(torch.nn.Module):
def __init__(
self,
size: Optional[Tuple[int]],
scale_factor: Optional[float | Tuple[float]],
):
super().__init__()
self.upsample = torch.nn.UpsamplingNearest2d( # noqa: TOR101
size=size, scale_factor=scale_factor
)

def forward(self, x):
return self.upsample(x)

class Upsample(torch.nn.Module):
def __init__(
self,
size: Optional[Tuple[int]],
scale_factor: Optional[float | Tuple[float]],
):
super().__init__()
self.upsample = torch.nn.Upsample(
size=size, scale_factor=scale_factor, mode="nearest"
)

def forward(self, x):
return self.upsample(x)

class Interpolate(torch.nn.Module):
def __init__(
self,
size: Optional[Tuple[int]],
scale_factor: Optional[float | Tuple[float]],
):
super().__init__()
self.upsample = lambda x: torch.nn.functional.interpolate(
x, size=size, scale_factor=scale_factor, mode="nearest"
)

def forward(self, x):
return self.upsample(x)

def _test_upsample_nearest_2d_tosa_MI_pipeline(
self,
module: torch.nn.Module,
test_data: Tuple[torch.tensor],
compare_outputs: bool,
):
tester = (
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
)
.export()
.check(["torch.ops.aten.upsample_nearest2d.vec"])
.check_not(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check_not(["torch.ops.aten.upsample_nearest2d.vec"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

if compare_outputs:
tester.run_method_and_compare_outputs(inputs=test_data)

def _test_upsample_nearest_2d_tosa_BI_pipeline(
self,
module: torch.nn.Module,
test_data: Tuple[torch.tensor],
compare_outputs: bool,
):
tester = (
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
)
.quantize()
.export()
.check(["torch.ops.aten.upsample_nearest2d.vec"])
.check(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check_not(["torch.ops.aten.upsample_nearest2d.vec"])
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
)

if compare_outputs:
tester.run_method_and_compare_outputs(inputs=test_data)

@parameterized.expand(test_data_suite)
def test_upsample_nearest_2d_tosa_MI(
self,
test_name: str,
test_data: torch.Tensor,
size: Optional[Tuple[int]],
scale_factor: Optional[float | Tuple[float]],
compare_outputs: bool,
):
self._test_upsample_nearest_2d_tosa_MI_pipeline(
self.UpsamplingNearest2d(size, scale_factor), (test_data,), compare_outputs
)
self._test_upsample_nearest_2d_tosa_MI_pipeline(
self.Upsample(size, scale_factor), (test_data,), compare_outputs
)
self._test_upsample_nearest_2d_tosa_MI_pipeline(
self.Interpolate(size, scale_factor), (test_data,), compare_outputs
)

@parameterized.expand(test_data_suite)
def test_upsample_nearest_2d_tosa_BI(
self,
test_name: str,
test_data: torch.Tensor,
size: Optional[Tuple[int]],
scale_factor: Optional[float | Tuple[float]],
compare_outputs: bool,
):
self._test_upsample_nearest_2d_tosa_BI_pipeline(
self.UpsamplingNearest2d(size, scale_factor), (test_data,), compare_outputs
)
self._test_upsample_nearest_2d_tosa_BI_pipeline(
self.Upsample(size, scale_factor), (test_data,), compare_outputs
)
self._test_upsample_nearest_2d_tosa_BI_pipeline(
self.Interpolate(size, scale_factor), (test_data,), compare_outputs
)
1 change: 0 additions & 1 deletion backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ def run_method_and_compare_outputs(
inputs (Optional[Tuple[torch.Tensor]]): Allows you to input custom input data.
The default is random data.
"""

edge_stage = self.stages[self.stage_name(tester.ToEdge)]
if edge_stage is None:
edge_stage = self.stages[self.stage_name(tester.ToEdgeTransformAndLower)]
Expand Down
45 changes: 45 additions & 0 deletions backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,48 @@ def expand_dims(
build_reshape(tosa_graph, input_node.name, new_shape, intermediate.name)

return intermediate


def get_resize_parameters(
input_size: torch.Tensor,
output_size: torch.Tensor,
resize_mode: int,
align_corners: bool,
):
"""Get the tosa.resize parameters based on the input and output size.

Args:
input_size (torch.Tensor): Size of the input
output_size (torch.Tensor): Size of the output
resize_mode (tosa.ResizeMode): The TOSA resize mode
align_corners (bool): Align the corners pixels of the input and output

Returns:
scale_n (torch.Tensor), scale_d (torch.Tensor),
offset (torch.Tensor), border (torch.Tensor)
"""
assert torch.all(input_size > 0)
assert torch.all(output_size > 0)

scale_n = torch.tensor(
[
so - 1 if align_corners and si > 1 and so > 1 else so
for si, so in zip(input_size, output_size)
]
)
scale_d = torch.tensor(
[
si - 1 if align_corners and si > 1 and so > 1 else si
for si, so in zip(input_size, output_size)
]
)

gcd = torch.gcd(scale_n, scale_d)
scale_n = scale_n // gcd
scale_d = scale_d // gcd

# No half-pixel centre support in PyTorch, no offset needed
offset = torch.zeros_like(input_size)
border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset

return scale_n, scale_d, offset, border
Loading