From efe81eebd5671b355cbf1d6cd8967641481e15c4 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 8 Sep 2025 21:05:55 -0700 Subject: [PATCH 01/10] LlmModule prefill refactor --- .../executorch/extension/llm/LlmModule.java | 46 +++++++++++++------ extension/android/jni/jni_layer_llama.cpp | 31 +++++-------- 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index b014ceb75d8..7c35dbf2989 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -173,20 +173,23 @@ public native int generate( * @param height Input image height * @param channels Input image number of channels * @param startPos The starting position in KV cache of the input in the LLM. - * @return The updated starting position in KV cache of the input in the LLM. + * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer + * exposed to user. * @throws RuntimeException if the prefill failed */ + @Deprecated public long prefillImages(int[] image, int width, int height, int channels, long startPos) { - long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos); - if (nativeResult[0] != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); + if (startPos == 0) { + resetContext(); } - return nativeResult[1]; + int nativeResult = prefillImagesNative(image, width, height, channels); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); + } + return 0; } - // returns a tuple of (status, updated startPos) - private native long[] prefillImagesNative( - int[] image, int width, int height, int channels, long startPos); + private native int prefillImagesNative(int[] image, int width, int height, int channels); /** * Prefill an LLaVA Module with the given text input. @@ -196,23 +199,30 @@ private native long[] prefillImagesNative( * reference and will be updated inside this function. * @param bos The number of BOS (begin of sequence) token. * @param eos The number of EOS (end of sequence) token. - * @return The updated starting position in KV cache of the input in the LLM. + * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer + * exposed to user. * @throws RuntimeException if the prefill failed */ + @Deprecated public long prefillPrompt(String prompt, long startPos, int bos, int eos) { - long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos); - if (nativeResult[0] != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); + if (startPos == 0) { + resetContext(); } - return nativeResult[1]; + int nativeResult = prefillPromptNative(prompt, bos, eos); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); + } + return 0; } // returns a tuple of (status, updated startPos) - private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos); + private native int prefillPromptNative(String prompt, int bos, int eos); /** * Generate tokens from the given prompt, starting from the given position. * + *

