diff --git a/extension/llm/export/README.md b/extension/llm/export/README.md index 96f36acc1b4..e97b9e10462 100644 --- a/extension/llm/export/README.md +++ b/extension/llm/export/README.md @@ -23,9 +23,9 @@ The LLM export process transforms a model from its original format to an optimiz ## Usage -The export API supports two configuration approaches: +The export API supports a Hydra-style CLI where you can you configure using yaml and also CLI args. -### Option 1: Hydra CLI Arguments +### Hydra CLI Arguments Use structured configuration arguments directly on the command line: @@ -41,7 +41,7 @@ python -m extension.llm.export.export_llm \ quantization.qmode=8da4w ``` -### Option 2: Configuration File +### Configuration File Create a YAML configuration file and reference it: @@ -78,53 +78,21 @@ debug: verbose: true ``` -**Important**: You cannot mix both approaches. Use either CLI arguments OR a config file, not both. +You can you also still provide additional overrides using the CLI args as well: -## Example Commands - -### Export Qwen3 0.6B with XNNPACK backend and quantization ```bash -python -m extension.llm.export.export_llm \ - base.model_class=qwen3_0_6b \ - base.params=examples/models/qwen3/0_6b_config.json \ - base.metadata='{"get_bos_id": 151644, "get_eos_ids":[151645]}' \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - model.dtype_override=FP32 \ - export.max_seq_length=512 \ - export.output_name=qwen3_0_6b.pte \ - quantization.qmode=8da4w \ - backend.xnnpack.enabled=true \ - backend.xnnpack.extended_ops=true \ - debug.verbose=true +python -m extension.llm.export.export_llm + --config my_config.yaml + base.model_class="llama2" + +export.max_context_length=1024 ``` -### Export Phi-4-Mini with custom checkpoint -```bash -python -m extension.llm.export.export_llm \ - base.model_class=phi_4_mini \ - base.checkpoint=/path/to/phi4_checkpoint.pth \ - base.params=examples/models/phi-4-mini/config.json \ - base.metadata='{"get_bos_id":151643, "get_eos_ids":[151643]}' \ - model.use_kv_cache=true \ - model.use_sdpa_with_kv_cache=true \ - export.max_seq_length=256 \ - export.output_name=phi4_mini.pte \ - backend.xnnpack.enabled=true \ - debug.verbose=true -``` +Note that if a config file is specified and you want to specify a CLI arg that is not in the config, you need to prepend with a `+`. You can read more about this in the Hydra [docs](https://hydra.cc/docs/advanced/override_grammar/basic/). -### Export with CoreML backend (iOS optimization) -```bash -python -m extension.llm.export.export_llm \ - base.model_class=llama3 \ - model.use_kv_cache=true \ - export.max_seq_length=128 \ - backend.coreml.enabled=true \ - backend.coreml.compute_units=ALL \ - quantization.pt2e_quantize=coreml_c4w \ - debug.verbose=true -``` + +## Example Commands + +Please refer to the docs for some of our example suported models ([Llama](https://github.com/pytorch/executorch/blob/main/examples/models/llama/README.md), [Qwen3](https://github.com/pytorch/executorch/tree/main/examples/models/qwen3/README.md), [Phi-4-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi_4_mini/README.md)). ## Configuration Options @@ -134,4 +102,4 @@ For a complete reference of all available configuration options, see the [LlmCon - [Llama Examples](../../../examples/models/llama/README.md) - Comprehensive Llama export guide - [LLM Runner](../runner/) - Running exported models -- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview \ No newline at end of file +- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview diff --git a/extension/llm/export/export_llm.py b/extension/llm/export/export_llm.py index e995b329f30..e0467250a28 100644 --- a/extension/llm/export/export_llm.py +++ b/extension/llm/export/export_llm.py @@ -30,6 +30,7 @@ """ import argparse +import os import sys from typing import Any, List, Tuple @@ -45,7 +46,6 @@ def parse_config_arg() -> Tuple[str, List[Any]]: - """First parse out the arg for whether to use Hydra or the old CLI.""" parser = argparse.ArgumentParser(add_help=True) parser.add_argument("--config", type=str, help="Path to the LlmConfig file") args, remaining = parser.parse_known_args() @@ -56,6 +56,7 @@ def pop_config_arg() -> str: """ Removes '--config' and its value from sys.argv. Assumes --config is specified and argparse has already validated the args. + Returns the config file path. """ idx = sys.argv.index("--config") value = sys.argv[idx + 1] @@ -63,30 +64,42 @@ def pop_config_arg() -> str: return value -@hydra.main(version_base=None, config_name="llm_config") +def add_hydra_config_args(config_file_path: str) -> None: + """ + Breaks down the config file path into directory and filename, + resolves the directory to an absolute path, and adds the + --config_path and --config_name arguments to sys.argv. + """ + config_dir = os.path.dirname(config_file_path) + config_name = os.path.basename(config_file_path) + + # Resolve to absolute path + config_dir_abs = os.path.abspath(config_dir) + + # Add the hydra config arguments to sys.argv + sys.argv.extend(["--config-path", config_dir_abs, "--config-name", config_name]) + + +@hydra.main(version_base=None, config_name="llm_config", config_path=None) def hydra_main(llm_config: LlmConfig) -> None: - export_llama(OmegaConf.to_object(llm_config)) + structured = OmegaConf.structured(LlmConfig) + merged = OmegaConf.merge(structured, llm_config) + llm_config_obj = OmegaConf.to_object(merged) + export_llama(llm_config_obj) def main() -> None: + # First parse out the arg for whether to use Hydra or the old CLI. config, remaining_args = parse_config_arg() if config: - # Check if there are any remaining hydra CLI args when --config is specified - # This might change in the future to allow overriding config file values - if remaining_args: - raise ValueError( - "Cannot specify additional CLI arguments when using --config. " - f"Found: {remaining_args}. Use either --config file or hydra CLI args, not both." - ) - + # Pop out --config and its value so that they are not parsed by + # Hydra's main. config_file_path = pop_config_arg() - default_llm_config = LlmConfig() - llm_config_from_file = OmegaConf.load(config_file_path) - # Override defaults with values specified in the .yaml provided by --config. - merged_llm_config = OmegaConf.merge(default_llm_config, llm_config_from_file) - export_llama(merged_llm_config) - else: - hydra_main() + + # Add hydra config_path and config_name arguments to sys.argv. + add_hydra_config_args(config_file_path) + + hydra_main() if __name__ == "__main__": diff --git a/extension/llm/export/test/test_export_llm.py b/extension/llm/export/test/test_export_llm.py index 7d17b7819d3..e6f7160d4af 100644 --- a/extension/llm/export/test/test_export_llm.py +++ b/extension/llm/export/test/test_export_llm.py @@ -21,7 +21,7 @@ class TestExportLlm(unittest.TestCase): def test_parse_config_arg_with_config(self) -> None: """Test parse_config_arg when --config is provided.""" # Mock sys.argv to include --config - test_argv = ["script.py", "--config", "test_config.yaml", "extra", "args"] + test_argv = ["export_llm.py", "--config", "test_config.yaml", "extra", "args"] with patch.object(sys, "argv", test_argv): config_path, remaining = parse_config_arg() self.assertEqual(config_path, "test_config.yaml") @@ -29,7 +29,7 @@ def test_parse_config_arg_with_config(self) -> None: def test_parse_config_arg_without_config(self) -> None: """Test parse_config_arg when --config is not provided.""" - test_argv = ["script.py", "debug.verbose=True"] + test_argv = ["export_llm.py", "debug.verbose=True"] with patch.object(sys, "argv", test_argv): config_path, remaining = parse_config_arg() self.assertIsNone(config_path) @@ -37,11 +37,21 @@ def test_parse_config_arg_without_config(self) -> None: def test_pop_config_arg(self) -> None: """Test pop_config_arg removes --config and its value from sys.argv.""" - test_argv = ["script.py", "--config", "test_config.yaml", "other", "args"] + test_argv = ["export_llm.py", "--config", "test_config.yaml", "other", "args"] with patch.object(sys, "argv", test_argv): config_path = pop_config_arg() self.assertEqual(config_path, "test_config.yaml") - self.assertEqual(sys.argv, ["script.py", "other", "args"]) + self.assertEqual(sys.argv, ["export_llm.py", "other", "args"]) + + def test_with_cli_args(self) -> None: + """Test main function with only hydra CLI args.""" + test_argv = ["export_llm.py", "debug.verbose=True"] + with patch.object(sys, "argv", test_argv): + with patch( + "executorch.extension.llm.export.export_llm.hydra_main" + ) as mock_hydra: + main() + mock_hydra.assert_called_once() @patch("executorch.extension.llm.export.export_llm.export_llama") def test_with_config(self, mock_export_llama: MagicMock) -> None: @@ -70,7 +80,7 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: config_file = f.name try: - test_argv = ["script.py", "--config", config_file] + test_argv = ["export_llm.py", "--config", config_file] with patch.object(sys, "argv", test_argv): main() @@ -78,75 +88,65 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None: mock_export_llama.assert_called_once() called_config = mock_export_llama.call_args[0][0] self.assertEqual( - called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json" - ) - self.assertEqual(called_config["base"]["model_class"], "llama2") - self.assertEqual(called_config["base"]["preq_mode"].value, "8da4w") - self.assertEqual(called_config["model"]["dtype_override"].value, "fp16") - self.assertEqual(called_config["export"]["max_seq_length"], 256) - self.assertEqual( - called_config["quantization"]["pt2e_quantize"].value, "xnnpack_dynamic" - ) - self.assertEqual( - called_config["quantization"]["use_spin_quant"].value, "cuda" + called_config.base.tokenizer_path, "/path/to/tokenizer.json" ) + self.assertEqual(called_config.base.model_class, "llama2") + self.assertEqual(called_config.base.preq_mode.value, "8da4w") + self.assertEqual(called_config.model.dtype_override.value, "fp16") + self.assertEqual(called_config.export.max_seq_length, 256) self.assertEqual( - called_config["backend"]["coreml"]["quantize"].value, "c4w" + called_config.quantization.pt2e_quantize.value, "xnnpack_dynamic" ) + self.assertEqual(called_config.quantization.use_spin_quant.value, "cuda") + self.assertEqual(called_config.backend.coreml.quantize.value, "c4w") self.assertEqual( - called_config["backend"]["coreml"]["compute_units"].value, "cpu_and_gpu" + called_config.backend.coreml.compute_units.value, "cpu_and_gpu" ) finally: os.unlink(config_file) - def test_with_cli_args(self) -> None: - """Test main function with only hydra CLI args.""" - test_argv = ["script.py", "debug.verbose=True"] - with patch.object(sys, "argv", test_argv): - with patch( - "executorch.extension.llm.export.export_llm.hydra_main" - ) as mock_hydra: - main() - mock_hydra.assert_called_once() - - def test_config_with_cli_args_error(self) -> None: - """Test that --config rejects additional CLI arguments to prevent mixing approaches.""" + @patch("executorch.extension.llm.export.export_llm.export_llama") + def test_with_config_and_cli(self, mock_export_llama: MagicMock) -> None: + """Test main function with --config file and no hydra args.""" # Create a temporary config file with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write("base:\n checkpoint: /path/to/checkpoint.pth") - config_file = f.name - - try: - test_argv = ["script.py", "--config", config_file, "debug.verbose=True"] - with patch.object(sys, "argv", test_argv): - with self.assertRaises(ValueError) as cm: - main() - - error_msg = str(cm.exception) - self.assertIn( - "Cannot specify additional CLI arguments when using --config", - error_msg, - ) - finally: - os.unlink(config_file) - - def test_config_rejects_multiple_cli_args(self) -> None: - """Test that --config rejects multiple CLI arguments (not just single ones).""" - with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: - f.write("export:\n max_seq_length: 128") + f.write( + """ +base: + model_class: llama2 +model: + dtype_override: fp16 +backend: + xnnpack: + enabled: False +""" + ) config_file = f.name try: test_argv = [ - "script.py", + "export_llm.py", "--config", config_file, - "debug.verbose=True", - "export.output_dir=/tmp", + "base.model_class=stories110m", + "backend.xnnpack.enabled=True", ] with patch.object(sys, "argv", test_argv): - with self.assertRaises(ValueError): - main() + main() + + # Verify export_llama was called with config + mock_export_llama.assert_called_once() + called_config = mock_export_llama.call_args[0][0] + self.assertEqual( + called_config.base.model_class, "stories110m" + ) # Override from CLI. + self.assertEqual( + called_config.model.dtype_override.value, "fp16" + ) # From yaml. + self.assertEqual( + called_config.backend.xnnpack.enabled, + True, # Override from CLI. + ) finally: os.unlink(config_file)