diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b3ddecbc298..a6c9cf1d06b 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -43,6 +43,7 @@ from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import ( UnsqueezeScalarPlaceholdersPass, ) +from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.pass_manager import PassManager @@ -58,6 +59,7 @@ def transform_to_backend_pipeline( ): """Apply passes before transforming program to backend""" self.add_pass(CastInt64ToInt32Pass(exported_program)) + self.add_pass(RemoveGetItemPass()) self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(SizeAdjustConv2DPass()) self.add_pass(RemoveClonePass()) diff --git a/backends/arm/arm_partitioner.py b/backends/arm/arm_partitioner.py index 73092879986..bdd4b80f292 100644 --- a/backends/arm/arm_partitioner.py +++ b/backends/arm/arm_partitioner.py @@ -55,6 +55,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.native_layer_norm.default, exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.max_pool2d_with_indices.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mm.default, exir_ops.edge.aten.repeat.default, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index a8ddf1c8f02..5e188aea771 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -20,6 +20,7 @@ op_get_item, op_hardtanh, op_log, + op_max_pool2d, op_mm, op_mul, op_permute, diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py new file mode 100644 index 00000000000..0752d8242f7 --- /dev/null +++ b/backends/arm/operators/op_max_pool2d.py @@ -0,0 +1,77 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import cast, List + +import serializer.tosa_serializer as ts +import torch +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.tosa_mapping import TosaArg +from executorch.backends.arm.tosa_utils import get_quant_node_args + +from serializer.tosa_serializer import TosaOp + + +@register_node_visitor +class MaxPool2dVisitor(NodeVisitor): + target = "aten.max_pool2d.default" + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + is_quant_node: bool, + ) -> None: + + input_tensor = inputs[0] + kernel_size = inputs[1].special + stride = inputs[2].special + + try: + padding = [*inputs[3].special, *inputs[3].special] + except IndexError: + padding = [0, 0, 0, 0] + + accumulator_type = input_tensor.dtype + + if is_quant_node: + # Accumulator type always is int8 when input tensor is an integer type. + accumulator_type = ts.DType.INT8 + + # Initilize zero point to zero. + input_zp = 0 + output_zp = 0 + + if is_quant_node: + input_zp = get_quant_node_args( + cast(torch.fx.Node, node.all_input_nodes[0]) + ).zp + output_zp = get_quant_node_args(list(node.users)[0]).zp + + attr = ts.TosaSerializerAttribute() + attr.PoolAttribute( + kernel=kernel_size, + stride=stride, + pad=padding, + input_zp=input_zp, + output_zp=output_zp, + accum_dtype=accumulator_type, + ) + + tosa_graph.addOperator( + TosaOp.Op().MAX_POOL2D, + [input_tensor.name], + [output.name], + attr, + ) diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index a1d7bfe296d..4d52b7ddf16 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -147,6 +147,7 @@ def is_share_obs_or_fq_op(op: Callable) -> bool: # TODO: remove? torch.ops.aten.adaptive_avg_pool2d.default, torch.ops.aten.avg_pool2d.default, + torch.ops.aten.max_pool2d.default, torch.ops.aten.full.default, torch.ops.aten.flatten.using_ints, torch.ops.aten.dropout.default, diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index af44fa44742..b0e2a7f0bb7 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -91,6 +91,17 @@ def pytest_sessionfinish(session, exitstatus): # ==== End of Pytest hooks ===== +# ==== Custom Pytest decorators ===== + + +def expectedFailureOnFVP(test_item): + if is_option_enabled("corstone300"): + test_item.__unittest_expecting_failure__ = True + return test_item + + +# ==== End of Custom Pytest decorators ===== + def load_libquantized_ops_aot_lib(): so_ext = { diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py new file mode 100644 index 00000000000..5c48afa3ce1 --- /dev/null +++ b/backends/arm/test/ops/test_max_pool.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import unittest + +from typing import Tuple + +import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + ArmQuantizer, + get_symmetric_quantization_config, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + +from executorch.backends.xnnpack.test.tester.tester import Quantize +from executorch.exir.backend.backend_details import CompileSpec +from parameterized import parameterized + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +test_data_suite = [ + # (test_name, test_data, [kernel_size, stride, padding]) + ("zeros", torch.zeros(1, 1, 4, 8), [2, 2, 1]), + ("ones", torch.ones(1, 16, 50, 32), [4, 2, 0]), + ("rand", torch.rand(1, 16, 52, 16), [4, 3, 0]), +] + +test_data_suite_mult_batches = [ + ("randn", torch.randn(5, 16, 50, 32), [4, 2, 0]), +] + + +class TestMaxPool2d(unittest.TestCase): + """Tests MaxPool2d.""" + + class MaxPool2d(torch.nn.Module): + def __init__( + self, + kernel_size: int | Tuple[int, int], + stride: int | Tuple[int, int], + padding: int | Tuple[int, int], + ): + super().__init__() + self.max_pool_2d = torch.nn.MaxPool2d( + kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x): + return self.max_pool_2d(x) + + def _test_maxpool2d_tosa_MI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + ) + .export() + .check(["torch.ops.aten.max_pool2d.default"]) + .check_not(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"]) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + ) + + def _test_maxpool2d_tosa_BI_pipeline( + self, module: torch.nn.Module, test_data: Tuple[torch.tensor] + ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True), + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.max_pool2d.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"]) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default" + ] + ) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .run_method_and_compare_outputs(inputs=test_data, qtol=1) + ) + + def _test_maxpool2d_tosa_ethos_BI_pipeline( + self, + module: torch.nn.Module, + compile_spec: CompileSpec, + test_data: Tuple[torch.tensor], + ): + quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config()) + tester = ( + ArmTester( + module, + example_inputs=test_data, + compile_spec=compile_spec, + ) + .quantize(Quantize(quantizer, get_symmetric_quantization_config())) + .export() + .check_count({"torch.ops.aten.max_pool2d.default": 1}) + .check(["torch.ops.quantized_decomposed"]) + .to_edge() + .partition() + .check_not(["executorch_exir_dialects_edge__ops_aten_max_pool2d_default"]) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + .to_executorch() + .serialize() + ) + + return tester + + @parameterized.expand(test_data_suite) + def test_maxpool2d_tosa_MI( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + self._test_maxpool2d_tosa_MI_pipeline( + self.MaxPool2d(*model_params), (test_data,) + ) + + @parameterized.expand(test_data_suite) + def test_maxpool2d_tosa_BI( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + self._test_maxpool2d_tosa_BI_pipeline( + self.MaxPool2d(*model_params), (test_data,) + ) + + @parameterized.expand(test_data_suite) + def test_maxpool2d_tosa_u55_BI( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( + self.MaxPool2d(*model_params), + common.get_u55_compile_spec(permute_memory_to_nhwc=True), + (test_data,), + ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=(test_data,), target_board="corstone-300" + ) + + @parameterized.expand(test_data_suite) + def test_maxpool2d_tosa_u85_BI( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( + self.MaxPool2d(*model_params), + common.get_u85_compile_spec(permute_memory_to_nhwc=True), + (test_data,), + ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=(test_data,), target_board="corstone-320" + ) + + @parameterized.expand(test_data_suite_mult_batches) + def test_maxpool2d_tosa_MI_mult_batches( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + self._test_maxpool2d_tosa_MI_pipeline( + self.MaxPool2d(*model_params), (test_data,) + ) + + @parameterized.expand(test_data_suite_mult_batches) + def test_maxpool2d_tosa_BI_mult_batches( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + self._test_maxpool2d_tosa_BI_pipeline( + self.MaxPool2d(*model_params), (test_data,) + ) + + @parameterized.expand(test_data_suite_mult_batches) + @common.expectedFailureOnFVP # TODO: MLETORCH-433 + def test_maxpool2d_tosa_u55_BI_mult_batches( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( + self.MaxPool2d(*model_params), + common.get_u55_compile_spec(permute_memory_to_nhwc=True), + (test_data,), + ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=(test_data,), target_board="corstone-300" + ) + + @parameterized.expand(test_data_suite_mult_batches) + @common.expectedFailureOnFVP # TODO: MLETORCH-433 + def test_maxpool2d_tosa_u85_BI_mult_batches( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( + self.MaxPool2d(*model_params), + common.get_u85_compile_spec(permute_memory_to_nhwc=True), + (test_data,), + ) + if common.is_option_enabled("corstone300"): + tester.run_method_and_compare_outputs( + qtol=1, inputs=(test_data,), target_board="corstone-320" + )