diff --git a/.ci/scripts/test_llama_lora.sh b/.ci/scripts/test_llama_lora.sh index 6337bbf76a2..63325aa7778 100644 --- a/.ci/scripts/test_llama_lora.sh +++ b/.ci/scripts/test_llama_lora.sh @@ -94,7 +94,7 @@ else exit 1 fi -# Export LoRA PTE, PTD file. +# Export LoRA PTE, foundation PTD file. MODEL_SEPARATE="${MODEL_NAME}_separate" $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ @@ -114,7 +114,7 @@ $PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ NOW=$(date +"%H:%M:%S") echo "Starting to run llama runner at ${NOW}" # shellcheck source=/dev/null -cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_path=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt +cmake-out/examples/models/llama/llama_main --model_path=${MODEL_SEPARATE}.pte --data_paths=${MODEL_SEPARATE}.ptd --prompt="${PROMPT}" ${RUNTIME_ARGS} > result2.txt NOW=$(date +"%H:%M:%S") echo "Finished at ${NOW}" @@ -122,8 +122,8 @@ RESULT2=$(cat result2.txt) if [[ "${RESULT2}" == "${EXPECTED_PREFIX}"* ]]; then echo "Expected result prefix: ${EXPECTED_PREFIX}" echo "Actual result: ${RESULT2}" + # Do not clean up files if test passes, as they're re-used in the next test. echo "Success" - cleanup_files else echo "Expected result prefix: ${EXPECTED_PREFIX}" echo "Actual result: ${RESULT2}" @@ -131,3 +131,45 @@ else cleanup_files exit 1 fi + +# Export LoRA PTE, LoRA PTD, foundation PTD file. +MODEL_PROGRAM_ONLY="${MODEL_NAME}_program" +MODEL_LORA_WEIGHTS="lora_weights" +MODEL_FOUNDATION_WEIGHTS="foundation_weights" +$PYTHON_EXECUTABLE -m extension.llm.export.export_llm \ + base.checkpoint="${DOWNLOADED_PATH}/consolidated.00.pth" \ + base.params="${DOWNLOADED_PATH}/params.json" \ + base.adapter_checkpoint="${DOWNLOADED_PATH}/adapter_model.pt" \ + base.adapter_config="${DOWNLOADED_PATH}/adapter_config.json" \ + base.tokenizer_path="${DOWNLOADED_PATH}/tokenizer.model" \ + model.use_kv_cache=true \ + model.use_sdpa_with_kv_cache=true \ + model.dtype_override="fp32" \ + backend.xnnpack.enabled=true \ + backend.xnnpack.extended_ops=true \ + export.output_name="${MODEL_PROGRAM_ONLY}.pte" \ + export.foundation_weights_file="${MODEL_FOUNDATION_WEIGHTS}.ptd" \ + export.lora_weights_file="${MODEL_LORA_WEIGHTS}.ptd" + +# Run llama runner. +NOW=$(date +"%H:%M:%S") +echo "Starting to run llama runner at ${NOW}" +# shellcheck source=/dev/null +cmake-out/examples/models/llama/llama_main --model_path=${MODEL_PROGRAM_ONLY}.pte --data_paths="${MODEL_FOUNDATION_WEIGHTS}.ptd,${MODEL_LORA_WEIGHTS}.ptd" --prompt="${PROMPT}" ${RUNTIME_ARGS} > result3.txt +NOW=$(date +"%H:%M:%S") +echo "Finished at ${NOW}" + +RESULT3=$(cat result3.txt) +if [[ "${RESULT3}" == "${EXPECTED_PREFIX}"* ]]; then + echo "Expected result prefix: ${EXPECTED_PREFIX}" + echo "Actual result: ${RESULT3}" + echo "Success" +else + echo "Expected result prefix: ${EXPECTED_PREFIX}" + echo "Actual result: ${RESULT3}" + echo "Failure; results not the same" + cleanup_files + exit 1 +fi + +cleanup_files diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 7fa9357f23b..6d8ba54c010 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1136,20 +1136,15 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901 llm_config.backend.xnnpack.enabled = True if llm_config.backend.xnnpack.enabled: - if llm_config.export.foundation_weights_file is not None: - if llm_config.export.lora_weights_file is not None: - gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: ( - llm_config.export.foundation_weights_file - if "lora" not in x.name - else None - ) - else: - gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: ( - llm_config.export.foundation_weights_file - if "lora" not in x.name - else llm_config.export.lora_weights_file - ) - + if ( + llm_config.export.foundation_weights_file is not None + or llm_config.export.lora_weights_file is not None + ): + gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: ( + llm_config.export.foundation_weights_file + if "lora" not in x.name + else llm_config.export.lora_weights_file + ) from executorch.exir.passes.external_constants_pass import ( delegate_external_constants_pass_unlifted, external_constants_pass, diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index 078d938ffde..0244a7f5661 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -8,6 +8,8 @@ */ #include +#include +#include #include @@ -21,7 +23,10 @@ DEFINE_string( "llama2.pte", "Model serialized in flatbuffer format."); -DEFINE_string(data_path, "", "Data file for the model."); +DEFINE_string( + data_paths, + "", + "Data files for the model. If multiple files are provided, they should be comma separated."); DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff."); @@ -54,6 +59,26 @@ DEFINE_int32( DEFINE_bool(warmup, false, "Whether to run a warmup run."); +// Helper function to parse comma-separated string lists +std::vector parseStringList(const std::string& input) { + std::vector result; + if (input.empty()) { + return result; + } + + std::stringstream ss(input); + std::string item; + while (std::getline(ss, item, ',')) { + // Trim whitespace + item.erase(0, item.find_first_not_of(" \t")); + item.erase(item.find_last_not_of(" \t") + 1); + if (!item.empty()) { + result.push_back(item); + } + } + return result; +} + int32_t main(int32_t argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -62,10 +87,7 @@ int32_t main(int32_t argc, char** argv) { // and users can create their own DataLoaders to load from arbitrary sources. const char* model_path = FLAGS_model_path.c_str(); - std::optional data_path = std::nullopt; - if (!FLAGS_data_path.empty()) { - data_path = FLAGS_data_path.c_str(); - } + std::vector data_paths = parseStringList(FLAGS_data_paths); const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); @@ -92,7 +114,7 @@ int32_t main(int32_t argc, char** argv) { #endif // create llama runner std::unique_ptr<::executorch::extension::llm::TextLLMRunner> runner = - example::create_llama_runner(model_path, tokenizer_path, data_path); + example::create_llama_runner(model_path, tokenizer_path, data_paths); if (runner == nullptr) { ET_LOG(Error, "Failed to create llama runner");