diff --git a/backends/arm/test/conftest.py b/backends/arm/test/conftest.py index a94adb9a897..401b89938b7 100644 --- a/backends/arm/test/conftest.py +++ b/backends/arm/test/conftest.py @@ -11,7 +11,6 @@ import shutil import subprocess import sys -from enum import auto, Enum from typing import Any import pytest @@ -22,21 +21,15 @@ """ -class arm_test_options(Enum): - quantize_io = auto() - corstone_fvp = auto() - fast_fvp = auto() - - -_test_options: dict[arm_test_options, Any] = {} - # ==== Pytest hooks ==== def pytest_configure(config): + pytest._test_options = {} + if config.option.arm_quantize_io: _load_libquantized_ops_aot_lib() - _test_options[arm_test_options.quantize_io] = True + pytest._test_options["quantize_io"] = True if config.option.arm_run_corstoneFVP: corstone300_exists = shutil.which("FVP_Corstone_SSE-300_Ethos-U55") corstone320_exists = shutil.which("FVP_Corstone_SSE-320") @@ -44,8 +37,8 @@ def pytest_configure(config): raise RuntimeError( "Tests are run with --arm_run_corstoneFVP but corstone FVP is not installed." ) - _test_options[arm_test_options.corstone_fvp] = True - _test_options[arm_test_options.fast_fvp] = config.option.fast_fvp + pytest._test_options["corstone_fvp"] = True + pytest._test_options["fast_fvp"] = config.option.fast_fvp logging.basicConfig(level=logging.INFO, stream=sys.stdout) @@ -131,9 +124,7 @@ def expectedFailureOnFVP(test_item): # ==== End of Custom Pytest decorators ===== -def is_option_enabled( - option: str | arm_test_options, fail_if_not_enabled: bool = False -) -> bool: +def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: """ Returns whether an option is successfully enabled, i.e. if the flag was given to pytest and the necessary requirements are available. @@ -144,10 +135,8 @@ def is_option_enabled( The optional parameter 'fail_if_not_enabled' makes the function raise a RuntimeError instead of returning False. """ - if isinstance(option, str): - option = arm_test_options[option.lower()] - if option in _test_options and _test_options[option]: + if option in pytest._test_options and pytest._test_options[option]: return True else: if fail_if_not_enabled: @@ -156,15 +145,15 @@ def is_option_enabled( return False -def get_option(option: arm_test_options) -> Any | None: +def get_option(option: str) -> Any | None: """ Returns the value of an pytest option if it is set, otherwise None. Args: - option (arm_test_options): The option to check for. + option (str): The option to check for. """ - if option in _test_options: - return _test_options[option] + if option in pytest._test_options: + return pytest._test_options[option] return None diff --git a/backends/arm/test/ops/test_depthwise_conv.py b/backends/arm/test/ops/test_depthwise_conv.py index 083e9aaf68e..25dc14ee948 100644 --- a/backends/arm/test/ops/test_depthwise_conv.py +++ b/backends/arm/test/ops/test_depthwise_conv.py @@ -156,14 +156,11 @@ ("two_dw_conv2d", two_dw_conv2d), ] -testsuite_conv2d_u85 = [ +testsuite_conv2d_u85_xfails = [ ("2x2_1x6x4x4_gp6_st1", dw_conv2d_2x2_1x6x4x4_gp6_st1), ("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1), ("3x3_1x4x256x256_gp4_st1", dw_conv2d_3x3_1x4x256x256_gp4_st1), ("3x3_1x4x256x256_gp4_nobias", dw_conv2d_3x3_1x4x256x256_gp4_nobias), -] - -testsuite_conv2d_u85_xfails = [ ("3x3_2x8x198x198_gp8_st3", dw_conv2d_3x3_2x8x198x198_gp8_st3), ("two_dw_conv2d", two_dw_conv2d), ] @@ -287,7 +284,7 @@ def test_dw_conv1d_u55_BI( model.get_inputs(), ) - @parameterized.expand(testsuite_conv1d + testsuite_conv2d_u85) + @parameterized.expand(testsuite_conv1d[2:]) def test_dw_conv_u85_BI( self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False ): @@ -299,8 +296,12 @@ def test_dw_conv_u85_BI( model.get_inputs(), ) + testsuite_conv2d_u85_xfails.remove( + ("3x3_1x3x256x256_gp3_st1", dw_conv2d_3x3_1x3x256x256_gp3_st1) + ) # Works + # All test cases except 3x3_1x3x256x256_gp3_st1 have numerical issues on FVP. MLETORCH-520 - @parameterized.expand(testsuite_conv2d_u85_xfails) + @parameterized.expand(testsuite_conv2d_u85_xfails + testsuite_conv1d[:2]) @conftest.expectedFailureOnFVP def test_dw_conv_u85_BI_xfails( self, test_name: str, model: torch.nn.Module, set_quantize_io: bool = False diff --git a/backends/arm/test/ops/test_div.py b/backends/arm/test/ops/test_div.py index eaf6a21023d..b367cab42f6 100644 --- a/backends/arm/test/ops/test_div.py +++ b/backends/arm/test/ops/test_div.py @@ -183,21 +183,8 @@ def test_div_tosa_BI( test_data = (input_, other_) self._test_div_tosa_BI_pipeline(self.Div(), test_data) - @parameterized.expand(test_data_suite[:2]) - def test_div_u55_BI( - self, - test_name: str, - input_: Union[torch.Tensor, torch.types.Number], - other_: Union[torch.Tensor, torch.types.Number], - rounding_mode: Optional[str] = None, - ): - test_data = (input_, other_) - self._test_div_ethos_BI_pipeline( - self.Div(), common.get_u55_compile_spec(), test_data - ) - # Numerical issues on FVP likely due to mul op, MLETORCH-521 - @parameterized.expand(test_data_suite[2:]) + @parameterized.expand(test_data_suite) @conftest.expectedFailureOnFVP def test_div_u55_BI_xfails( self, @@ -211,21 +198,8 @@ def test_div_u55_BI_xfails( self.Div(), common.get_u55_compile_spec(), test_data ) - @parameterized.expand(test_data_suite[:2]) - def test_div_u85_BI( - self, - test_name: str, - input_: Union[torch.Tensor, torch.types.Number], - other_: Union[torch.Tensor, torch.types.Number], - rounding_mode: Optional[str] = None, - ): - test_data = (input_, other_) - self._test_div_ethos_BI_pipeline( - self.Div(), common.get_u85_compile_spec(), test_data - ) - # Numerical issues on FVP likely due to mul op, MLETORCH-521 - @parameterized.expand(test_data_suite[2:]) + @parameterized.expand(test_data_suite) @conftest.expectedFailureOnFVP def test_div_u85_BI_xfails( self, diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py index ced71b0072b..1fa81d0eb21 100644 --- a/backends/arm/test/ops/test_mul.py +++ b/backends/arm/test/ops/test_mul.py @@ -152,7 +152,9 @@ def test_mul_tosa_BI( test_data = (input_, other_) self._test_mul_tosa_BI_pipeline(self.Mul(), test_data) + # Numerical issues on FVP, MLETORCH-521 @parameterized.expand(test_data_sute) + @conftest.expectedFailureOnFVP def test_mul_u55_BI( self, test_name: str, @@ -164,7 +166,10 @@ def test_mul_u55_BI( common.get_u55_compile_spec(), self.Mul(), test_data ) - @parameterized.expand(test_data_sute) + # Numerical issues on FVP, MLETORCH-521 + # test_data_sute[0] works on U85 + @parameterized.expand(test_data_sute[1:]) + @conftest.expectedFailureOnFVP def test_mul_u85_BI( self, test_name: str, diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index a8a113cf931..4586a9240b5 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -17,7 +17,7 @@ import numpy as np import torch -from executorch.backends.arm.test.conftest import arm_test_options, is_option_enabled +from executorch.backends.arm.test.conftest import is_option_enabled from torch.export import ExportedProgram from torch.fx.node import Node @@ -251,7 +251,7 @@ def run_corstone( cmd_line += f" -i {input_path}" ethos_u_extra_args = "" - if is_option_enabled(arm_test_options.fast_fvp): + if is_option_enabled("fast_fvp"): ethos_u_extra_args = ethos_u_extra_args + "--fast" command_args = {