From dbd9139c59367d489dd9865ecb7c942989dba4ee Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Tue, 12 Nov 2024 18:03:37 -0800 Subject: [PATCH] Accept model type parameter in export_llama (#6507) Summary: Specify model to export in the CLI. Test Plan: Exported the stories 110M model. ``` python -m examples.models.llama.export_llama -c stories110M/stories110M.pt -p stories110M/params.json -X -kv ``` PR chain: - [Add kwarg example inputs to eager model base](https://github.com/pytorch/executorch/pull/5765) - [Llama2 model cleanup](https://github.com/pytorch/executorch/pull/5859) - **YOU ARE HERE ~>** [Accept model type parameter in export_llama](https://github.com/pytorch/executorch/pull/5910) - [Export TorchTune llama3_2_vision in ET](https://github.com/pytorch/executorch/pull/5911) - [Runner changes for TorchTune Llama3.2 vision text decoder](https://github.com/pytorch/executorch/pull/6610) - [Add et version of TorchTune MHA for swapping with custom op](https://github.com/pytorch/executorch/pull/5912) Reviewed By: helunwencser Differential Revision: D65612837 Pulled By: dvorjackz --- ...lama3-qualcomm-ai-engine-direct-backend.md | 2 +- .../docs/delegates/qualcomm_README.md | 2 +- examples/models/llama/README.md | 14 +++++- examples/models/llama/eval_llama_lib.py | 2 +- examples/models/llama/export_llama.py | 3 +- examples/models/llama/export_llama_lib.py | 50 +++++++++++++------ examples/models/llama/runner/eager.py | 2 +- 7 files changed, 53 insertions(+), 22 deletions(-) diff --git a/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md b/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md index d928377ff28..133f9ec50bb 100644 --- a/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md +++ b/docs/source/llm/build-run-llama3-qualcomm-ai-engine-direct-backend.md @@ -39,7 +39,7 @@ To export Llama 3 8B instruct with the Qualcomm AI Engine Direct Backend, ensure ```bash # Please note that calibration_data must include the prompt template for special tokens. -python -m examples.models.llama.export_llama -t +python -m examples.models.llama.export_llama -t llama3/Meta-Llama-3-8B-Instruct/tokenizer.model -p -c --use_kv_cache --qnn --pt2e_quantize qnn_16a4w --disable_dynamic_shape --num_sharding 8 --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --optimized_rotation_path --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ``` diff --git a/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md b/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md index 8308da6d840..7790f66923c 100644 --- a/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md +++ b/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md @@ -158,7 +158,7 @@ To export Llama 3 8B instruct with the Qualcomm AI Engine Direct Backend, ensure * 8B models might need 16GB RAM on the device to run. ``` # Please note that calibration_data must include the prompt template for special tokens. -python -m examples.models.llama.export_llama -t -p -c --use_kv_cache --qnn --pt2e_quantize qnn_16a4w --disable_dynamic_shape --num_sharding 8 --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --optimized_rotation_path --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" +python -m examples.models.llama.export_llama -t -p -c --use_kv_cache --qnn --pt2e_quantize qnn_16a4w --disable_dynamic_shape --num_sharding 8 --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --optimized_rotation_path --calibration_data "<|start_header_id|>system<|end_header_id|>\n\nYou are a funny chatbot.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nCould you tell me about Facebook?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ``` ## Pushing Model and Tokenizer diff --git a/examples/models/llama/README.md b/examples/models/llama/README.md index 6fc66f6506c..cfa0fe04b1b 100644 --- a/examples/models/llama/README.md +++ b/examples/models/llama/README.md @@ -239,9 +239,19 @@ You can export and run the original Llama 3 8B instruct model. 2. Export model and generate `.pte` file ``` - python -m examples.models.llama.export_llama --checkpoint -p -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' --embedding-quantize 4,32 --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte" + python -m examples.models.llama.export_llama \ + --checkpoint \ + -p \ + -kv \ + --use_sdpa_with_kv_cache \ + -X \ + -qmode 8da4w \ + --group_size 128 \ + -d fp32 \ + --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \ + --embedding-quantize 4,32 \ + --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte" ``` - Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` as shown above to further reduce the model size. diff --git a/examples/models/llama/eval_llama_lib.py b/examples/models/llama/eval_llama_lib.py index 6e1847deca6..dd01365ba59 100644 --- a/examples/models/llama/eval_llama_lib.py +++ b/examples/models/llama/eval_llama_lib.py @@ -190,7 +190,7 @@ def gen_eval_wrapper( pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) # GPTFastEvalWrapper: Create a wrapper around a pre-exported model - manager: LLMEdgeManager = _prepare_for_llama_export(model_name, args) + manager: LLMEdgeManager = _prepare_for_llama_export(args) if len(quantizers) != 0: manager = manager.export().pt2e_quantize(quantizers) diff --git a/examples/models/llama/export_llama.py b/examples/models/llama/export_llama.py index 1899ccf4df6..eeb425c338c 100644 --- a/examples/models/llama/export_llama.py +++ b/examples/models/llama/export_llama.py @@ -23,10 +23,9 @@ def main() -> None: seed = 42 torch.manual_seed(seed) - modelname = "llama2" parser = build_args_parser() args = parser.parse_args() - export_llama(modelname, args) + export_llama(args) if __name__ == "__main__": diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 6e0ae37e120..0e015418d42 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -81,6 +81,10 @@ verbosity_setting = None +EXECUTORCH_DEFINED_MODELS = ["stories110m", "llama2", "llama3", "llama3_1", "llama3_2"] +TORCHTUNE_DEFINED_MODELS = [] + + class WeightType(Enum): LLAMA = "LLAMA" FAIRSEQ2 = "FAIRSEQ2" @@ -105,7 +109,7 @@ def verbose_export(): def build_model( - modelname: str = "model", + modelname: str = "llama3", extra_opts: str = "", *, par_local_output: bool = False, @@ -116,11 +120,11 @@ def build_model( else: output_dir_path = "." - argString = f"--checkpoint par:{modelname}_ckpt.pt --params par:{modelname}_params.json {extra_opts} --output-dir {output_dir_path}" + argString = f"--model {modelname} --checkpoint par:model_ckpt.pt --params par:model_params.json {extra_opts} --output-dir {output_dir_path}" parser = build_args_parser() args = parser.parse_args(shlex.split(argString)) # pkg_name = resource_pkg_name - return export_llama(modelname, args) + return export_llama(args) def build_args_parser() -> argparse.ArgumentParser: @@ -130,6 +134,12 @@ def build_args_parser() -> argparse.ArgumentParser: # parser.add_argument( # "-q", "--quantized_ckpt", default=None, help="quantized checkpoint file" # ) + parser.add_argument( + "--model", + default="llama3", + choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS, + help="The Lllama model architecture to use. stories110M, llama2, llama3, llama3_1, and llama3_2 use the same underlying LlamaTransformer architecture defined in ExecuTorch. All other models use TorchTune model definitions.", + ) parser.add_argument( "-E", "--embedding-quantize", @@ -480,13 +490,13 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: return return_val -def export_llama(modelname, args) -> str: +def export_llama(args) -> str: if args.profile_path is not None: try: from executorch.util.python_profiler import CProfilerFlameGraph with CProfilerFlameGraph(args.profile_path): - builder = _export_llama(modelname, args) + builder = _export_llama(args) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" @@ -497,14 +507,14 @@ def export_llama(modelname, args) -> str: ) return "" else: - builder = _export_llama(modelname, args) + builder = _export_llama(args) assert ( filename := builder.get_saved_pte_filename() ) is not None, "Fail to get file name from builder" return filename -def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: +def _prepare_for_llama_export(args) -> LLMEdgeManager: """ Helper function for export_llama. Loads the model from checkpoint and params, and sets up a LLMEdgeManager with initial transforms and dtype conversion. @@ -530,7 +540,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: return ( _load_llama_model( - modelname=modelname, + args.model, checkpoint=checkpoint_path, checkpoint_dir=checkpoint_dir, params_path=params_path, @@ -553,7 +563,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: args=args, ) .set_output_dir(output_dir_path) - .source_transform(_get_source_transforms(modelname, dtype_override, args)) + .source_transform(_get_source_transforms(args.model, dtype_override, args)) ) @@ -627,12 +637,12 @@ def _validate_args(args): ) -def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 +def _export_llama(args) -> LLMEdgeManager: # noqa: C901 _validate_args(args) pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args) # export_to_edge - builder_exported = _prepare_for_llama_export(modelname, args).export() + builder_exported = _prepare_for_llama_export(args).export() if args.export_only: exit() @@ -830,8 +840,8 @@ def _load_llama_model_metadata( def _load_llama_model( + modelname: str = "llama3", *, - modelname: str = "llama2", checkpoint: Optional[str] = None, checkpoint_dir: Optional[str] = None, params_path: str, @@ -859,15 +869,27 @@ def _load_llama_model( Returns: An instance of LLMEdgeManager which contains the eager mode model. """ + assert ( checkpoint or checkpoint_dir ) and params_path, "Both checkpoint/checkpoint_dir and params can't be empty" logging.info( f"Loading model with checkpoint={checkpoint}, params={params_path}, use_kv_cache={use_kv_cache}, weight_type={weight_type}" ) + + if modelname in EXECUTORCH_DEFINED_MODELS: + module_name = "llama" + model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py. + elif modelname in TORCHTUNE_DEFINED_MODELS: + raise NotImplementedError( + "Torchtune Llama models are not yet supported in ExecuTorch export." + ) + else: + raise ValueError(f"{modelname} is not a valid Llama model.") + model, example_inputs, example_kwarg_inputs, _ = EagerModelFactory.create_model( - module_name="llama", - model_class_name="Llama2Model", + module_name, + model_class_name, checkpoint=checkpoint, checkpoint_dir=checkpoint_dir, params=params_path, diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 0c7168b743d..e6f09b56e8c 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -37,7 +37,7 @@ def __init__(self, args): model_args=model_args, device="cuda" if torch.cuda.is_available() else "cpu", ) - manager: LLMEdgeManager = _prepare_for_llama_export("llama", args) + manager: LLMEdgeManager = _prepare_for_llama_export(args) self.model = manager.model.eval().to(device=self.device) def forward(