Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion .ci/scripts/gather_benchmark_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import re
import sys
from typing import Any, Dict, List
from typing import Any, Dict, List, NamedTuple

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from examples.models import MODEL_NAME_TO_MODEL
Expand Down Expand Up @@ -47,6 +47,46 @@
}


class DisabledConfig(NamedTuple):
config_name: str
github_issue: str # Link to the GitHub issue


# Updated DISABLED_CONFIGS
DISABLED_CONFIGS: Dict[str, List[DisabledConfig]] = {
"resnet50": [
DisabledConfig(
config_name="qnn_q8",
github_issue="https://github.com/pytorch/executorch/issues/7892",
),
],
"w2l": [
DisabledConfig(
config_name="qnn_q8",
github_issue="https://github.com/pytorch/executorch/issues/7634",
),
],
"mobilebert": [
DisabledConfig(
config_name="mps",
github_issue="https://github.com/pytorch/executorch/issues/7904",
),
],
"edsr": [
DisabledConfig(
config_name="mps",
github_issue="https://github.com/pytorch/executorch/issues/7905",
),
],
"llama": [
DisabledConfig(
config_name="mps",
github_issue="https://github.com/pytorch/executorch/issues/7907",
),
],
}


def extract_all_configs(data, target_os=None):
if isinstance(data, dict):
# If target_os is specified, include "xplat" and the specified branch
Expand Down Expand Up @@ -117,6 +157,14 @@ def generate_compatible_configs(model_name: str, target_os=None) -> List[str]:
# Skip unknown models with a warning
logging.warning(f"Unknown or invalid model name '{model_name}'. Skipping.")

# Remove disabled configs for the given model
disabled_configs = DISABLED_CONFIGS.get(model_name, [])
disabled_config_names = {disabled.config_name for disabled in disabled_configs}
for disabled in disabled_configs:
print(
f"Excluding disabled config: '{disabled.config_name}' for model '{model_name}' on '{target_os}'. Linked GitHub issue: {disabled.github_issue}"
)
configs = [config for config in configs if config not in disabled_config_names]
return configs


Expand Down
110 changes: 89 additions & 21 deletions .ci/scripts/tests/test_gather_benchmark_configs.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,41 @@
import importlib.util
import os
import re
import subprocess
import sys
import unittest
from unittest.mock import mock_open, patch

import pytest

# Dynamically import the script
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
spec = importlib.util.spec_from_file_location("gather_benchmark_configs", script_path)
gather_benchmark_configs = importlib.util.module_from_spec(spec)
spec.loader.exec_module(gather_benchmark_configs)


