From e018c30a6a05c0ea987749a2a9cee938ce4a94d4 Mon Sep 17 00:00:00 2001 From: lucylq Date: Wed, 10 Sep 2025 12:02:59 -0700 Subject: [PATCH] JNI support for multiple ptd files ^ Differential Revision: [D82072929](https://our.internmc.facebook.com/intern/diff/D82072929/) [ghstack-poisoned] --- .../executorch/extension/llm/LlmModule.java | 25 +++++++++++----- extension/android/jni/jni_layer_llama.cpp | 29 ++++++++++++++----- 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index b014ceb75d8..21f853d0cf9 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -11,6 +11,7 @@ import com.facebook.jni.HybridData; import com.facebook.jni.annotations.DoNotStrip; import java.io.File; +import java.util.List; import org.pytorch.executorch.ExecuTorchRuntime; import org.pytorch.executorch.annotations.Experimental; @@ -32,14 +33,14 @@ public class LlmModule { @DoNotStrip private static native HybridData initHybrid( - int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath); + int modelType, String modulePath, String tokenizerPath, float temperature, List dataFiles); /** * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and - * data path. + * dataFiles. */ public LlmModule( - int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) { + int modelType, String modulePath, String tokenizerPath, float temperature, List dataFiles) { ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime(); File modelFile = new File(modulePath); @@ -50,12 +51,22 @@ public LlmModule( if (!tokenizerFile.canRead() || !tokenizerFile.isFile()) { throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath); } - mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataPath); + + mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataFiles); + } + + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * data path. + */ + public LlmModule( + int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) { + this(modelType, modulePath, tokenizerPath, temperature, List.of(dataPath)); } /** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */ public LlmModule(String modulePath, String tokenizerPath, float temperature) { - this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null); + this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, List.of()); } /** @@ -63,12 +74,12 @@ public LlmModule(String modulePath, String tokenizerPath, float temperature) { * path. */ public LlmModule(String modulePath, String tokenizerPath, float temperature, String dataPath) { - this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath); + this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, List.of(dataPath)); } /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */ public LlmModule(int modelType, String modulePath, String tokenizerPath, float temperature) { - this(modelType, modulePath, tokenizerPath, temperature, null); + this(modelType, modulePath, tokenizerPath, temperature, List.of()); } /** Constructs a LLM Module for a model with the given LlmModuleConfig */ diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 0c3550f151a..ef03ba639f8 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -140,13 +140,13 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { facebook::jni::alias_ref model_path, facebook::jni::alias_ref tokenizer_path, jfloat temperature, - facebook::jni::alias_ref data_path) { + facebook::jni::alias_ref data_files) { return makeCxxInstance( model_type_category, model_path, tokenizer_path, temperature, - data_path); + data_files); } ExecuTorchLlmJni( @@ -154,7 +154,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { facebook::jni::alias_ref model_path, facebook::jni::alias_ref tokenizer_path, jfloat temperature, - facebook::jni::alias_ref data_path = nullptr) { + facebook::jni::alias_ref data_files = nullptr) { temperature_ = temperature; #if defined(ET_USE_THREADPOOL) // Reserve 1 thread for the main thread. @@ -173,18 +173,32 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { model_path->toStdString().c_str(), llm::load_tokenizer(tokenizer_path->toStdString())); } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { - std::optional data_path_str = data_path - ? std::optional{data_path->toStdString()} - : std::nullopt; + std::unordered_set data_files_set; + if (data_files != nullptr) { + // Convert Java List to C++ unordered_set + auto list_class = facebook::jni::findClassStatic("java/util/List"); + auto size_method = list_class->getMethod("size"); + auto get_method = + list_class->getMethod(jint)>( + "get"); + + jint size = size_method(data_files); + for (jint i = 0; i < size; ++i) { + auto str_obj = get_method(data_files, i); + auto jstr = facebook::jni::static_ref_cast(str_obj); + data_files_set.insert(jstr->toStdString()); + } + } runner_ = executorch::extension::llm::create_text_llm_runner( model_path->toStdString(), llm::load_tokenizer(tokenizer_path->toStdString()), - data_path_str); + data_files_set); #if defined(EXECUTORCH_BUILD_QNN) } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { std::unique_ptr module = std::make_unique< executorch::extension::Module>( model_path->toStdString().c_str(), + data_files_set, executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); std::string decoder_model = "llama3"; // use llama3 for now runner_ = std::make_unique>( // QNN runner @@ -192,7 +206,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { decoder_model.c_str(), model_path->toStdString().c_str(), tokenizer_path->toStdString().c_str(), - data_path->toStdString().c_str(), ""); model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif