diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index dda9ece589d..5f2cac188fc 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -180,6 +180,86 @@ class ExecuTorchLlamaJni return 0; } + // Returns a tuple of (error, start_pos) + // Contract is valid within an AAR (JNI + corresponding Java code) + // If the first element is not Error::Ok, the other element is undefined. + facebook::jni::local_ref prefill_prompt( + facebook::jni::alias_ref prompt, + jlong start_pos, + jint bos, + jint eos) { + facebook::jni::local_ref tuple_result = + facebook::jni::make_long_array(2); + if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { + tuple_result->pin()[0] = static_cast(Error::NotSupported); + return tuple_result; + } + + auto&& result = multi_modal_runner_->prefill_prompt( + prompt->toStdString(), start_pos, bos, eos); + tuple_result->pin()[0] = static_cast(Error::Ok); + if (result.ok()) { + tuple_result->pin()[1] = static_cast(start_pos); + } + return tuple_result; + } + + // Returns a tuple of (error, start_pos) + // Contract is valid within an AAR (JNI + corresponding Java code) + // If the first element is not Error::Ok, the other element is undefined. + + facebook::jni::local_ref prefill_images( + facebook::jni::alias_ref image, + jint width, + jint height, + jint channels, + jlong start_pos) { + facebook::jni::local_ref tuple_result = + facebook::jni::make_long_array(2); + + if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { + tuple_result->pin()[0] = static_cast(Error::NotSupported); + return tuple_result; + } + + auto image_size = image->size(); + std::vector images; + if (image_size != 0) { + std::vector image_data_jint(image_size); + std::vector image_data(image_size); + image->getRegion(0, image_size, image_data_jint.data()); + for (int i = 0; i < image_size; i++) { + image_data[i] = image_data_jint[i]; + } + Image image_runner{image_data, width, height, channels}; + images.push_back(image_runner); + } + // TODO(hsz): make start_pos a reference and update it here + jint result = static_cast( + multi_modal_runner_->prefill_images(images, start_pos)); + tuple_result->pin()[0] = result; + tuple_result->pin()[1] = static_cast(start_pos); + return tuple_result; + } + + jint generate_from_pos( + facebook::jni::alias_ref prompt, + jint seq_len, + jlong start_pos, + facebook::jni::alias_ref callback) { + if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { + return static_cast(Error::NotSupported); + } + return static_cast(multi_modal_runner_->generate_from_pos( + prompt->toStdString(), + seq_len, + start_pos, + [callback](const std::string& result) { callback->onResult(result); }, + [callback](const ::executorch::extension::llm::Stats& stats) { + callback->onStats(stats); + })); + } + void stop() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_->stop(); diff --git a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java index bdc8506aa9c..e636c5f3f80 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java +++ b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java @@ -94,6 +94,63 @@ public native int generate( int seqLen, LlamaCallback llamaCallback); + /** + * Prefill an LLaVA Module with the given images input. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @param startPos The starting position in KV cache of the input in the LLM. + * @return The updated starting position in KV cache of the input in the LLM. + * @throws RuntimeException if the prefill failed + */ + public long prefillImages(int[] image, int width, int height, int channels, long startPos) { + long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos); + if (nativeResult[0] != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); + } + return nativeResult[1]; + } + + // returns a tuple of (status, updated startPos) + private native long[] prefillImagesNative( + int[] image, int width, int height, int channels, long startPos); + + /** + * Prefill an LLaVA Module with the given text input. + * + * @param prompt The text prompt to LLaVA. + * @param startPos The starting position in KV cache of the input in the LLM. It's passed as + * reference and will be updated inside this function. + * @param bos The number of BOS (begin of sequence) token. + * @param eos The number of EOS (end of sequence) token. + * @return The updated starting position in KV cache of the input in the LLM. + * @throws RuntimeException if the prefill failed + */ + public long prefillPrompt(String prompt, long startPos, int bos, int eos) { + long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos); + if (nativeResult[0] != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); + } + return nativeResult[1]; + } + + // returns a tuple of (status, updated startPos) + private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos); + + /** + * Generate tokens from the given prompt, starting from the given position. + * + * @param prompt The text prompt to LLaVA. + * @param seqLen The total sequence length, including the prompt tokens and new tokens. + * @param startPos The starting position in KV cache of the input in the LLM. + * @param llamaCallback callback object to receive results. + * @return The error code. + */ + public native int generateFromPos( + String prompt, int seqLen, long startPos, LlamaCallback callback); + /** Stop current generate() before it finishes. */ @DoNotStrip public native void stop(); diff --git a/extension/llm/runner/multimodal_runner.h b/extension/llm/runner/multimodal_runner.h index 43bbe688448..70ecafee810 100644 --- a/extension/llm/runner/multimodal_runner.h +++ b/extension/llm/runner/multimodal_runner.h @@ -61,6 +61,50 @@ class MultimodalRunner { std::function token_callback = {}, std::function stats_callback = {}) = 0; + /** + * Prefill an LLaVA Module with the given images input. + * @param images The image input to LLaVA. + * @param start_pos The starting position in KV cache of the input in the LLM. + * It's passed as reference and will be updated inside this function. + * @return The error status of prefilling images. + */ + virtual runtime::Error prefill_images( + std::vector& images, + int64_t& start_pos) = 0; + + /** + * Prefill an LLaVA Module with the given text input. + * @param prompt The text prompt to LLaVA. + * @param start_pos The starting position in KV cache of the input in the LLM. + * It's passed as reference and will be updated inside this function. + * @param bos The number of BOS (begin of sequence) token. + * @param eos The number of EOS (end of sequence) token. + * @return The generated token of the LLaVA Module after prefill prompt. + */ + virtual runtime::Result prefill_prompt( + const std::string& prompt, + int64_t& start_pos, + int8_t bos = 0, + int8_t eos = 0) = 0; + + /** + * Generate tokens from the given prompt, starting from the given position. + * @param prompt The text prompt to LLaVA. + * @param seq_len The total sequence length, including the prompt tokens and + * new tokens. + * @param start_pos The starting position in KV cache of the input in the LLM. + * @param token_callback What to do after a token is generated. + * @param stats_callback What to do with Stats. + * @return The error code. + */ + virtual runtime::Error generate_from_pos( + const std::string& prompt, + int32_t seq_len = 1024, + int64_t start_pos = 0, + std::function token_callback = {}, + std::function + stats_callback = {}) = 0; + inline void stop() { text_token_generator_->stop(); }