diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index 872eccce872..b51e164d483 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -132,6 +132,8 @@ runtime.python_library( name = "export_library", srcs = [ "export_llama.py", + "export_llama_args.py", + "export_llama_hydra.py", "export_llama_lib.py", "model.py", ], @@ -148,6 +150,8 @@ runtime.python_library( ":source_transformation", "//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform", "//caffe2:torch", + "//executorch/examples/models/llama/config:llm_config", + "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/exir/passes:init_mutable_pass", "//executorch/examples/models:model_base", "//executorch/examples/models:models", diff --git a/examples/models/llama/config/llm_config.py b/examples/models/llama/config/llm_config.py index b929e756c3e..d9a8a8e6192 100644 --- a/examples/models/llama/config/llm_config.py +++ b/examples/models/llama/config/llm_config.py @@ -12,6 +12,7 @@ Uses dataclasses, which integrate with OmegaConf and Hydra. """ +import argparse import ast import re from dataclasses import dataclass, field @@ -468,6 +469,18 @@ class LlmConfig: quantization: QuantizationConfig = field(default_factory=QuantizationConfig) backend: BackendConfig = field(default_factory=BackendConfig) + @classmethod + def from_args(cls, args: argparse.Namespace) -> "LlmConfig": + """ + To support legacy purposes, this function converts CLI args from + argparse to an LlmConfig, which is used by the LLM export process. + """ + llm_config = LlmConfig() + + # TODO: conversion code. + + return llm_config + def __post_init__(self): self._validate_low_bit() diff --git a/examples/models/llama/export_llama.py b/examples/models/llama/export_llama.py index e25a8a007eb..63e76e28ba9 100644 --- a/examples/models/llama/export_llama.py +++ b/examples/models/llama/export_llama.py @@ -4,30 +4,50 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# Example script for exporting Llama2 to flatbuffer - -import logging - # force=True to ensure logging while in debugger. Set up logger before any # other imports. +import logging + FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT, force=True) +import argparse +import runpy import sys import torch -from .export_llama_lib import build_args_parser, export_llama - sys.setrecursionlimit(4096) +def parse_hydra_arg(): + """First parse out the arg for whether to use Hydra or the old CLI.""" + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument("--hydra", action="store_true") + args, remaining = parser.parse_known_args() + return args.hydra, remaining + + def main() -> None: seed = 42 torch.manual_seed(seed) - parser = build_args_parser() - args = parser.parse_args() - export_llama(args) + + use_hydra, remaining_args = parse_hydra_arg() + if use_hydra: + # The import runs the main function of export_llama_hydra with the remaining args + # under the Hydra framework. + sys.argv = [arg for arg in sys.argv if arg != "--hydra"] + print(f"running with {sys.argv}") + runpy.run_module( + "executorch.examples.models.llama.export_llama_hydra", run_name="__main__" + ) + else: + # Use the legacy version of the export_llama script which uses argsparse. + from executorch.examples.models.llama.export_llama_args import ( + main as export_llama_args_main, + ) + + export_llama_args_main(remaining_args) if __name__ == "__main__": diff --git a/examples/models/llama/export_llama_args.py b/examples/models/llama/export_llama_args.py new file mode 100644 index 00000000000..7a176d9b7d0 --- /dev/null +++ b/examples/models/llama/export_llama_args.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run export_llama with the legacy argparse setup. +""" + +from .export_llama_lib import build_args_parser, export_llama + + +def main(args) -> None: + parser = build_args_parser() + args = parser.parse_args(args) + export_llama(args) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/export_llama_hydra.py b/examples/models/llama/export_llama_hydra.py new file mode 100644 index 00000000000..73eca7e2a5a --- /dev/null +++ b/examples/models/llama/export_llama_hydra.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Run export_llama using the new Hydra CLI. +""" + +import hydra + +from executorch.examples.models.llama.config.llm_config import LlmConfig +from executorch.examples.models.llama.export_llama_lib import export_llama +from hydra.core.config_store import ConfigStore + +cs = ConfigStore.instance() +cs.store(name="llm_config", node=LlmConfig) + + +@hydra.main(version_base=None, config_name="llm_config") +def main(llm_config: LlmConfig) -> None: + export_llama(llm_config) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 11fb2fa3cbb..12406cc762e 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -27,6 +27,8 @@ from executorch.devtools.backend_debug import print_delegation_info from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func + +from executorch.examples.models.llama.config.llm_config import LlmConfig from executorch.examples.models.llama.hf_download import ( download_and_convert_hf_checkpoint, ) @@ -50,6 +52,7 @@ get_vulkan_quantizer, ) from executorch.util.activation_memory_profiler import generate_memory_trace +from omegaconf.dictconfig import DictConfig from ..model_factory import EagerModelFactory from .source_transformation.apply_spin_quant_r1_r2 import ( @@ -567,7 +570,23 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: return return_val -def export_llama(args) -> str: +def export_llama( + export_options: Union[argparse.Namespace, DictConfig], +) -> str: + if isinstance(export_options, argparse.Namespace): + # Legacy CLI. + args = export_options + llm_config = LlmConfig.from_args(export_options) # noqa: F841 + elif isinstance(export_options, DictConfig): + # Hydra CLI. + llm_config = export_options # noqa: F841 + else: + raise ValueError( + "Input to export_llama must be either of type argparse.Namespace or LlmConfig" + ) + + # TODO: refactor rest of export_llama to use llm_config instead of args. + # If a checkpoint isn't provided for an HF OSS model, download and convert the # weights first. if not args.checkpoint and args.model in HUGGING_FACE_REPO_IDS: diff --git a/examples/models/llama/install_requirements.sh b/examples/models/llama/install_requirements.sh index b9e0f9210c5..580a152a322 100755 --- a/examples/models/llama/install_requirements.sh +++ b/examples/models/llama/install_requirements.sh @@ -10,7 +10,7 @@ # Install tokenizers for hf .json tokenizer. # Install snakeviz for cProfile flamegraph # Install lm-eval for Model Evaluation with lm-evalution-harness. -pip install huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile +pip install hydra-core huggingface_hub tiktoken torchtune sentencepiece tokenizers snakeviz lm_eval==0.4.5 blobfile # Call the install helper for further setup python examples/models/llama/install_requirement_helper.py