From efe81eebd5671b355cbf1d6cd8967641481e15c4 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 8 Sep 2025 21:05:55 -0700 Subject: [PATCH 1/5] LlmModule prefill refactor --- .../executorch/extension/llm/LlmModule.java | 46 +++++++++++++------ extension/android/jni/jni_layer_llama.cpp | 31 +++++-------- 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index b014ceb75d8..7c35dbf2989 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -173,20 +173,23 @@ public native int generate( * @param height Input image height * @param channels Input image number of channels * @param startPos The starting position in KV cache of the input in the LLM. - * @return The updated starting position in KV cache of the input in the LLM. + * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer + * exposed to user. * @throws RuntimeException if the prefill failed */ + @Deprecated public long prefillImages(int[] image, int width, int height, int channels, long startPos) { - long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos); - if (nativeResult[0] != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); + if (startPos == 0) { + resetContext(); } - return nativeResult[1]; + int nativeResult = prefillImagesNative(image, width, height, channels); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); + } + return 0; } - // returns a tuple of (status, updated startPos) - private native long[] prefillImagesNative( - int[] image, int width, int height, int channels, long startPos); + private native int prefillImagesNative(int[] image, int width, int height, int channels); /** * Prefill an LLaVA Module with the given text input. @@ -196,23 +199,30 @@ private native long[] prefillImagesNative( * reference and will be updated inside this function. * @param bos The number of BOS (begin of sequence) token. * @param eos The number of EOS (end of sequence) token. - * @return The updated starting position in KV cache of the input in the LLM. + * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer + * exposed to user. * @throws RuntimeException if the prefill failed */ + @Deprecated public long prefillPrompt(String prompt, long startPos, int bos, int eos) { - long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos); - if (nativeResult[0] != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); + if (startPos == 0) { + resetContext(); } - return nativeResult[1]; + int nativeResult = prefillPromptNative(prompt, bos, eos); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); + } + return 0; } // returns a tuple of (status, updated startPos) - private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos); + private native int prefillPromptNative(String prompt, int bos, int eos); /** * Generate tokens from the given prompt, starting from the given position. * + *

