diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index 922f0598f1d..f995c5bc65a 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -81,7 +81,6 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlmCall private Runnable memoryUpdater; private boolean mThinkMode = false; private int promptID = 0; - private long startPos = 0; private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2; private Executor executor; @@ -178,7 +177,8 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { ETLogging.getInstance().log("Llava start prefill prompt"); - startPos = mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt(), 0, 1, 0); + mModule.resetContext(); + mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt()); ETLogging.getInstance().log("Llava completes prefill prompt"); } } @@ -645,13 +645,11 @@ private void showMediaPreview(List uris) { ETLogging.getInstance().log("Starting runnable prefill image"); ETImage img = processedImageList.get(0); ETLogging.getInstance().log("Llava start prefill image"); - startPos = - mModule.prefillImages( - img.getInts(), - img.getWidth(), - img.getHeight(), - ModelUtils.VISION_MODEL_IMAGE_CHANNELS, - startPos); + mModule.prefillImages( + img.getInts(), + img.getWidth(), + img.getHeight(), + ModelUtils.VISION_MODEL_IMAGE_CHANNELS); }; executor.execute(runnable); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 6599cb4c15d..289df5defd9 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -125,9 +125,7 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) { * @param llmCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { - return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo); - } + public native int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo); /** * Start generating tokens from the module. @@ -154,8 +152,7 @@ public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCa * @param llmCallback callback object to receive results. * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - @DoNotStrip - public native int generate( + public int generate( int[] image, int width, int height, @@ -163,7 +160,11 @@ public native int generate( String prompt, int seqLen, LlmCallback llmCallback, - boolean echo); + boolean echo) { + prefillPrompt(prompt); + prefillImages(image, width, height, channels); + return generate("", llmCallback, echo); + } /** * Prefill an LLaVA Module with the given images input. @@ -172,16 +173,12 @@ public native int generate( * @param width Input image width * @param height Input image height * @param channels Input image number of channels - * @param startPos The starting position in KV cache of the input in the LLM. * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer * exposed to user. * @throws RuntimeException if the prefill failed */ @Deprecated - public long prefillImages(int[] image, int width, int height, int channels, long startPos) { - if (startPos == 0) { - resetContext(); - } + public long prefillImages(int[] image, int width, int height, int channels) { int nativeResult = appendImagesInput(image, width, height, channels); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); @@ -195,28 +192,21 @@ public long prefillImages(int[] image, int width, int height, int channels, long * Prefill an LLaVA Module with the given text input. * * @param prompt The text prompt to LLaVA. - * @param startPos The starting position in KV cache of the input in the LLM. It's passed as - * reference and will be updated inside this function. - * @param bos The number of BOS (begin of sequence) token. - * @param eos The number of EOS (end of sequence) token. * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer * exposed to user. * @throws RuntimeException if the prefill failed */ @Deprecated - public long prefillPrompt(String prompt, long startPos, int bos, int eos) { - if (startPos == 0) { - resetContext(); - } - int nativeResult = appendTextInput(prompt, bos, eos); + public long prefillPrompt(String prompt) { + int nativeResult = appendTextInput(prompt); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - // returns a tuple of (status, updated startPos) - private native int appendTextInput(String prompt, int bos, int eos); + // returns status + private native int appendTextInput(String prompt); /** * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 331c20ee6f1..23686f01ee7 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -208,10 +208,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { } jint generate( - facebook::jni::alias_ref image, - jint width, - jint height, - jint channels, facebook::jni::alias_ref prompt, jint seq_len, facebook::jni::alias_ref callback, @@ -219,18 +215,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { std::vector inputs = prefill_inputs_; prefill_inputs_.clear(); - inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - auto image_size = image->size(); - std::vector images; - if (image_size != 0) { - std::vector image_data_jint(image_size); - std::vector image_data(image_size); - image->getRegion(0, image_size, image_data_jint.data()); - for (int i = 0; i < image_size; i++) { - image_data[i] = image_data_jint[i]; - } - llm::Image image_runner{image_data, width, height, channels}; - inputs.emplace_back(llm::MultimodalInput{std::move(image_runner)}); + if (!prompt->toStdString().empty()) { + inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); } executorch::extension::llm::GenerationConfig config{ .echo = static_cast(echo), @@ -257,23 +243,23 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return 0; } - // Returns a tuple of (error, start_pos) + // Returns status_code // Contract is valid within an AAR (JNI + corresponding Java code) - // If the first element is not Error::Ok, the other element is undefined. - jint append_text_input( - facebook::jni::alias_ref prompt, - jint bos, - jint eos) { + jint append_text_input(facebook::jni::alias_ref prompt) { prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()}); return 0; } + // Returns status_code jint append_images_input( facebook::jni::alias_ref image, jint width, jint height, jint channels) { std::vector images; + if (image == nullptr) { + return static_cast(Error::EndOfMethod); + } auto image_size = image->size(); if (image_size != 0) { std::vector image_data_jint(image_size);