Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 14 additions & 46 deletions extension/llm/export/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ The LLM export process transforms a model from its original format to an optimiz

## Usage

The export API supports two configuration approaches:
The export API supports a Hydra-style CLI where you can you configure using yaml and also CLI args.

### Option 1: Hydra CLI Arguments
### Hydra CLI Arguments

Use structured configuration arguments directly on the command line:

Expand All @@ -41,7 +41,7 @@ python -m extension.llm.export.export_llm \
quantization.qmode=8da4w
```

### Option 2: Configuration File
### Configuration File

Create a YAML configuration file and reference it:

Expand Down Expand Up @@ -78,53 +78,21 @@ debug:
verbose: true
```

**Important**: You cannot mix both approaches. Use either CLI arguments OR a config file, not both.
You can you also still provide additional overrides using the CLI args as well:

## Example Commands

### Export Qwen3 0.6B with XNNPACK backend and quantization
```bash
python -m extension.llm.export.export_llm \
base.model_class=qwen3_0_6b \
base.params=examples/models/qwen3/0_6b_config.json \
base.metadata='{"get_bos_id": 151644, "get_eos_ids":[151645]}' \
model.use_kv_cache=true \
model.use_sdpa_with_kv_cache=true \
model.dtype_override=FP32 \
export.max_seq_length=512 \
export.output_name=qwen3_0_6b.pte \
quantization.qmode=8da4w \
backend.xnnpack.enabled=true \
backend.xnnpack.extended_ops=true \
debug.verbose=true
python -m extension.llm.export.export_llm
--config my_config.yaml
base.model_class="llama2"
+export.max_context_length=1024
```