This is a deprecated API. Please use {@link #generate(String, int, LlmCallback, boolean)} + * * @param prompt The text prompt to LLaVA. * @param seqLen The total sequence length, including the prompt tokens and new tokens. * @param startPos The starting position in KV cache of the input in the LLM. @@ -220,9 +230,17 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { * @param echo indicate whether to echo the input prompt or not. * @return The error code. */ + @Deprecated public native int generateFromPos( String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo); + /** + * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. + * + *

The startPos will be reset to 0. + */ + public native void resetContext(); + /** Stop current generate() before it finishes. */ @DoNotStrip public native void stop(); diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 0c3550f151a..b4a9320e20b 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -260,28 +260,19 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { // Returns a tuple of (error, start_pos) // Contract is valid within an AAR (JNI + corresponding Java code) // If the first element is not Error::Ok, the other element is undefined. - facebook::jni::local_ref prefill_prompt( + jint prefill_prompt( facebook::jni::alias_ref prompt, - jlong start_pos, jint bos, jint eos) { prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - facebook::jni::local_ref tuple_result = - facebook::jni::make_long_array(2); - tuple_result->pin()[0] = static_cast(Error::Ok); - return tuple_result; + return 0; } - // Returns a tuple of (error, start_pos) - // Contract is valid within an AAR (JNI + corresponding Java code) - // If the first element is not Error::Ok, the other element is undefined. - - facebook::jni::local_ref prefill_images( + jint prefill_images( facebook::jni::alias_ref image, jint width, jint height, - jint channels, - jlong start_pos) { + jint channels) { std::vector images; auto image_size = image->size(); if (image_size != 0) { @@ -296,11 +287,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { llm::MultimodalInput{std::move(image_runner)}); } - facebook::jni::local_ref tuple_result = - facebook::jni::make_long_array(2); - - tuple_result->pin()[0] = static_cast(Error::Ok); - return tuple_result; + return 0; } jint generate_from_pos( @@ -325,9 +312,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { .seq_len = seq_len, .temperature = temperature_, }; - return static_cast(runner_->generate_from_pos( + return static_cast(runner_->generate( prompt->toStdString(), - start_pos, config, [callback](std::string result) { callback->onResult(result); }, [callback](const llm::Stats& stats) { callback->onStats(stats); })); @@ -343,6 +329,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { } } + void reset_context() { + runner_->reset(); + } + jint load() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { return static_cast(multi_modal_runner_->load()); @@ -364,6 +354,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { "prefillPromptNative", ExecuTorchLlmJni::prefill_prompt), makeNativeMethod( "generateFromPos", ExecuTorchLlmJni::generate_from_pos), + makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), }); } }; From 015a6ab78f7f9c80e1d3970b83b583d6de9f8557 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 8 Sep 2025 21:11:35 -0700 Subject: [PATCH 2/5] Doing some rename --- .../org/pytorch/executorch/extension/llm/LlmModule.java | 8 ++++---- extension/android/jni/jni_layer_llama.cpp | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 7c35dbf2989..e4be53f65cd 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -182,14 +182,14 @@ public long prefillImages(int[] image, int width, int height, int channels, long if (startPos == 0) { resetContext(); } - int nativeResult = prefillImagesNative(image, width, height, channels); + int nativeResult = appendImagesInput(image, width, height, channels); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private native int prefillImagesNative(int[] image, int width, int height, int channels); + private native int appendImagesInput(int[] image, int width, int height, int channels); /** * Prefill an LLaVA Module with the given text input. @@ -208,7 +208,7 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { if (startPos == 0) { resetContext(); } - int nativeResult = prefillPromptNative(prompt, bos, eos); + int nativeResult = appendTextInput(prompt, bos, eos); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } @@ -216,7 +216,7 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { } // returns a tuple of (status, updated startPos) - private native int prefillPromptNative(String prompt, int bos, int eos); + private native int appendTextInput(String prompt, int bos, int eos); /** * Generate tokens from the given prompt, starting from the given position. diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index b4a9320e20b..85d97ed2797 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -260,7 +260,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { // Returns a tuple of (error, start_pos) // Contract is valid within an AAR (JNI + corresponding Java code) // If the first element is not Error::Ok, the other element is undefined. - jint prefill_prompt( + jint append_text_input( facebook::jni::alias_ref prompt, jint bos, jint eos) { @@ -268,7 +268,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return 0; } - jint prefill_images( + jint append_images_input( facebook::jni::alias_ref image, jint width, jint height, @@ -349,9 +349,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { makeNativeMethod("stop", ExecuTorchLlmJni::stop), makeNativeMethod("load", ExecuTorchLlmJni::load), makeNativeMethod( - "prefillImagesNative", ExecuTorchLlmJni::prefill_images), + "appendImagesInput", ExecuTorchLlmJni::append_images_input), makeNativeMethod( - "prefillPromptNative", ExecuTorchLlmJni::prefill_prompt), + "appendTextInput", ExecuTorchLlmJni::append_text_input), makeNativeMethod( "generateFromPos", ExecuTorchLlmJni::generate_from_pos), makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), From 16b6d1cb564a3f5cf1fbb06c9ede5bcae19e2117 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 8 Sep 2025 21:14:59 -0700 Subject: [PATCH 3/5] Java layer no longer need a separate generateFromPos --- .../executorch/extension/llm/LlmModule.java | 6 ++-- extension/android/jni/jni_layer_llama.cpp | 33 ------------------- 2 files changed, 4 insertions(+), 35 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index e4be53f65cd..ec2f38bb7d3 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -231,8 +231,10 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { * @return The error code. */ @Deprecated - public native int generateFromPos( - String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo); + public int generateFromPos( + String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) { + return generate(prompt, seqLen, callback, echo); + } /** * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 85d97ed2797..7cb827bf827 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -290,37 +290,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return 0; } - jint generate_from_pos( - facebook::jni::alias_ref prompt, - jint seq_len, - jlong start_pos, - facebook::jni::alias_ref callback, - jboolean echo) { - if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - std::vector inputs = prefill_inputs_; - prefill_inputs_.clear(); - inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - return static_cast(multi_modal_runner_->generate( - inputs, - llm::GenerationConfig{ - .echo = static_cast(echo), .seq_len = seq_len}, - [callback](const std::string& result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); })); - } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { - executorch::extension::llm::GenerationConfig config{ - .echo = static_cast(echo), - .seq_len = seq_len, - .temperature = temperature_, - }; - return static_cast(runner_->generate( - prompt->toStdString(), - config, - [callback](std::string result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); })); - } - return static_cast(executorch::runtime::Error::InvalidArgument); - } - void stop() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_->stop(); @@ -352,8 +321,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { "appendImagesInput", ExecuTorchLlmJni::append_images_input), makeNativeMethod( "appendTextInput", ExecuTorchLlmJni::append_text_input), - makeNativeMethod( - "generateFromPos", ExecuTorchLlmJni::generate_from_pos), makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), }); } From 039d83172c923e147cfcb2c9d277191110490511 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 8 Sep 2025 22:40:35 -0700 Subject: [PATCH 4/5] Lint --- .../java/org/pytorch/executorch/extension/llm/LlmModule.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index ec2f38bb7d3..9494b0fe5cb 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -233,8 +233,8 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { @Deprecated public int generateFromPos( String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) { - return generate(prompt, seqLen, callback, echo); - } + return generate(prompt, seqLen, callback, echo); + } /** * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. From f6f69b4b22e33c2627671ab5a0e2e6433aa63930 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 10 Sep 2025 11:40:50 -0700 Subject: [PATCH 5/5] Fix --- .../org/pytorch/executorch/extension/llm/LlmModule.java | 7 ------- extension/android/jni/jni_layer_llama.cpp | 9 --------- 2 files changed, 16 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 4dc3dff2dce..9494b0fe5cb 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -243,13 +243,6 @@ public int generateFromPos( */ public native void resetContext(); - /** - * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. - * - *

The startPos will be reset to 0. - */ - public native void resetContext(); - /** Stop current generate() before it finishes. */ @DoNotStrip public native void stop(); diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 2fe3b71a918..331c20ee6f1 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -299,16 +299,12 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { } void reset_context() { -<<<<<<< HEAD - runner_->reset(); -======= if (runner_ != nullptr) { runner_->reset(); } if (multi_modal_runner_ != nullptr) { multi_modal_runner_->reset(); } ->>>>>>> origin/main } jint load() { @@ -330,11 +326,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { "appendImagesInput", ExecuTorchLlmJni::append_images_input), makeNativeMethod( "appendTextInput", ExecuTorchLlmJni::append_text_input), -<<<<<<< HEAD -======= - makeNativeMethod( - "generateFromPos", ExecuTorchLlmJni::generate_from_pos), ->>>>>>> origin/main makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), }); }