This is a deprecated API. Please use {@link #generate(String, int, LlmCallback, boolean)} + * * @param prompt The text prompt to LLaVA. * @param seqLen The total sequence length, including the prompt tokens and new tokens. * @param startPos The starting position in KV cache of the input in the LLM. @@ -220,9 +230,17 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { * @param echo indicate whether to echo the input prompt or not. * @return The error code. */ + @Deprecated public native int generateFromPos( String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo); + /** + * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. + * + *

The startPos will be reset to 0. + */ + public native void resetContext(); + /** Stop current generate() before it finishes. */ @DoNotStrip public native void stop(); diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 0c3550f151a..b4a9320e20b 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -260,28 +260,19 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { // Returns a tuple of (error, start_pos) // Contract is valid within an AAR (JNI + corresponding Java code) // If the first element is not Error::Ok, the other element is undefined. - facebook::jni::local_ref prefill_prompt( + jint prefill_prompt( facebook::jni::alias_ref prompt, - jlong start_pos, jint bos, jint eos) { prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - facebook::jni::local_ref tuple_result = - facebook::jni::make_long_array(2); - tuple_result->pin()[0] = static_cast(Error::Ok); - return tuple_result; + return 0; } - // Returns a tuple of (error, start_pos) - // Contract is valid within an AAR (JNI + corresponding Java code) - // If the first element is not Error::Ok, the other element is undefined. - - facebook::jni::local_ref prefill_images( + jint prefill_images( facebook::jni::alias_ref image, jint width, jint height, - jint channels, - jlong start_pos) { + jint channels) { std::vector images; auto image_size = image->size(); if (image_size != 0) { @@ -296,11 +287,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { llm::MultimodalInput{std::move(image_runner)}); } - facebook::jni::local_ref tuple_result = - facebook::jni::make_long_array(2); - - tuple_result->pin()[0] = static_cast(Error::Ok); - return tuple_result; + return 0; } jint generate_from_pos( @@ -325,9 +312,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { .seq_len = seq_len, .temperature = temperature_, }; - return static_cast(runner_->generate_from_pos( + return static_cast(runner_->generate( prompt->toStdString(), - start_pos, config, [callback](std::string result) { callback->onResult(result); }, [callback](const llm::Stats& stats) { callback->onStats(stats); })); @@ -343,6 +329,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { } } + void reset_context() { + runner_->reset(); + } + jint load() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { return static_cast(multi_modal_runner_->load()); @@ -364,6 +354,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { "prefillPromptNative", ExecuTorchLlmJni::prefill_prompt), makeNativeMethod( "generateFromPos", ExecuTorchLlmJni::generate_from_pos), + makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), }); } }; From 015a6ab78f7f9c80e1d3970b83b583d6de9f8557 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 8 Sep 2025 21:11:35 -0700 Subject: [PATCH 02/10] Doing some rename --- .../org/pytorch/executorch/extension/llm/LlmModule.java | 8 ++++---- extension/android/jni/jni_layer_llama.cpp | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 7c35dbf2989..e4be53f65cd 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -182,14 +182,14 @@ public long prefillImages(int[] image, int width, int height, int channels, long if (startPos == 0) { resetContext(); } - int nativeResult = prefillImagesNative(image, width, height, channels); + int nativeResult = appendImagesInput(image, width, height, channels); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private native int prefillImagesNative(int[] image, int width, int height, int channels); + private native int appendImagesInput(int[] image, int width, int height, int channels); /** * Prefill an LLaVA Module with the given text input. @@ -208,7 +208,7 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { if (startPos == 0) { resetContext(); } - int nativeResult = prefillPromptNative(prompt, bos, eos); + int nativeResult = appendTextInput(prompt, bos, eos); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } @@ -216,7 +216,7 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { } // returns a tuple of (status, updated startPos) - private native int prefillPromptNative(String prompt, int bos, int eos); + private native int appendTextInput(String prompt, int bos, int eos); /** * Generate tokens from the given prompt, starting from the given position. diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index b4a9320e20b..85d97ed2797 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -260,7 +260,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { // Returns a tuple of (error, start_pos) // Contract is valid within an AAR (JNI + corresponding Java code) // If the first element is not Error::Ok, the other element is undefined. - jint prefill_prompt( + jint append_text_input( facebook::jni::alias_ref prompt, jint bos, jint eos) { @@ -268,7 +268,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return 0; } - jint prefill_images( + jint append_images_input( facebook::jni::alias_ref image, jint width, jint height, @@ -349,9 +349,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { makeNativeMethod("stop", ExecuTorchLlmJni::stop), makeNativeMethod("load", ExecuTorchLlmJni::load), makeNativeMethod( - "prefillImagesNative", ExecuTorchLlmJni::prefill_images), + "appendImagesInput", ExecuTorchLlmJni::append_images_input), makeNativeMethod( - "prefillPromptNative", ExecuTorchLlmJni::prefill_prompt), + "appendTextInput", ExecuTorchLlmJni::append_text_input), makeNativeMethod( "generateFromPos", ExecuTorchLlmJni::generate_from_pos), makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), From 16b6d1cb564a3f5cf1fbb06c9ede5bcae19e2117 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 8 Sep 2025 21:14:59 -0700 Subject: [PATCH 03/10] Java layer no longer need a separate generateFromPos --- .../executorch/extension/llm/LlmModule.java | 6 ++-- extension/android/jni/jni_layer_llama.cpp | 33 ------------------- 2 files changed, 4 insertions(+), 35 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index e4be53f65cd..ec2f38bb7d3 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -231,8 +231,10 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { * @return The error code. */ @Deprecated - public native int generateFromPos( - String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo); + public int generateFromPos( + String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) { + return generate(prompt, seqLen, callback, echo); + } /** * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 85d97ed2797..7cb827bf827 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -290,37 +290,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return 0; } - jint generate_from_pos( - facebook::jni::alias_ref prompt, - jint seq_len, - jlong start_pos, - facebook::jni::alias_ref callback, - jboolean echo) { - if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - std::vector inputs = prefill_inputs_; - prefill_inputs_.clear(); - inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - return static_cast(multi_modal_runner_->generate( - inputs, - llm::GenerationConfig{ - .echo = static_cast(echo), .seq_len = seq_len}, - [callback](const std::string& result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); })); - } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { - executorch::extension::llm::GenerationConfig config{ - .echo = static_cast(echo), - .seq_len = seq_len, - .temperature = temperature_, - }; - return static_cast(runner_->generate( - prompt->toStdString(), - config, - [callback](std::string result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); })); - } - return static_cast(executorch::runtime::Error::InvalidArgument); - } - void stop() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_->stop(); @@ -352,8 +321,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { "appendImagesInput", ExecuTorchLlmJni::append_images_input), makeNativeMethod( "appendTextInput", ExecuTorchLlmJni::append_text_input), - makeNativeMethod( - "generateFromPos", ExecuTorchLlmJni::generate_from_pos), makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), }); } From e08df474e53c31e0bf7f2f04cd7daed053630075 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 8 Sep 2025 21:17:12 -0700 Subject: [PATCH 04/10] Remove generateFromPos API --- .../executorchllamademo/MainActivity.java | 3 +-- .../executorch/extension/llm/LlmModule.java | 18 ------------------ 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index b26031d89a6..fb7cc01206c 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -778,10 +778,9 @@ public void run() { mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()) == ModelUtils.VISION_MODEL) { - mModule.generateFromPos( + mModule.generate( finalPrompt, ModelUtils.VISION_MODEL_SEQ_LEN, - startPos, MainActivity.this, false); } else if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_GUARD_3) { diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index ec2f38bb7d3..6599cb4c15d 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -218,24 +218,6 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { // returns a tuple of (status, updated startPos) private native int appendTextInput(String prompt, int bos, int eos); - /** - * Generate tokens from the given prompt, starting from the given position. - * - *

This is a deprecated API. Please use {@link #generate(String, int, LlmCallback, boolean)} - * - * @param prompt The text prompt to LLaVA. - * @param seqLen The total sequence length, including the prompt tokens and new tokens. - * @param startPos The starting position in KV cache of the input in the LLM. - * @param callback callback object to receive results. - * @param echo indicate whether to echo the input prompt or not. - * @return The error code. - */ - @Deprecated - public int generateFromPos( - String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) { - return generate(prompt, seqLen, callback, echo); - } - /** * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. * From eed96ba189e45915d91a9888a92a97511186972f Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 8 Sep 2025 21:49:24 -0700 Subject: [PATCH 05/10] Refactor generate --- .../executorch/extension/llm/LlmModule.java | 19 ++++++------ extension/android/jni/jni_layer_llama.cpp | 31 ++++++------------- 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 6599cb4c15d..fe6c61e8287 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -125,9 +125,7 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) { * @param llmCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { - return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo); - } + public native int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo); /** * Start generating tokens from the module. @@ -154,8 +152,7 @@ public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCa * @param llmCallback callback object to receive results. * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - @DoNotStrip - public native int generate( + public int generate( int[] image, int width, int height, @@ -163,7 +160,11 @@ public native int generate( String prompt, int seqLen, LlmCallback llmCallback, - boolean echo); + boolean echo) { + prefillPrompt(prompt, 0, 0, 0); + prefillImages(image, width, height, channels, 0); + generate("", llmCallback, echo); + } /** * Prefill an LLaVA Module with the given images input. @@ -208,15 +209,15 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { if (startPos == 0) { resetContext(); } - int nativeResult = appendTextInput(prompt, bos, eos); + int nativeResult = appendTextInput(prompt); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - // returns a tuple of (status, updated startPos) - private native int appendTextInput(String prompt, int bos, int eos); + // returns status + private native int appendTextInput(String prompt); /** * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 7cb827bf827..f83d944a188 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -208,10 +208,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { } jint generate( - facebook::jni::alias_ref image, - jint width, - jint height, - jint channels, facebook::jni::alias_ref prompt, jint seq_len, facebook::jni::alias_ref callback, @@ -219,18 +215,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { std::vector inputs = prefill_inputs_; prefill_inputs_.clear(); - inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - auto image_size = image->size(); - std::vector images; - if (image_size != 0) { - std::vector image_data_jint(image_size); - std::vector image_data(image_size); - image->getRegion(0, image_size, image_data_jint.data()); - for (int i = 0; i < image_size; i++) { - image_data[i] = image_data_jint[i]; - } - llm::Image image_runner{image_data, width, height, channels}; - inputs.emplace_back(llm::MultimodalInput{std::move(image_runner)}); + if (!prompt->toStdString().empty()) { + inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); } executorch::extension::llm::GenerationConfig config{ .echo = static_cast(echo), @@ -257,23 +243,24 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return 0; } - // Returns a tuple of (error, start_pos) + // Returns status_code // Contract is valid within an AAR (JNI + corresponding Java code) - // If the first element is not Error::Ok, the other element is undefined. - jint append_text_input( - facebook::jni::alias_ref prompt, - jint bos, - jint eos) { + jint append_text_input(facebook::jni::alias_ref prompt) { prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()}); return 0; } + // Returns status_code + // Contract is valid within an AAR (JNI + corresponding Java code) jint append_images_input( facebook::jni::alias_ref image, jint width, jint height, jint channels) { std::vector images; + if (image == nullptr) { + return static_cast(Error::EndOfMethod); + } auto image_size = image->size(); if (image_size != 0) { std::vector image_data_jint(image_size); From 19c77cdd9cfa6e301c0ee0d10f1b2e449d069e8a Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 10 Sep 2025 15:49:12 -0700 Subject: [PATCH 06/10] Fix --- .../java/org/pytorch/executorch/extension/llm/LlmModule.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index fe6c61e8287..00d568b66f6 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -163,7 +163,7 @@ public int generate( boolean echo) { prefillPrompt(prompt, 0, 0, 0); prefillImages(image, width, height, channels, 0); - generate("", llmCallback, echo); + return generate("", llmCallback, echo); } /** From d82455127547f060de32a48f7f8e8500d40792fb Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 10 Sep 2025 16:05:20 -0700 Subject: [PATCH 07/10] remove unused args --- .../example/executorchllamademo/MainActivity.java | 10 ++++------ .../executorch/extension/llm/LlmModule.java | 15 ++------------- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index 922f0598f1d..338bc94cd14 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -81,7 +81,6 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlmCall private Runnable memoryUpdater; private boolean mThinkMode = false; private int promptID = 0; - private long startPos = 0; private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2; private Executor executor; @@ -178,7 +177,8 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { ETLogging.getInstance().log("Llava start prefill prompt"); - startPos = mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt(), 0, 1, 0); + mModule.resetContext(); + mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt()); ETLogging.getInstance().log("Llava completes prefill prompt"); } } @@ -645,13 +645,11 @@ private void showMediaPreview(List uris) { ETLogging.getInstance().log("Starting runnable prefill image"); ETImage img = processedImageList.get(0); ETLogging.getInstance().log("Llava start prefill image"); - startPos = - mModule.prefillImages( + mModule.prefillImages( img.getInts(), img.getWidth(), img.getHeight(), - ModelUtils.VISION_MODEL_IMAGE_CHANNELS, - startPos); + ModelUtils.VISION_MODEL_IMAGE_CHANNELS); }; executor.execute(runnable); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 00d568b66f6..8fd164a1b7d 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -173,16 +173,12 @@ public int generate( * @param width Input image width * @param height Input image height * @param channels Input image number of channels - * @param startPos The starting position in KV cache of the input in the LLM. * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer * exposed to user. * @throws RuntimeException if the prefill failed */ @Deprecated - public long prefillImages(int[] image, int width, int height, int channels, long startPos) { - if (startPos == 0) { - resetContext(); - } + public long prefillImages(int[] image, int width, int height, int channels) { int nativeResult = appendImagesInput(image, width, height, channels); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); @@ -196,19 +192,12 @@ public long prefillImages(int[] image, int width, int height, int channels, long * Prefill an LLaVA Module with the given text input. * * @param prompt The text prompt to LLaVA. - * @param startPos The starting position in KV cache of the input in the LLM. It's passed as - * reference and will be updated inside this function. - * @param bos The number of BOS (begin of sequence) token. - * @param eos The number of EOS (end of sequence) token. * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer * exposed to user. * @throws RuntimeException if the prefill failed */ @Deprecated - public long prefillPrompt(String prompt, long startPos, int bos, int eos) { - if (startPos == 0) { - resetContext(); - } + public long prefillPrompt(String prompt) { int nativeResult = appendTextInput(prompt); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); From c48a08b545098bc566bc90e550d2db186e5a188f Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 10 Sep 2025 16:07:20 -0700 Subject: [PATCH 08/10] lint --- .../com/example/executorchllamademo/MainActivity.java | 8 ++++---- .../org/pytorch/executorch/extension/llm/LlmModule.java | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index 338bc94cd14..f995c5bc65a 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -646,10 +646,10 @@ private void showMediaPreview(List uris) { ETImage img = processedImageList.get(0); ETLogging.getInstance().log("Llava start prefill image"); mModule.prefillImages( - img.getInts(), - img.getWidth(), - img.getHeight(), - ModelUtils.VISION_MODEL_IMAGE_CHANNELS); + img.getInts(), + img.getWidth(), + img.getHeight(), + ModelUtils.VISION_MODEL_IMAGE_CHANNELS); }; executor.execute(runnable); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 8fd164a1b7d..f702e9d0e0d 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -161,10 +161,10 @@ public int generate( int seqLen, LlmCallback llmCallback, boolean echo) { - prefillPrompt(prompt, 0, 0, 0); - prefillImages(image, width, height, channels, 0); - return generate("", llmCallback, echo); - } + prefillPrompt(prompt, 0, 0, 0); + prefillImages(image, width, height, channels, 0); + return generate("", llmCallback, echo); + } /** * Prefill an LLaVA Module with the given images input. From 8ee21aff1f845f6859e592fc25b9acc48481d700 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 10 Sep 2025 16:12:02 -0700 Subject: [PATCH 09/10] Fix --- .../java/org/pytorch/executorch/extension/llm/LlmModule.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index f702e9d0e0d..ed51ff7ce57 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -161,7 +161,7 @@ public int generate( int seqLen, LlmCallback llmCallback, boolean echo) { - prefillPrompt(prompt, 0, 0, 0); + prefillPrompt(prompt); prefillImages(image, width, height, channels, 0); return generate("", llmCallback, echo); } From a036eb949ee4d3eda0bb8678da22e1e822cc29ab Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 10 Sep 2025 16:14:58 -0700 Subject: [PATCH 10/10] Fix --- .../java/org/pytorch/executorch/extension/llm/LlmModule.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index ed51ff7ce57..289df5defd9 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -162,7 +162,7 @@ public int generate( LlmCallback llmCallback, boolean echo) { prefillPrompt(prompt); - prefillImages(image, width, height, channels, 0); + prefillImages(image, width, height, channels); return generate("", llmCallback, echo); }