### Export Phi-4-Mini with custom checkpoint
```bash
python -m extension.llm.export.export_llm \
base.model_class=phi_4_mini \
base.checkpoint=/path/to/phi4_checkpoint.pth \
base.params=examples/models/phi-4-mini/config.json \
base.metadata='{"get_bos_id":151643, "get_eos_ids":[151643]}' \
model.use_kv_cache=true \
model.use_sdpa_with_kv_cache=true \
export.max_seq_length=256 \
export.output_name=phi4_mini.pte \
backend.xnnpack.enabled=true \
debug.verbose=true
```
Note that if a config file is specified and you want to specify a CLI arg that is not in the config, you need to prepend with a `+`. You can read more about this in the Hydra [docs](https://hydra.cc/docs/advanced/override_grammar/basic/).

### Export with CoreML backend (iOS optimization)
```bash
python -m extension.llm.export.export_llm \
base.model_class=llama3 \
model.use_kv_cache=true \
export.max_seq_length=128 \
backend.coreml.enabled=true \
backend.coreml.compute_units=ALL \
quantization.pt2e_quantize=coreml_c4w \
debug.verbose=true
```

## Example Commands

Please refer to the docs for some of our example suported models ([Llama](https://github.com/pytorch/executorch/blob/main/examples/models/llama/README.md), [Qwen3](https://github.com/pytorch/executorch/tree/main/examples/models/qwen3/README.md), [Phi-4-mini](https://github.com/pytorch/executorch/tree/main/examples/models/phi_4_mini/README.md)).

## Configuration Options

Expand All @@ -134,4 +102,4 @@ For a complete reference of all available configuration options, see the [LlmCon

- [Llama Examples](../../../examples/models/llama/README.md) - Comprehensive Llama export guide
- [LLM Runner](../runner/) - Running exported models
- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview
- [ExecuTorch Documentation](https://pytorch.org/executorch/) - Framework overview
49 changes: 31 additions & 18 deletions extension/llm/export/export_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"""

import argparse
import os
import sys
from typing import Any, List, Tuple

Expand All @@ -45,7 +46,6 @@


def parse_config_arg() -> Tuple[str, List[Any]]:
"""First parse out the arg for whether to use Hydra or the old CLI."""
parser = argparse.ArgumentParser(add_help=True)
parser.add_argument("--config", type=str, help="Path to the LlmConfig file")
args, remaining = parser.parse_known_args()
Expand All @@ -56,37 +56,50 @@ def pop_config_arg() -> str:
"""
Removes '--config' and its value from sys.argv.
Assumes --config is specified and argparse has already validated the args.
Returns the config file path.
"""
idx = sys.argv.index("--config")
value = sys.argv[idx + 1]
del sys.argv[idx : idx + 2]
return value


@hydra.main(version_base=None, config_name="llm_config")
def add_hydra_config_args(config_file_path: str) -> None:
"""
Breaks down the config file path into directory and filename,
resolves the directory to an absolute path, and adds the
--config_path and --config_name arguments to sys.argv.
"""
config_dir = os.path.dirname(config_file_path)
config_name = os.path.basename(config_file_path)

# Resolve to absolute path
config_dir_abs = os.path.abspath(config_dir)

# Add the hydra config arguments to sys.argv
sys.argv.extend(["--config-path", config_dir_abs, "--config-name", config_name])


@hydra.main(version_base=None, config_name="llm_config", config_path=None)
def hydra_main(llm_config: LlmConfig) -> None:
export_llama(OmegaConf.to_object(llm_config))
structured = OmegaConf.structured(LlmConfig)
merged = OmegaConf.merge(structured, llm_config)
llm_config_obj = OmegaConf.to_object(merged)
export_llama(llm_config_obj)


def main() -> None:
# First parse out the arg for whether to use Hydra or the old CLI.
config, remaining_args = parse_config_arg()
if config:
# Check if there are any remaining hydra CLI args when --config is specified
# This might change in the future to allow overriding config file values
if remaining_args:
raise ValueError(
"Cannot specify additional CLI arguments when using --config. "
f"Found: {remaining_args}. Use either --config file or hydra CLI args, not both."
)

# Pop out --config and its value so that they are not parsed by
# Hydra's main.
config_file_path = pop_config_arg()
default_llm_config = LlmConfig()
llm_config_from_file = OmegaConf.load(config_file_path)
# Override defaults with values specified in the .yaml provided by --config.
merged_llm_config = OmegaConf.merge(default_llm_config, llm_config_from_file)
export_llama(merged_llm_config)
else:
hydra_main()

# Add hydra config_path and config_name arguments to sys.argv.
add_hydra_config_args(config_file_path)

hydra_main()


if __name__ == "__main__":
Expand Down
112 changes: 56 additions & 56 deletions extension/llm/export/test/test_export_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,37 @@ class TestExportLlm(unittest.TestCase):
def test_parse_config_arg_with_config(self) -> None:
"""Test parse_config_arg when --config is provided."""
# Mock sys.argv to include --config
test_argv = ["script.py", "--config", "test_config.yaml", "extra", "args"]
test_argv = ["export_llm.py", "--config", "test_config.yaml", "extra", "args"]
with patch.object(sys, "argv", test_argv):
config_path, remaining = parse_config_arg()
self.assertEqual(config_path, "test_config.yaml")
self.assertEqual(remaining, ["extra", "args"])

def test_parse_config_arg_without_config(self) -> None:
"""Test parse_config_arg when --config is not provided."""
test_argv = ["script.py", "debug.verbose=True"]
test_argv = ["export_llm.py", "debug.verbose=True"]
with patch.object(sys, "argv", test_argv):
config_path, remaining = parse_config_arg()
self.assertIsNone(config_path)
self.assertEqual(remaining, ["debug.verbose=True"])

def test_pop_config_arg(self) -> None:
"""Test pop_config_arg removes --config and its value from sys.argv."""
test_argv = ["script.py", "--config", "test_config.yaml", "other", "args"]
test_argv = ["export_llm.py", "--config", "test_config.yaml", "other", "args"]
with patch.object(sys, "argv", test_argv):
config_path = pop_config_arg()
self.assertEqual(config_path, "test_config.yaml")
self.assertEqual(sys.argv, ["script.py", "other", "args"])
self.assertEqual(sys.argv, ["export_llm.py", "other", "args"])

def test_with_cli_args(self) -> None:
"""Test main function with only hydra CLI args."""
test_argv = ["export_llm.py", "debug.verbose=True"]
with patch.object(sys, "argv", test_argv):
with patch(
"executorch.extension.llm.export.export_llm.hydra_main"
) as mock_hydra:
main()
mock_hydra.assert_called_once()

@patch("executorch.extension.llm.export.export_llm.export_llama")
def test_with_config(self, mock_export_llama: MagicMock) -> None:
Expand Down Expand Up @@ -70,83 +80,73 @@ def test_with_config(self, mock_export_llama: MagicMock) -> None:
config_file = f.name

try:
test_argv = ["script.py", "--config", config_file]
test_argv = ["export_llm.py", "--config", config_file]
with patch.object(sys, "argv", test_argv):
main()

# Verify export_llama was called with config
mock_export_llama.assert_called_once()
called_config = mock_export_llama.call_args[0][0]
self.assertEqual(
called_config["base"]["tokenizer_path"], "/path/to/tokenizer.json"
)
self.assertEqual(called_config["base"]["model_class"], "llama2")
self.assertEqual(called_config["base"]["preq_mode"].value, "8da4w")
self.assertEqual(called_config["model"]["dtype_override"].value, "fp16")
self.assertEqual(called_config["export"]["max_seq_length"], 256)
self.assertEqual(
called_config["quantization"]["pt2e_quantize"].value, "xnnpack_dynamic"
)
self.assertEqual(
called_config["quantization"]["use_spin_quant"].value, "cuda"
called_config.base.tokenizer_path, "/path/to/tokenizer.json"
)
self.assertEqual(called_config.base.model_class, "llama2")
self.assertEqual(called_config.base.preq_mode.value, "8da4w")
self.assertEqual(called_config.model.dtype_override.value, "fp16")
self.assertEqual(called_config.export.max_seq_length, 256)
self.assertEqual(
called_config["backend"]["coreml"]["quantize"].value, "c4w"
called_config.quantization.pt2e_quantize.value, "xnnpack_dynamic"
)
self.assertEqual(called_config.quantization.use_spin_quant.value, "cuda")
self.assertEqual(called_config.backend.coreml.quantize.value, "c4w")
self.assertEqual(
called_config["backend"]["coreml"]["compute_units"].value, "cpu_and_gpu"
called_config.backend.coreml.compute_units.value, "cpu_and_gpu"
)
finally:
os.unlink(config_file)

def test_with_cli_args(self) -> None:
"""Test main function with only hydra CLI args."""
test_argv = ["script.py", "debug.verbose=True"]
with patch.object(sys, "argv", test_argv):
with patch(
"executorch.extension.llm.export.export_llm.hydra_main"
) as mock_hydra:
main()
mock_hydra.assert_called_once()

def test_config_with_cli_args_error(self) -> None:
"""Test that --config rejects additional CLI arguments to prevent mixing approaches."""
@patch("executorch.extension.llm.export.export_llm.export_llama")
def test_with_config_and_cli(self, mock_export_llama: MagicMock) -> None:
"""Test main function with --config file and no hydra args."""
# Create a temporary config file
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write("base:\n checkpoint: /path/to/checkpoint.pth")
config_file = f.name

try:
test_argv = ["script.py", "--config", config_file, "debug.verbose=True"]
with patch.object(sys, "argv", test_argv):
with self.assertRaises(ValueError) as cm:
main()

error_msg = str(cm.exception)
self.assertIn(
"Cannot specify additional CLI arguments when using --config",
error_msg,
)
finally:
os.unlink(config_file)

def test_config_rejects_multiple_cli_args(self) -> None:
"""Test that --config rejects multiple CLI arguments (not just single ones)."""
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write("export:\n max_seq_length: 128")
f.write(
"""
base:
model_class: llama2
model:
dtype_override: fp16
backend:
xnnpack:
enabled: False
"""
)
config_file = f.name

try:
test_argv = [
"script.py",
"export_llm.py",
"--config",
config_file,
"debug.verbose=True",
"export.output_dir=/tmp",
"base.model_class=stories110m",
"backend.xnnpack.enabled=True",
]
with patch.object(sys, "argv", test_argv):
with self.assertRaises(ValueError):
main()
main()

# Verify export_llama was called with config
mock_export_llama.assert_called_once()
called_config = mock_export_llama.call_args[0][0]
self.assertEqual(
called_config.base.model_class, "stories110m"
) # Override from CLI.
self.assertEqual(
called_config.model.dtype_override.value, "fp16"
) # From yaml.
self.assertEqual(
called_config.backend.xnnpack.enabled,
True, # Override from CLI.
)
finally:
os.unlink(config_file)

Expand Down
Loading