diff --git a/.ci/scripts/gather_benchmark_configs.py b/.ci/scripts/gather_benchmark_configs.py index 7cc9d7e1361..908c363a415 100755 --- a/.ci/scripts/gather_benchmark_configs.py +++ b/.ci/scripts/gather_benchmark_configs.py @@ -10,7 +10,7 @@ import os import re import sys -from typing import Any, Dict, List +from typing import Any, Dict, List, NamedTuple sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) from examples.models import MODEL_NAME_TO_MODEL @@ -47,6 +47,46 @@ } +class DisabledConfig(NamedTuple): + config_name: str + github_issue: str # Link to the GitHub issue + + +# Updated DISABLED_CONFIGS +DISABLED_CONFIGS: Dict[str, List[DisabledConfig]] = { + "resnet50": [ + DisabledConfig( + config_name="qnn_q8", + github_issue="https://github.com/pytorch/executorch/issues/7892", + ), + ], + "w2l": [ + DisabledConfig( + config_name="qnn_q8", + github_issue="https://github.com/pytorch/executorch/issues/7634", + ), + ], + "mobilebert": [ + DisabledConfig( + config_name="mps", + github_issue="https://github.com/pytorch/executorch/issues/7904", + ), + ], + "edsr": [ + DisabledConfig( + config_name="mps", + github_issue="https://github.com/pytorch/executorch/issues/7905", + ), + ], + "llama": [ + DisabledConfig( + config_name="mps", + github_issue="https://github.com/pytorch/executorch/issues/7907", + ), + ], +} + + def extract_all_configs(data, target_os=None): if isinstance(data, dict): # If target_os is specified, include "xplat" and the specified branch @@ -117,6 +157,14 @@ def generate_compatible_configs(model_name: str, target_os=None) -> List[str]: # Skip unknown models with a warning logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.") + # Remove disabled configs for the given model + disabled_configs = DISABLED_CONFIGS.get(model_name, []) + disabled_config_names = {disabled.config_name for disabled in disabled_configs} + for disabled in disabled_configs: + print( + f"Excluding disabled config: '{disabled.config_name}' for model '{model_name}' on '{target_os}'. Linked GitHub issue: {disabled.github_issue}" + ) + configs = [config for config in configs if config not in disabled_config_names] return configs diff --git a/.ci/scripts/tests/test_gather_benchmark_configs.py b/.ci/scripts/tests/test_gather_benchmark_configs.py index 855f8153609..03d735096ea 100644 --- a/.ci/scripts/tests/test_gather_benchmark_configs.py +++ b/.ci/scripts/tests/test_gather_benchmark_configs.py @@ -1,5 +1,6 @@ import importlib.util import os +import re import subprocess import sys import unittest @@ -7,21 +8,25 @@ import pytest -# Dynamically import the script -script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py") -spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path) -gather_benchmark_configs = importlib.util.module_from_spec(spec) -spec.loader.exec_module(gather_benchmark_configs) - @pytest.mark.skipif( sys.platform != "linux", reason="The script under test runs on Linux runners only" ) class TestGatehrBenchmarkConfigs(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Dynamically import the script + script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py") + spec = importlib.util.spec_from_file_location( + "gather_benchmark_configs", script_path + ) + cls.gather_benchmark_configs = importlib.util.module_from_spec(spec) + spec.loader.exec_module(cls.gather_benchmark_configs) + def test_extract_all_configs_android(self): - android_configs = gather_benchmark_configs.extract_all_configs( - gather_benchmark_configs.BENCHMARK_CONFIGS, "android" + android_configs = self.gather_benchmark_configs.extract_all_configs( + self.gather_benchmark_configs.BENCHMARK_CONFIGS, "android" ) self.assertIn("xnnpack_q8", android_configs) self.assertIn("qnn_q8", android_configs) @@ -29,8 +34,8 @@ def test_extract_all_configs_android(self): self.assertIn("llama3_qlora", android_configs) def test_extract_all_configs_ios(self): - ios_configs = gather_benchmark_configs.extract_all_configs( - gather_benchmark_configs.BENCHMARK_CONFIGS, "ios" + ios_configs = self.gather_benchmark_configs.extract_all_configs( + self.gather_benchmark_configs.BENCHMARK_CONFIGS, "ios" ) self.assertIn("xnnpack_q8", ios_configs) @@ -40,17 +45,76 @@ def test_extract_all_configs_ios(self): self.assertIn("llama3_spinquant", ios_configs) self.assertIn("llama3_qlora", ios_configs) + def test_skip_disabled_configs(self): + # Use patch as a context manager to avoid modifying DISABLED_CONFIGS and BENCHMARK_CONFIGS + with patch.dict( + self.gather_benchmark_configs.DISABLED_CONFIGS, + { + "mv3": [ + self.gather_benchmark_configs.DisabledConfig( + config_name="disabled_config1", + github_issue="https://github.com/org/repo/issues/123", + ), + self.gather_benchmark_configs.DisabledConfig( + config_name="disabled_config2", + github_issue="https://github.com/org/repo/issues/124", + ), + ] + }, + ), patch.dict( + self.gather_benchmark_configs.BENCHMARK_CONFIGS, + { + "ios": [ + "disabled_config1", + "disabled_config2", + "enabled_config1", + "enabled_config2", + ] + }, + ): + result = self.gather_benchmark_configs.generate_compatible_configs( + "mv3", target_os="ios" + ) + + # Assert that disabled configs are excluded + self.assertNotIn("disabled_config1", result) + self.assertNotIn("disabled_config2", result) + # Assert enabled configs are included + self.assertIn("enabled_config1", result) + self.assertIn("enabled_config2", result) + + def test_disabled_configs_have_github_links(self): + github_issue_regex = re.compile(r"https://github\.com/.+/.+/issues/\d+") + + for ( + model_name, + disabled_configs, + ) in self.gather_benchmark_configs.DISABLED_CONFIGS.items(): + for disabled in disabled_configs: + with self.subTest(model_name=model_name, config=disabled.config_name): + # Assert that disabled is an instance of DisabledConfig + self.assertIsInstance( + disabled, self.gather_benchmark_configs.DisabledConfig + ) + + # Assert that github_issue is provided and matches the expected pattern + self.assertTrue( + disabled.github_issue + and github_issue_regex.match(disabled.github_issue), + f"Invalid or missing GitHub issue link for '{disabled.config_name}' in model '{model_name}'.", + ) + def test_generate_compatible_configs_llama_model(self): model_name = "meta-llama/Llama-3.2-1B" target_os = "ios" - result = gather_benchmark_configs.generate_compatible_configs( + result = self.gather_benchmark_configs.generate_compatible_configs( model_name, target_os ) expected = ["llama3_fb16", "llama3_coreml_ane"] self.assertEqual(result, expected) target_os = "android" - result = gather_benchmark_configs.generate_compatible_configs( + result = self.gather_benchmark_configs.generate_compatible_configs( model_name, target_os ) expected = ["llama3_fb16"] @@ -58,33 +122,37 @@ def test_generate_compatible_configs_llama_model(self): def test_generate_compatible_configs_quantized_llama_model(self): model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8" - result = gather_benchmark_configs.generate_compatible_configs(model_name, None) + result = self.gather_benchmark_configs.generate_compatible_configs( + model_name, None + ) expected = ["llama3_spinquant"] self.assertEqual(result, expected) model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8" - result = gather_benchmark_configs.generate_compatible_configs(model_name, None) + result = self.gather_benchmark_configs.generate_compatible_configs( + model_name, None + ) expected = ["llama3_qlora"] self.assertEqual(result, expected) def test_generate_compatible_configs_non_genai_model(self): model_name = "mv2" target_os = "xplat" - result = gather_benchmark_configs.generate_compatible_configs( + result = self.gather_benchmark_configs.generate_compatible_configs( model_name, target_os ) expected = ["xnnpack_q8"] self.assertEqual(result, expected) target_os = "android" - result = gather_benchmark_configs.generate_compatible_configs( + result = self.gather_benchmark_configs.generate_compatible_configs( model_name, target_os ) expected = ["xnnpack_q8", "qnn_q8"] self.assertEqual(result, expected) target_os = "ios" - result = gather_benchmark_configs.generate_compatible_configs( + result = self.gather_benchmark_configs.generate_compatible_configs( model_name, target_os ) expected = ["xnnpack_q8", "coreml_fp16", "mps"] @@ -93,7 +161,7 @@ def test_generate_compatible_configs_non_genai_model(self): def test_generate_compatible_configs_unknown_model(self): model_name = "unknown_model" target_os = "ios" - result = gather_benchmark_configs.generate_compatible_configs( + result = self.gather_benchmark_configs.generate_compatible_configs( model_name, target_os ) self.assertEqual(result, []) @@ -101,14 +169,14 @@ def test_generate_compatible_configs_unknown_model(self): def test_is_valid_huggingface_model_id_valid(self): valid_model = "meta-llama/Llama-3.2-1B" self.assertTrue( - gather_benchmark_configs.is_valid_huggingface_model_id(valid_model) + self.gather_benchmark_configs.is_valid_huggingface_model_id(valid_model) ) @patch("builtins.open", new_callable=mock_open) @patch("os.getenv", return_value=None) def test_set_output_no_github_env(self, mock_getenv, mock_file): with patch("builtins.print") as mock_print: - gather_benchmark_configs.set_output("test_name", "test_value") + self.gather_benchmark_configs.set_output("test_name", "test_value") mock_print.assert_called_with("::set-output name=test_name::test_value") def test_device_pools_contains_all_devices(self): @@ -120,7 +188,7 @@ def test_device_pools_contains_all_devices(self): "google_pixel_8_pro", ] for device in expected_devices: - self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS) + self.assertIn(device, self.gather_benchmark_configs.DEVICE_POOLS) def test_gather_benchmark_configs_cli(self): args = {