@pytest.mark.skipif(
sys.platform != "linux", reason="The script under test runs on Linux runners only"
)
class TestGatehrBenchmarkConfigs(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Dynamically import the script
script_path = os.path.join(".ci", "scripts", "gather_benchmark_configs.py")
spec = importlib.util.spec_from_file_location(
"gather_benchmark_configs", script_path
)
cls.gather_benchmark_configs = importlib.util.module_from_spec(spec)
spec.loader.exec_module(cls.gather_benchmark_configs)

def test_extract_all_configs_android(self):
android_configs = gather_benchmark_configs.extract_all_configs(
gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
android_configs = self.gather_benchmark_configs.extract_all_configs(
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "android"
)
self.assertIn("xnnpack_q8", android_configs)
self.assertIn("qnn_q8", android_configs)
self.assertIn("llama3_spinquant", android_configs)
self.assertIn("llama3_qlora", android_configs)

def test_extract_all_configs_ios(self):
ios_configs = gather_benchmark_configs.extract_all_configs(
gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
ios_configs = self.gather_benchmark_configs.extract_all_configs(
self.gather_benchmark_configs.BENCHMARK_CONFIGS, "ios"
)

self.assertIn("xnnpack_q8", ios_configs)
Expand All @@ -40,51 +45,114 @@ def test_extract_all_configs_ios(self):
self.assertIn("llama3_spinquant", ios_configs)
self.assertIn("llama3_qlora", ios_configs)

def test_skip_disabled_configs(self):
# Use patch as a context manager to avoid modifying DISABLED_CONFIGS and BENCHMARK_CONFIGS
with patch.dict(
self.gather_benchmark_configs.DISABLED_CONFIGS,
{
"mv3": [
self.gather_benchmark_configs.DisabledConfig(
config_name="disabled_config1",
github_issue="https://github.com/org/repo/issues/123",
),
self.gather_benchmark_configs.DisabledConfig(
config_name="disabled_config2",
github_issue="https://github.com/org/repo/issues/124",
),
]
},
), patch.dict(
self.gather_benchmark_configs.BENCHMARK_CONFIGS,
{
"ios": [
"disabled_config1",
"disabled_config2",
"enabled_config1",
"enabled_config2",
]
},
):
result = self.gather_benchmark_configs.generate_compatible_configs(
"mv3", target_os="ios"
)

# Assert that disabled configs are excluded
self.assertNotIn("disabled_config1", result)
self.assertNotIn("disabled_config2", result)
# Assert enabled configs are included
self.assertIn("enabled_config1", result)
self.assertIn("enabled_config2", result)

def test_disabled_configs_have_github_links(self):
github_issue_regex = re.compile(r"https://github\.com/.+/.+/issues/\d+")

for (
model_name,
disabled_configs,
) in self.gather_benchmark_configs.DISABLED_CONFIGS.items():
for disabled in disabled_configs:
with self.subTest(model_name=model_name, config=disabled.config_name):
# Assert that disabled is an instance of DisabledConfig
self.assertIsInstance(
disabled, self.gather_benchmark_configs.DisabledConfig
)

# Assert that github_issue is provided and matches the expected pattern
self.assertTrue(
disabled.github_issue
and github_issue_regex.match(disabled.github_issue),
f"Invalid or missing GitHub issue link for '{disabled.config_name}' in model '{model_name}'.",
)

def test_generate_compatible_configs_llama_model(self):
model_name = "meta-llama/Llama-3.2-1B"
target_os = "ios"
result = gather_benchmark_configs.generate_compatible_configs(
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["llama3_fb16", "llama3_coreml_ane"]
self.assertEqual(result, expected)

target_os = "android"
result = gather_benchmark_configs.generate_compatible_configs(
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["llama3_fb16"]
self.assertEqual(result, expected)

def test_generate_compatible_configs_quantized_llama_model(self):
model_name = "meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8"
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, None
)
expected = ["llama3_spinquant"]
self.assertEqual(result, expected)

model_name = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
result = gather_benchmark_configs.generate_compatible_configs(model_name, None)
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, None
)
expected = ["llama3_qlora"]
self.assertEqual(result, expected)

def test_generate_compatible_configs_non_genai_model(self):
model_name = "mv2"
target_os = "xplat"
result = gather_benchmark_configs.generate_compatible_configs(
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["xnnpack_q8"]
self.assertEqual(result, expected)

target_os = "android"
result = gather_benchmark_configs.generate_compatible_configs(
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["xnnpack_q8", "qnn_q8"]
self.assertEqual(result, expected)

target_os = "ios"
result = gather_benchmark_configs.generate_compatible_configs(
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
expected = ["xnnpack_q8", "coreml_fp16", "mps"]
Expand All @@ -93,22 +161,22 @@ def test_generate_compatible_configs_non_genai_model(self):
def test_generate_compatible_configs_unknown_model(self):
model_name = "unknown_model"
target_os = "ios"
result = gather_benchmark_configs.generate_compatible_configs(
result = self.gather_benchmark_configs.generate_compatible_configs(
model_name, target_os
)
self.assertEqual(result, [])

def test_is_valid_huggingface_model_id_valid(self):
valid_model = "meta-llama/Llama-3.2-1B"
self.assertTrue(
gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
self.gather_benchmark_configs.is_valid_huggingface_model_id(valid_model)
)

@patch("builtins.open", new_callable=mock_open)
@patch("os.getenv", return_value=None)
def test_set_output_no_github_env(self, mock_getenv, mock_file):
with patch("builtins.print") as mock_print:
gather_benchmark_configs.set_output("test_name", "test_value")
self.gather_benchmark_configs.set_output("test_name", "test_value")
mock_print.assert_called_with("::set-output name=test_name::test_value")

def test_device_pools_contains_all_devices(self):
Expand All @@ -120,7 +188,7 @@ def test_device_pools_contains_all_devices(self):
"google_pixel_8_pro",
]
for device in expected_devices:
self.assertIn(device, gather_benchmark_configs.DEVICE_POOLS)
self.assertIn(device, self.gather_benchmark_configs.DEVICE_POOLS)

def test_gather_benchmark_configs_cli(self):
args = {
Expand Down
Loading