From fa17199161510898211d265e2a366e843e85ba8c Mon Sep 17 00:00:00 2001 From: lucylq Date: Fri, 10 Oct 2025 09:51:22 -0700 Subject: [PATCH] Export lora weights to sep file Pull Request resolved: https://github.com/pytorch/executorch/pull/14756 ghstack-source-id: 315259024 @exported-using-ghexport Differential Revision: [D83777195](https://our.internmc.facebook.com/intern/diff/D83777195/) --- .ci/scripts/test_llama_lora.sh | 48 +++++++++++++++++++++-- examples/models/llama/export_llama_lib.py | 8 ++-- examples/models/llama/main.cpp | 34 +++++++++++++--- extension/llm/export/config/llm_config.py | 10 +++-- 4 files changed, 85 insertions(+), 15 deletions(-) diff --git a/.ci/scripts/test_llama_lora.sh b/.ci/scripts/test_llama_lora.sh index 6337bbf76a2..63325aa7778 100644 --- a/.ci/scripts/test_llama_lora.sh +++ b/.ci/scripts/test_llama_lora.sh @@ -94,7 +94,7 @@ else exit 1 fi -# Export LoRA PTE, PTD file. +# Export LoRA PTE, foundation PTD file. MODEL_SEPARATE="${MODEL_NAME}_separate" $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ @@ -114,7 +114,7 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ NOW=$(date +"%H:%M:%S") echo "Starting to run llama runner at ${NOW}" # shellcheck source=/dev/null -cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_path=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt +cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_paths=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt NOW=$(date +"%H:%M:%S") echo "Finished at ${NOW}" @@ -122,8 +122,8 @@ RESULT2=$(cat result2.txt) if [[ "${RESULT2}" == "${EXPECTED_PREFIX}"* ]]; then echo "Expected result prefix: ${EXPECTED_PREFIX}" echo "Actual result: ${RESULT2}" + # Do not clean up files if test passes, as they're re-used in the next test. echo "Success" - cleanup_files else echo "Expected result prefix: ${EXPECTED_PREFIX}" echo "Actual result: ${RESULT2}" @@ -131,3 +131,45 @@ else cleanup_files exit 1 fi + +# Export LoRA PTE, LoRA PTD, foundation PTD file. +MODEL_PROGRAM_ONLY="${MODEL_NAME}_program" +MODEL_LORA_WEIGHTS="lora_weights" +MODEL_FOUNDATION_WEIGHTS="foundation_weights" +$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ + base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ + base.params="${DOWNLOADED_PATH}/params.json" \ + base.adapter_checkpoint="${DOWNLOADED_PATH}/adapter_model.pt" \ + base.adapter_config="${DOWNLOADED_PATH}/adapter_config.json" \ + base.tokenizer_path="${DOWNLOADED_PATH}/tokenizer.model" \ + model.use_kv_cache=true \ + model.use_sdpa_with_kv_cache=true \ + model.dtype_override="fp32" \ + backend.xnnpack.enabled=true \ + backend.xnnpack.extended_ops=true \ + export.output_name="${MODEL_PROGRAM_ONLY}.pte" \ + export.foundation_weights_file="${MODEL_FOUNDATION_WEIGHTS}.ptd" \ + export.lora_weights_file="${MODEL_LORA_WEIGHTS}.ptd" + +# Run llama runner. +NOW=$(date +"%H:%M:%S") +echo "Starting to run llama runner at ${NOW}" +# shellcheck source=/dev/null +cmake-out/examples/models/llama/llama_main --model_path=${MODEL_PROGRAM_ONLY}.pte --data_paths="${MODEL_FOUNDATION_WEIGHTS}.ptd,${MODEL_LORA_WEIGHTS}.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} > result3.txt +NOW=$(date +"%H:%M:%S") +echo "Finished at ${NOW}" + +RESULT3=$(cat result3.txt) +if [[ "${RESULT3}" == "${EXPECTED_PREFIX}"* ]]; then + echo "Expected result prefix: ${EXPECTED_PREFIX}" + echo "Actual result: ${RESULT3}" + echo "Success" +else + echo "Expected result prefix: ${EXPECTED_PREFIX}" + echo "Actual result: ${RESULT3}" + echo "Failure; results not the same" + cleanup_files + exit 1 +fi + +cleanup_files diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index aa3b157c8da..7c2705f0a15 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1088,13 +1088,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 llm_config.backend.xnnpack.enabled = True if llm_config.backend.xnnpack.enabled: - if llm_config.export.foundation_weights_file is not None: + if ( + llm_config.export.foundation_weights_file is not None + or llm_config.export.lora_weights_file is not None + ): gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: ( llm_config.export.foundation_weights_file if "lora" not in x.name - else None + else llm_config.export.lora_weights_file ) - from executorch.exir.passes.external_constants_pass import ( delegate_external_constants_pass_unlifted, external_constants_pass, diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index 078d938ffde..0244a7f5661 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -8,6 +8,8 @@ */ #include +#include +#include #include @@ -21,7 +23,10 @@ DEFINE_string( "llama2.pte", "Model serialized in flatbuffer format."); -DEFINE_string(data_path, "", "Data file for the model."); +DEFINE_string( + data_paths, + "", + "Data files for the model. If multiple files are provided, they should be comma separated."); DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff."); @@ -54,6 +59,26 @@ DEFINE_int32( DEFINE_bool(warmup, false, "Whether to run a warmup run."); +// Helper function to parse comma-separated string lists +std::vector parseStringList(const std::string& input) { + std::vector result; + if (input.empty()) { + return result; + } + + std::stringstream ss(input); + std::string item; + while (std::getline(ss, item, ',')) { + // Trim whitespace + item.erase(0, item.find_first_not_of(" \t")); + item.erase(item.find_last_not_of(" \t") + 1); + if (!item.empty()) { + result.push_back(item); + } + } + return result; +} + int32_t main(int32_t argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -62,10 +87,7 @@ int32_t main(int32_t argc, char** argv) { // and users can create their own DataLoaders to load from arbitrary sources. const char* model_path = FLAGS_model_path.c_str(); - std::optional data_path = std::nullopt; - if (!FLAGS_data_path.empty()) { - data_path = FLAGS_data_path.c_str(); - } + std::vector data_paths = parseStringList(FLAGS_data_paths); const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); @@ -92,7 +114,7 @@ int32_t main(int32_t argc, char** argv) { #endif // create llama runner std::unique_ptr<::executorch::extension::llm::TextLLMRunner> runner = - example::create_llama_runner(model_path, tokenizer_path, data_path); + example::create_llama_runner(model_path, tokenizer_path, data_paths); if (runner == nullptr) { ET_LOG(Error, "Failed to create llama runner"); diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index b13001c005b..f15aad9e000 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -215,9 +215,10 @@ class ExportConfig: so_library: Shared library to specify custom quantized operators. export_only: Whether to stop right after torch.export() and just save the exported .pt2 graph file. - foundation_weights_file: configure the foundation weights of a model - to be placed in a separate file, external to the PTE. Pass the - intended file name here. + foundation_weights_file: place the foundation weights of the model into + a separate file, external to the PTE. Pass the file name here. + lora_weights_file: place the lora weights of the model into a + separate file, external to the PTE. Pass the file name here. """ max_seq_length: int = 128 @@ -227,6 +228,7 @@ class ExportConfig: so_library: Optional[str] = None export_only: bool = False foundation_weights_file: Optional[str] = None + lora_weights_file: Optional[str] = None def __post_init__(self): if self.max_context_length < self.max_seq_length: @@ -572,6 +574,8 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901 llm_config.export.export_only = args.export_only if hasattr(args, "foundation_weights_file"): llm_config.export.foundation_weights_file = args.foundation_weights_file + if hasattr(args, "lora_weights_file"): + llm_config.export.lora_weights_file = args.lora_weights_file # QuantizationConfig if hasattr(args, "quantization_mode"):