From c6362fb23503a887119219cfbe962aa989796bc7 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 24 Jun 2025 11:57:41 -0700 Subject: [PATCH 1/4] Allow CLI overrides --- extension/llm/export/export_llm.py | 43 ++++++---- extension/llm/export/test/test_export_llm.py | 90 ++++++++++---------- 2 files changed, 74 insertions(+), 59 deletions(-) diff --git a/extension/llm/export/export_llm.py b/extension/llm/export/export_llm.py index e995b329f30..7abe7bf3e91 100644 --- a/extension/llm/export/export_llm.py +++ b/extension/llm/export/export_llm.py @@ -38,14 +38,19 @@ from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import export_llama from hydra.core.config_store import ConfigStore +from hydra.core.hydra_config import HydraConfig from omegaconf import OmegaConf cs = ConfigStore.instance() cs.store(name="llm_config", node=LlmConfig) +# Need this global variable to pass an llm_config from yaml +# into the hydra-wrapped main function. +llm_config_from_yaml = None + + def parse_config_arg() -> Tuple[str, List[Any]]: - """First parse out the arg for whether to use Hydra or the old CLI.""" parser = argparse.ArgumentParser(add_help=True) parser.add_argument("--config", type=str, help="Path to the LlmConfig file") args, remaining = parser.parse_known_args() @@ -65,28 +70,34 @@ def pop_config_arg() -> str: @hydra.main(version_base=None, config_name="llm_config") def hydra_main(llm_config: LlmConfig) -> None: - export_llama(OmegaConf.to_object(llm_config)) + global llm_config_from_yaml + + # Override the LlmConfig constructed from the provide yaml config file + # with the CLI overrides. + if llm_config_from_yaml: + # Get CLI overrides (excluding defaults list). + overrides_list: List[str] = list(HydraConfig.get().overrides.get("task", [])) + override_cfg = OmegaConf.from_dotlist(overrides_list) + merged_config = OmegaConf.merge(llm_config_from_yaml, override_cfg) + export_llama(merged_config) + else: + export_llama(OmegaConf.to_object(llm_config)) def main() -> None: + # First parse out the arg for whether to use Hydra or the old CLI. config, remaining_args = parse_config_arg() if config: - # Check if there are any remaining hydra CLI args when --config is specified - # This might change in the future to allow overriding config file values - if remaining_args: - raise ValueError( - "Cannot specify additional CLI arguments when using --config. " - f"Found: {remaining_args}. Use either --config file or hydra CLI args, not both." - ) - + global llm_config_from_yaml + # Pop out --config and its value so that they are not parsed by + # Hyra's main. config_file_path = pop_config_arg() default_llm_config = LlmConfig() - llm_config_from_file = OmegaConf.load(config_file_path) - # Override defaults with values specified in the .yaml provided by --config. - merged_llm_config = OmegaConf.merge(default_llm_config, llm_config_from_file) - export_llama(merged_llm_config) - else: - hydra_main() + # Construct the LlmConfig from the config yaml file. + default_llm_config = LlmConfig() + from_yaml = OmegaConf.load(config_file_path) + llm_config_from_yaml = OmegaConf.merge(default_llm_config, from_yaml) + hydra_main() if __name__ == "__main__": diff --git a/extension/llm/export/test/test_export_llm.py b/extension/llm/export/test/test_export_llm.py index 7d17b7819d3..7ae98f97c5b 100644 --- a/extension/llm/export/test/test_export_llm.py +++ b/extension/llm/export/test/test_export_llm.py @@ -21,7 +21,7 @@ class TestExportLlm(unittest.TestCase): def test_parse_config_arg_with_config(self) -> None: """Test parse_config_arg when --config is provided.""" # Mock sys.argv to include --config - test_argv = ["script.py", "--config", "test_config.yaml", "extra", "args"] + test_argv = ["export_llm.py", "--config", "test_config.yaml", "extra", "args"] with patch.object(sys, "argv", test_argv): config_path, remaining = parse_config_arg() self.assertEqual(config_path, "test_config.yaml") @@ -29,7 +29,7 @@ def test_parse_config_arg_with_config(self) -> None: def test_parse_config_arg_without_config(self) -> None: """Test parse_config_arg when --config is not provided.""" - test_argv = ["script.py", "debug.verbose=True"] + test_argv = ["export_llm.py", "debug.verbose=True"] with patch.object(sys, "argv", test_argv): config_path, remaining = parse_config_arg() self.assertIsNone(config_path) @@ -37,11 +37,21 @@ def test_parse_config_arg_without_config(self) -> None: def test_pop_config_arg(self) -> None: """Test pop_config_arg removes --config and its value from sys.argv.""" - test_argv = ["script.py", "--config", "test_config.yaml", "other", "args"] + test_argv = ["export_llm.py", "--config", "test_config.yaml", "other", "args"] with patch.object(sys, "argv", test_argv): config_path = pop_config_arg() self.assertEqual(config_path, "test_config.yaml") - self.assertEqual(sys.argv, ["script.py", "other", "args"]) + self.assertEqual(sys.argv, ["export_llm.py", "other", "args"]) + + def test_with_cli_args(self) -> None: + """Test main function with only hydra CLI args.""" + test_argv = ["export_llm.py", "debug.verbose=True"] + with patch.object(sys, "argv", test_argv): + with patch( + "executorch.extension.llm.export.export_llm.hydra_main" + ) as mock_hydra: + main() + mock_hydra.assert_called_once() @patch("executorch.extension.llm.export.export_llm.export_llama") def test_with_config(self, mock_export_llama: MagicMock) -> None: @@ -70,7 +80,7 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: config_file = f.name try: - test_argv = ["script.py", "--config", config_file] + test_argv = ["export_llm.py", "--config", config_file] with patch.object(sys, "argv", test_argv): main() @@ -99,54 +109,48 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: finally: os.unlink(config_file) - def test_with_cli_args(self) -> None: - """Test main function with only hydra CLI args.""" - test_argv = ["script.py", "debug.verbose=True"] - with patch.object(sys, "argv", test_argv): - with patch( - "executorch.extension.llm.export.export_llm.hydra_main" - ) as mock_hydra: - main() - mock_hydra.assert_called_once() - - def test_config_with_cli_args_error(self) -> None: - """Test that --config rejects additional CLI arguments to prevent mixing approaches.""" + @patch("executorch.extension.llm.export.export_llm.export_llama") + def test_with_config_and_cli(self, mock_export_llama: MagicMock) -> None: + """Test main function with --config file and no hydra args.""" # Create a temporary config file with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write("base:\n checkpoint: /path/to/checkpoint.pth") - config_file = f.name - - try: - test_argv = ["script.py", "--config", config_file, "debug.verbose=True"] - with patch.object(sys, "argv", test_argv): - with self.assertRaises(ValueError) as cm: - main() - - error_msg = str(cm.exception) - self.assertIn( - "Cannot specify additional CLI arguments when using --config", - error_msg, - ) - finally: - os.unlink(config_file) - - def test_config_rejects_multiple_cli_args(self) -> None: - """Test that --config rejects multiple CLI arguments (not just single ones).""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write("export:\n max_seq_length: 128") + f.write( + """ +base: + model_class: llama2 +model: + dtype_override: fp16 +backend: + xnnpack: + enabled: False +""" + ) config_file = f.name try: test_argv = [ - "script.py", + "export_llm.py", "--config", config_file, - "debug.verbose=True", - "export.output_dir=/tmp", + "base.model_class=stories110m", + "backend.xnnpack.enabled=True", ] with patch.object(sys, "argv", test_argv): - with self.assertRaises(ValueError): - main() + main() + + # Verify export_llama was called with config + mock_export_llama.assert_called_once() + called_config = mock_export_llama.call_args[0][0] + self.assertEqual( + called_config["base"]["model_class"], "stories110m" + ) # Override from CLI. + self.assertEqual( + called_config["model"]["dtype_override"].value, "fp16" + ) # From yaml. + self.assertEqual( + called_config["backend"]["xnnpack"]["enabled"], + True, # Override from CLI. + ) finally: os.unlink(config_file) From 26078ae89e92a8575c18ff1edd07da8a772697f5 Mon Sep 17 00:00:00 2001 From: "Jack Zhang (aider)" <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 24 Jun 2025 17:23:01 -0700 Subject: [PATCH 2/4] Try splitting config into path and name --- extension/llm/export/export_llm.py | 52 ++++++++++++++++-------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/extension/llm/export/export_llm.py b/extension/llm/export/export_llm.py index 7abe7bf3e91..73ce9fc0ad7 100644 --- a/extension/llm/export/export_llm.py +++ b/extension/llm/export/export_llm.py @@ -30,6 +30,7 @@ """ import argparse +import os import sys from typing import Any, List, Tuple @@ -45,11 +46,6 @@ cs.store(name="llm_config", node=LlmConfig) -# Need this global variable to pass an llm_config from yaml -# into the hydra-wrapped main function. -llm_config_from_yaml = None - - def parse_config_arg() -> Tuple[str, List[Any]]: parser = argparse.ArgumentParser(add_help=True) parser.add_argument("--config", type=str, help="Path to the LlmConfig file") @@ -61,6 +57,7 @@ def pop_config_arg() -> str: """ Removes '--config' and its value from sys.argv. Assumes --config is specified and argparse has already validated the args. + Returns the config file path. """ idx = sys.argv.index("--config") value = sys.argv[idx + 1] @@ -68,20 +65,28 @@ def pop_config_arg() -> str: return value -@hydra.main(version_base=None, config_name="llm_config") -def hydra_main(llm_config: LlmConfig) -> None: - global llm_config_from_yaml +def add_hydra_config_args(config_file_path: str) -> None: + """ + Breaks down the config file path into directory and filename, + resolves the directory to an absolute path, and adds the + --config_path and --config_name arguments to sys.argv. + """ + config_dir = os.path.dirname(config_file_path) + config_name = os.path.basename(config_file_path) + + # Resolve to absolute path + config_dir_abs = os.path.abspath(config_dir) + + # Add the hydra config arguments to sys.argv + sys.argv.extend(["--config-path", config_dir_abs, "--config-name", config_name]) + - # Override the LlmConfig constructed from the provide yaml config file - # with the CLI overrides. - if llm_config_from_yaml: - # Get CLI overrides (excluding defaults list). - overrides_list: List[str] = list(HydraConfig.get().overrides.get("task", [])) - override_cfg = OmegaConf.from_dotlist(overrides_list) - merged_config = OmegaConf.merge(llm_config_from_yaml, override_cfg) - export_llama(merged_config) - else: - export_llama(OmegaConf.to_object(llm_config)) +@hydra.main(version_base=None, config_name="llm_config", config_path=None) +def hydra_main(llm_config: LlmConfig) -> None: + structured = OmegaConf.structured(LlmConfig) + merged = OmegaConf.merge(structured, llm_config) + llm_config_obj = OmegaConf.to_object(merged) + export_llama(llm_config_obj) def main() -> None: @@ -90,13 +95,12 @@ def main() -> None: if config: global llm_config_from_yaml # Pop out --config and its value so that they are not parsed by - # Hyra's main. + # Hydra's main. config_file_path = pop_config_arg() - default_llm_config = LlmConfig() - # Construct the LlmConfig from the config yaml file. - default_llm_config = LlmConfig() - from_yaml = OmegaConf.load(config_file_path) - llm_config_from_yaml = OmegaConf.merge(default_llm_config, from_yaml) + + # Add hydra config_path and config_name arguments to sys.argv. + add_hydra_config_args(config_file_path) + hydra_main() From 4dfd43fe4d716a2c68f866cad6aa7083eef73adb Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 25 Jun 2025 14:21:06 -0700 Subject: [PATCH 3/4] Update README --- extension/llm/export/README.md | 60 +++++++----------------------- extension/llm/export/export_llm.py | 8 ++-- 2 files changed, 17 insertions(+), 51 deletions(-) diff --git a/extension/llm/export/README.md b/extension/llm/export/README.md index 96f36acc1b4..e97b9e10462 100644 --- a/extension/llm/export/README.md +++ b/extension/llm/export/README.md @@ -23,9 +23,9 @@ The LLM export process transforms a model from its original format to an optimiz ## Usage -The export API supports two configuration approaches: +The export API supports a Hydra-style CLI where you can you configure using yaml and also CLI args. -### Option 1: Hydra CLI Arguments +### Hydra CLI Arguments Use structured configuration arguments directly on the command line: @@ -41,7 +41,7 @@ python -m extension.llm.export.export_llm \ quantization.qmode=8da4w ``` -### Option 2: Configuration File +### Configuration File Create a YAML configuration file and reference it: @@ -78,53 +78,21 @@ debug: verbose: true ``` -**Important**: You cannot mix both approaches. Use either CLI arguments OR a config file, not both. +You can you also still provide additional overrides using the CLI args as well: -## Example Commands - -### Export Qwen3 0.6B with XNNPACK backend and quantization ```bash -python -m extension.llm.export.export_llm \ - base.model_class=qwen3_0_6b \ - base.params=examples/models/qwen3/0_6b_config.json \ - base.metadata='{"get_bos_id": 151644, "get_eos_ids":[151645]}' \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - model.dtype_override=FP32 \ - export.max_seq_length=512 \ - export.output_name=qwen3_0_6b.pte \ - quantization.qmode=8da4w \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - debug.verbose=true +python -m extension.llm.export.export_llm + --config my_config.yaml + base.model_class="llama2" + +export.max_context_length=1024 ``` -### Export Phi-4-Mini with custom checkpoint -```bash -python -m extension.llm.export.export_llm \ - base.model_class=phi_4_mini \ - base.checkpoint=/path/to/phi4_checkpoint.pth \ - base.params=examples/models/phi-4-mini/config.json \ - base.metadata='{"get_bos_id":151643, "get_eos_ids":[151643]}' \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - export.max_seq_length=256 \ - export.output_name=phi4_mini.pte \ - backend.xnnpack.enabled=true \ - debug.verbose=true -``` +Note that if a config file is specified and you want to specify a CLI arg that is not in the config, you need to prepend with a `+`. You can read more about this in the Hydra [docs](https://hydra.cc/docs/advanced/override_grammar/basic/). -### Export with CoreML backend (iOS optimization) -```bash -python -m extension.llm.export.export_llm \ - base.model_class=llama3 \ - model.use_kv_cache=true \ - export.max_seq_length=128 \ - backend.coreml.enabled=true \ - backend.coreml.compute_units=ALL \ - quantization.pt2e_quantize=coreml_c4w \ - debug.verbose=true -``` + +## Example Commands + +Please refer to the docs for some of our example suported models ([Llama](https://github.com/pytorch/executorch/blob/main/examples/models/llama/README.md), [Qwen3](https://github.com/pytorch/executorch/tree/main/examples/models/qwen3/README.md), [Phi-4-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi_4_mini/README.md)). ## Configuration Options @@ -134,4 +102,4 @@ For a complete reference of all available configuration options, see the [LlmCon - [Llama Examples](../../../examples/models/llama/README.md) - Comprehensive Llama export guide - [LLM Runner](../runner/) - Running exported models -- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview \ No newline at end of file +- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview diff --git a/extension/llm/export/export_llm.py b/extension/llm/export/export_llm.py index 73ce9fc0ad7..e0467250a28 100644 --- a/extension/llm/export/export_llm.py +++ b/extension/llm/export/export_llm.py @@ -39,7 +39,6 @@ from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.export_llama_lib import export_llama from hydra.core.config_store import ConfigStore -from hydra.core.hydra_config import HydraConfig from omegaconf import OmegaConf cs = ConfigStore.instance() @@ -73,10 +72,10 @@ def add_hydra_config_args(config_file_path: str) -> None: """ config_dir = os.path.dirname(config_file_path) config_name = os.path.basename(config_file_path) - + # Resolve to absolute path config_dir_abs = os.path.abspath(config_dir) - + # Add the hydra config arguments to sys.argv sys.argv.extend(["--config-path", config_dir_abs, "--config-name", config_name]) @@ -93,11 +92,10 @@ def main() -> None: # First parse out the arg for whether to use Hydra or the old CLI. config, remaining_args = parse_config_arg() if config: - global llm_config_from_yaml # Pop out --config and its value so that they are not parsed by # Hydra's main. config_file_path = pop_config_arg() - + # Add hydra config_path and config_name arguments to sys.argv. add_hydra_config_args(config_file_path) From 06164b62a172ee738946842282bec8d22a526f67 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Wed, 25 Jun 2025 16:03:26 -0700 Subject: [PATCH 4/4] Fix test --- extension/llm/export/test/test_export_llm.py | 28 +++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/extension/llm/export/test/test_export_llm.py b/extension/llm/export/test/test_export_llm.py index 7ae98f97c5b..e6f7160d4af 100644 --- a/extension/llm/export/test/test_export_llm.py +++ b/extension/llm/export/test/test_export_llm.py @@ -88,23 +88,19 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: mock_export_llama.assert_called_once() called_config = mock_export_llama.call_args[0][0] self.assertEqual( - called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json" + called_config.base.tokenizer_path, "/path/to/tokenizer.json" ) - self.assertEqual(called_config["base"]["model_class"], "llama2") - self.assertEqual(called_config["base"]["preq_mode"].value, "8da4w") - self.assertEqual(called_config["model"]["dtype_override"].value, "fp16") - self.assertEqual(called_config["export"]["max_seq_length"], 256) + self.assertEqual(called_config.base.model_class, "llama2") + self.assertEqual(called_config.base.preq_mode.value, "8da4w") + self.assertEqual(called_config.model.dtype_override.value, "fp16") + self.assertEqual(called_config.export.max_seq_length, 256) self.assertEqual( - called_config["quantization"]["pt2e_quantize"].value, "xnnpack_dynamic" + called_config.quantization.pt2e_quantize.value, "xnnpack_dynamic" ) + self.assertEqual(called_config.quantization.use_spin_quant.value, "cuda") + self.assertEqual(called_config.backend.coreml.quantize.value, "c4w") self.assertEqual( - called_config["quantization"]["use_spin_quant"].value, "cuda" - ) - self.assertEqual( - called_config["backend"]["coreml"]["quantize"].value, "c4w" - ) - self.assertEqual( - called_config["backend"]["coreml"]["compute_units"].value, "cpu_and_gpu" + called_config.backend.coreml.compute_units.value, "cpu_and_gpu" ) finally: os.unlink(config_file) @@ -142,13 +138,13 @@ def test_with_config_and_cli(self, mock_export_llama: MagicMock) -> None: mock_export_llama.assert_called_once() called_config = mock_export_llama.call_args[0][0] self.assertEqual( - called_config["base"]["model_class"], "stories110m" + called_config.base.model_class, "stories110m" ) # Override from CLI. self.assertEqual( - called_config["model"]["dtype_override"].value, "fp16" + called_config.model.dtype_override.value, "fp16" ) # From yaml. self.assertEqual( - called_config["backend"]["xnnpack"]["enabled"], + called_config.backend.xnnpack.enabled, True, # Override from CLI. ) finally: