Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
torch.ops.aten.unsqueeze.default,
# TODO: remove?
torch.ops.aten.adaptive_avg_pool2d.default,
torch.ops.aten.avg_pool2d.default,
torch.ops.aten.view_copy.default,
torch.ops.aten.view.default,
torch.ops.aten.slice.Tensor,
Expand Down
21 changes: 9 additions & 12 deletions backends/arm/test/ops/test_avg_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import unittest

from typing import Tuple

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
ArmQuantizer,
get_symmetric_quantization_config,
)
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.backends.xnnpack.test.tester.tester import Quantize
from executorch.exir.backend.backend_details import CompileSpec
from parameterized import parameterized

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

test_data_suite = [
# (test_name, test_data, [kernel_size, stride, padding])
Expand Down Expand Up @@ -69,13 +71,14 @@ def _test_avgpool2d_tosa_MI_pipeline(
def _test_avgpool2d_tosa_BI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
):
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=common.get_tosa_compile_spec(permute_memory_to_nhwc=True),
)
.quantize()
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check_count({"torch.ops.aten.avg_pool2d.default": 1})
.check(["torch.ops.quantized_decomposed"])
Expand All @@ -93,13 +96,14 @@ def _test_avgpool2d_tosa_ethos_BI_pipeline(
compile_spec: CompileSpec,
test_data: Tuple[torch.tensor],
):
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
(
ArmTester(
module,
example_inputs=test_data,
compile_spec=compile_spec,
)
.quantize()
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
.export()
.check_count({"torch.ops.aten.avg_pool2d.default": 1})
.check(["torch.ops.quantized_decomposed"])
Expand All @@ -121,10 +125,7 @@ def test_avgpool2d_tosa_MI(
self.AvgPool2d(*model_params), (test_data,)
)

# Expected to fail since ArmQuantizer cannot quantize a AvgPool2D layer
# TODO(MLETORCH-93)
@parameterized.expand(test_data_suite)
@unittest.expectedFailure
def test_avgpool2d_tosa_BI(
self,
test_name: str,
Expand All @@ -135,10 +136,7 @@ def test_avgpool2d_tosa_BI(
self.AvgPool2d(*model_params), (test_data,)
)

# Expected to fail since ArmQuantizer cannot quantize a AvgPool2D layer
# TODO(MLETORCH-93)
@parameterized.expand(test_data_suite)
@unittest.expectedFailure
def test_avgpool2d_tosa_u55_BI(
self,
test_name: str,
Expand All @@ -152,7 +150,6 @@ def test_avgpool2d_tosa_u55_BI(
)

@parameterized.expand(test_data_suite)
@unittest.expectedFailure
def test_avgpool2d_tosa_u85_BI(
self,
test_name: str,
Expand Down
61 changes: 61 additions & 0 deletions backends/arm/test/ops/test_conv_combos.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,32 @@ def forward(self, x):
return x


class ComboConvAvgPool2d(torch.nn.Module):
edge_op_list = [
"executorch_exir_dialects_edge__ops_aten_convolution_default",
"executorch_exir_dialects_edge__ops_aten_avg_pool2d_default",
]

test_data = [
(20 * torch.randn(1, 3, 64, 32),),
(torch.randn(1, 3, 100, 200),),
(5 * torch.randn(1, 3, 256, 256),),
(torch.rand(1, 3, 512, 128),),
]

def __init__(self):
super().__init__()
self.conv2d = torch.nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
)
self.avg_pool2d = torch.nn.AvgPool2d(kernel_size=(2, 2))

def forward(self, x):
x = self.conv2d(x)
x = self.avg_pool2d(x)
return x


class TestConvCombos(unittest.TestCase):
"""Tests conv combined with other ops."""

Expand Down Expand Up @@ -334,3 +360,38 @@ def test_block_bottleneck_residual_u85_BI(self):
common.get_u85_compile_spec(permute_memory_to_nhwc=True),
model.get_inputs(),
)

######################
## Conv + AvgPool2d ##
######################
@parameterized.expand(ComboConvAvgPool2d.test_data)
def test_conv_avgpool2d_tosa_MI(self, test_data: torch.Tensor):
model = ComboConvAvgPool2d()
test_data = (test_data,)
self._test_conv_combo_tosa_MI_pipeline(model, test_data)

@parameterized.expand(ComboConvAvgPool2d.test_data)
def test_conv_avgpool2d_tosa_BI(self, test_data: torch.Tensor):
model = ComboConvAvgPool2d()
test_data = (test_data,)
self._test_conv_combo_tosa_BI_pipeline(model, test_data)

@parameterized.expand(ComboConvAvgPool2d.test_data)
def test_conv_avgpool2d_u55_BI(self, test_data: torch.Tensor):
model = ComboConvAvgPool2d()
test_data = (test_data,)
self._test_conv_combo_ethos_BI_pipeline(
model,
common.get_u55_compile_spec(),
test_data,
)

@parameterized.expand(ComboConvAvgPool2d.test_data)
def test_conv_avgpool2d_u85_BI(self, test_data: torch.Tensor):
model = ComboConvAvgPool2d()
test_data = (test_data,)
self._test_conv_combo_ethos_BI_pipeline(
model,
common.get_u85_compile_spec(),
test_data,
)
Loading