diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 86b7e957628..d2caccd5897 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -85,6 +85,7 @@ runtime.python_binary( ":export_library", "//caffe2:torch", "//executorch/extension/pybindings:aten_lib", + "//executorch/extension/llm/export:export_llm_lib", ], ) @@ -133,8 +134,6 @@ runtime.python_library( name = "export_library", srcs = [ "export_llama.py", - "export_llama_args.py", - "export_llama_hydra.py", "export_llama_lib.py", "model.py", ], diff --git a/examples/models/llama/config/llm_config.py b/examples/models/llama/config/llm_config.py index a5c486a8c1e..034d8af7562 100644 --- a/examples/models/llama/config/llm_config.py +++ b/examples/models/llama/config/llm_config.py @@ -86,7 +86,7 @@ class BaseConfig: checkpoint_dir: Optional[str] = None tokenizer_path: Optional[str] = None metadata: Optional[str] = None - use_lora: int = int + use_lora: int = 0 fairseq2: bool = False preq_mode: Optional[PreqMode] = None preq_group_size: int = 32 @@ -214,7 +214,7 @@ class ExportConfig: max_seq_length: int = 128 max_context_length: int = 128 - output_dir: Optional[str] = None + output_dir: str = "." output_name: Optional[str] = None so_library: Optional[str] = None export_only: bool = False diff --git a/examples/models/llama/export_llama.py b/examples/models/llama/export_llama.py index 63e76e28ba9..93782b00e37 100644 --- a/examples/models/llama/export_llama.py +++ b/examples/models/llama/export_llama.py @@ -17,6 +17,11 @@ import torch +from executorch.examples.models.llama.export_llama_lib import ( + build_args_parser, + export_llama, +) + sys.setrecursionlimit(4096) @@ -39,15 +44,12 @@ def main() -> None: sys.argv = [arg for arg in sys.argv if arg != "--hydra"] print(f"running with {sys.argv}") runpy.run_module( - "executorch.examples.models.llama.export_llama_hydra", run_name="__main__" + "executorch.extension.llm.export.export_llm", run_name="__main__" ) else: - # Use the legacy version of the export_llama script which uses argsparse. - from executorch.examples.models.llama.export_llama_args import ( - main as export_llama_args_main, - ) - - export_llama_args_main(remaining_args) + parser = build_args_parser() + remaining_args = parser.parse_args(remaining_args) + export_llama(remaining_args) if __name__ == "__main__": diff --git a/examples/models/llama/export_llama_args.py b/examples/models/llama/export_llama_args.py deleted file mode 100644 index 7a176d9b7d0..00000000000 --- a/examples/models/llama/export_llama_args.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Run export_llama with the legacy argparse setup. -""" - -from .export_llama_lib import build_args_parser, export_llama - - -def main(args) -> None: - parser = build_args_parser() - args = parser.parse_args(args) - export_llama(args) - - -if __name__ == "__main__": - main() diff --git a/examples/models/llama/export_llama_hydra.py b/examples/models/llama/export_llama_hydra.py deleted file mode 100644 index 4871de00e25..00000000000 --- a/examples/models/llama/export_llama_hydra.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Run export_llama using the new Hydra CLI. -""" - -import hydra - -from executorch.examples.models.llama.config.llm_config import LlmConfig -from executorch.examples.models.llama.export_llama_lib import export_llama -from hydra.core.config_store import ConfigStore -from omegaconf import OmegaConf - -cs = ConfigStore.instance() -cs.store(name="llm_config", node=LlmConfig) - - -@hydra.main(version_base=None, config_name="llm_config") -def main(llm_config: LlmConfig) -> None: - export_llama(OmegaConf.to_object(llm_config)) - - -if __name__ == "__main__": - main() diff --git a/extension/llm/export/TARGETS b/extension/llm/export/TARGETS index a85370fc49c..7acf026a8da 100644 --- a/extension/llm/export/TARGETS +++ b/extension/llm/export/TARGETS @@ -47,6 +47,41 @@ runtime.python_library( ], ) +runtime.python_binary( + name = "export_llm", + srcs = [ + "export_llm.py", + ], + main_function = "executorch.extension.llm.export.export_llm.main", + preload_deps = [ + "//executorch/extension/llm/custom_ops:model_sharding_py", + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + "//executorch/kernels/quantized:aot_lib", + ], + deps = [ + "fbsource//third-party/pypi/hydra-core:hydra-core", + "fbsource//third-party/pypi/omegaconf:omegaconf", + "//executorch/examples/models/llama:export_library", + "//executorch/extension/pybindings:aten_lib", + ], +) + +runtime.python_library( + name = "export_llm_lib", + srcs = [ + "export_llm.py", + ], + deps = [ + "fbsource//third-party/pypi/hydra-core:hydra-core", + "fbsource//third-party/pypi/omegaconf:omegaconf", + "//executorch/examples/models/llama:export_library", + ], + visibility = [ + "//executorch/examples/...", + "//executorch/extension/llm/...", + ], +) + runtime.python_test( name = "export_passes_test", srcs = [ diff --git a/extension/llm/export/export_llm.py b/extension/llm/export/export_llm.py new file mode 100644 index 00000000000..09a15d6ab58 --- /dev/null +++ b/extension/llm/export/export_llm.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Export an LLM with ExecuTorch. Currently follows the following steps: +1. Instantiate our custom PyTorch transformer definition from examples/llama/models/llama_transformer.py. +2. Load weights into the model. +3. Apply source transformations/TorchAO quantization. +4. Export model to intermediate IRs. +5. Graph transformations/PT2E quantization. +6. Partition graph and delegate to backend(s). +7. Export to final ExecuTorch .pte format. + +Example usage using full CLI arguments: +python -m extension.llm.export.export_llm \ + base.model_class="llama3" \ + model.use_sdpa_with_kv_cache=True \ + model.use_kv_cache=True \ + debug.verbose=True \ + backend.xnnpack.enabled=True \ + backend.xnnpack.extended_ops=True \ + quantization.qmode="8da4w" +""" + +import hydra + +from executorch.examples.models.llama.config.llm_config import LlmConfig +from executorch.examples.models.llama.export_llama_lib import export_llama +from hydra.core.config_store import ConfigStore +from omegaconf import OmegaConf + +cs = ConfigStore.instance() +cs.store(name="llm_config", node=LlmConfig) + + +@hydra.main(version_base=None, config_path=None, config_name="llm_config") +def main(llm_config: LlmConfig) -> None: + export_llama(OmegaConf.to_object(llm_config)) + + +if __name__ == "__main__": + main()