From 47b64059d08386c1622dc260e124a563cd5d7238 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Sep 2024 22:09:17 -0700 Subject: [PATCH 1/8] Add JNI layer for prefill API Need to add in Java layer next. Need to pass the ref to Java --- extension/android/jni/jni_layer_llama.cpp | 65 +++++++++++++++++++++++ extension/llm/runner/multimodal_runner.h | 44 +++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index dda9ece589d..5cfecdc56ff 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -180,6 +180,71 @@ class ExecuTorchLlamaJni return 0; } + jint prefill_prompt( + facebook::jni::alias_ref prompt, + jlong start_pos, + jint bos, + jint eos, + jlong generated_token) { + if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { + return static_cast(Error::NotSupported); + } + + auto&& result = multi_modal_runner_->prefill_prompt( + prompt->toStdString(), start_pos, bos, eos); + if (result.ok()) { + // TODO(hsz): make generated_token a reference and update it here + generated_token = result.get(); + return 0; + } + return static_cast(result.error()); + } + + jint prefill_images( + facebook::jni::alias_ref image, + jint width, + jint height, + jint channels, + jlong start_pos) { + if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { + return static_cast(Error::NotSupported); + } + + auto image_size = image->size(); + std::vector images; + if (image_size != 0) { + std::vector image_data_jint(image_size); + std::vector image_data(image_size); + image->getRegion(0, image_size, image_data_jint.data()); + for (int i = 0; i < image_size; i++) { + image_data[i] = image_data_jint[i]; + } + Image image_runner{image_data, width, height, channels}; + images.push_back(image_runner); + } + // TODO(hsz): make start_pos a reference and update it here + jint result = static_cast( + multi_modal_runner_->prefill_images(images, start_pos)); + } + + jint generate_from_pos( + facebook::jni::alias_ref prompt, + jint seq_len, + jlong start_pos, + facebook::jni::alias_ref callback) { + if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { + return static_cast(Error::NotSupported); + } + return static_cast(multi_modal_runner_->generate_from_pos( + prompt->toStdString(), + seq_len, + start_pos, + [callback](const std::string& result) { callback->onResult(result); }, + [callback](const ::executorch::extension::llm::Stats& stats) { + callback->onStats(stats); + })); + } + void stop() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_->stop(); diff --git a/extension/llm/runner/multimodal_runner.h b/extension/llm/runner/multimodal_runner.h index 43bbe688448..70ecafee810 100644 --- a/extension/llm/runner/multimodal_runner.h +++ b/extension/llm/runner/multimodal_runner.h @@ -61,6 +61,50 @@ class MultimodalRunner { std::function token_callback = {}, std::function stats_callback = {}) = 0; + /** + * Prefill an LLaVA Module with the given images input. + * @param images The image input to LLaVA. + * @param start_pos The starting position in KV cache of the input in the LLM. + * It's passed as reference and will be updated inside this function. + * @return The error status of prefilling images. + */ + virtual runtime::Error prefill_images( + std::vector& images, + int64_t& start_pos) = 0; + + /** + * Prefill an LLaVA Module with the given text input. + * @param prompt The text prompt to LLaVA. + * @param start_pos The starting position in KV cache of the input in the LLM. + * It's passed as reference and will be updated inside this function. + * @param bos The number of BOS (begin of sequence) token. + * @param eos The number of EOS (end of sequence) token. + * @return The generated token of the LLaVA Module after prefill prompt. + */ + virtual runtime::Result prefill_prompt( + const std::string& prompt, + int64_t& start_pos, + int8_t bos = 0, + int8_t eos = 0) = 0; + + /** + * Generate tokens from the given prompt, starting from the given position. + * @param prompt The text prompt to LLaVA. + * @param seq_len The total sequence length, including the prompt tokens and + * new tokens. + * @param start_pos The starting position in KV cache of the input in the LLM. + * @param token_callback What to do after a token is generated. + * @param stats_callback What to do with Stats. + * @return The error code. + */ + virtual runtime::Error generate_from_pos( + const std::string& prompt, + int32_t seq_len = 1024, + int64_t start_pos = 0, + std::function token_callback = {}, + std::function + stats_callback = {}) = 0; + inline void stop() { text_token_generator_->stop(); } From efef19e6ab946094633c81785ff1d68271d32861 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Sep 2024 22:33:22 -0700 Subject: [PATCH 2/8] Used a tuple for return values for passed in variable ref --- extension/android/jni/jni_layer_llama.cpp | 32 ++++++++++++++++++----- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 5cfecdc56ff..65ceaddfcfb 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -180,34 +180,51 @@ class ExecuTorchLlamaJni return 0; } - jint prefill_prompt( + // Returns a tuple of (error, token, start_pos) + // Contract is valid within an AAR (JNI + corresponding Java code) + // If the first element is not Error::Ok, the other two elements are + // undefined. + facebook::jni::local_ref prefill_prompt( facebook::jni::alias_ref prompt, jlong start_pos, jint bos, jint eos, jlong generated_token) { + facebook::jni::local_ref tuple_result = + facebook::jni::make_long_array(3); if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { - return static_cast(Error::NotSupported); + tuple_result->pin()[0] = static_cast(Error::NotSupported); + return tuple_result; } auto&& result = multi_modal_runner_->prefill_prompt( prompt->toStdString(), start_pos, bos, eos); + tuple_result->pin()[0] = static_cast(Error::Ok); if (result.ok()) { // TODO(hsz): make generated_token a reference and update it here generated_token = result.get(); - return 0; + tuple_result->pin()[1] = static_cast(generated_token); + tuple_result->pin()[2] = static_cast(start_pos); } - return static_cast(result.error()); + return tuple_result; } - jint prefill_images( + // Returns a tuple of (error, start_pos) + // Contract is valid within an AAR (JNI + corresponding Java code) + // If the first element is not Error::Ok, the other element is undefined. + + facebook::jni::local_ref prefill_images( facebook::jni::alias_ref image, jint width, jint height, jint channels, jlong start_pos) { + facebook::jni::local_ref tuple_result = + facebook::jni::make_long_array(2); + if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { - return static_cast(Error::NotSupported); + tuple_result->pin()[0] = static_cast(Error::NotSupported); + return tuple_result; } auto image_size = image->size(); @@ -225,6 +242,9 @@ class ExecuTorchLlamaJni // TODO(hsz): make start_pos a reference and update it here jint result = static_cast( multi_modal_runner_->prefill_images(images, start_pos)); + tuple_result->pin()[0] = result; + tuple_result->pin()[1] = static_cast(start_pos); + return tuple_result; } jint generate_from_pos( From 6715585b23614c882007e1c2f985c3ede2571edf Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 5 Sep 2024 22:39:33 -0700 Subject: [PATCH 3/8] Java part --- .../org/pytorch/executorch/LlamaModule.java | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java index bdc8506aa9c..b22ca49e332 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java +++ b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java @@ -94,6 +94,44 @@ public native int generate( int seqLen, LlamaCallback llamaCallback); + /** + * Prefill an LLaVA Module with the given text input. + * + * @param prompt The text prompt to LLaVA. + * @param startPos The starting position in KV cache of the input in the LLM. It's passed as + * reference and will be updated inside this function. + * @param bos The number of BOS (begin of sequence) token. + * @param eos The number of EOS (end of sequence) token. + * @return a tuple of (error, token, updated startPos) + */ + public static native long[] prefill_prompt( + String prompt, long startPos, int bos, int eos, long generatedToken); + + /** + * Prefill an LLaVA Module with the given images input. + * + * @param image Input image as a byte array + * @param width Input image width + * @param height Input image height + * @param channels Input image number of channels + * @param startPos The starting position in KV cache of the input in the LLM. + * @return a tuple of (error code, updated startPos) + */ + public static native long[] prefill_images( + int[] image, int width, int height, int channels, long startPos); + + /** + * Generate tokens from the given prompt, starting from the given position. + * + * @param prompt The text prompt to LLaVA. + * @param seqLen The total sequence length, including the prompt tokens and new tokens. + * @param startPos The starting position in KV cache of the input in the LLM. + * @param llamaCallback callback object to receive results. + * @return The error code. + */ + public static native int generate_from_pos( + String prompt, int seqLen, long startPos, ExecuTorchLlamaCallback callback); + /** Stop current generate() before it finishes. */ @DoNotStrip public native void stop(); From 9e24e2f171cf45115470986a194b9b0f20aac234 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Sep 2024 09:33:43 -0700 Subject: [PATCH 4/8] fix --- extension/android/jni/jni_layer_llama.cpp | 7 ++----- .../main/java/org/pytorch/executorch/LlamaModule.java | 10 +++++----- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 65ceaddfcfb..41d635e5d54 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -188,8 +188,7 @@ class ExecuTorchLlamaJni facebook::jni::alias_ref prompt, jlong start_pos, jint bos, - jint eos, - jlong generated_token) { + jint eos) { facebook::jni::local_ref tuple_result = facebook::jni::make_long_array(3); if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { @@ -201,9 +200,7 @@ class ExecuTorchLlamaJni prompt->toStdString(), start_pos, bos, eos); tuple_result->pin()[0] = static_cast(Error::Ok); if (result.ok()) { - // TODO(hsz): make generated_token a reference and update it here - generated_token = result.get(); - tuple_result->pin()[1] = static_cast(generated_token); + tuple_result->pin()[1] = static_cast(result.get()); tuple_result->pin()[2] = static_cast(start_pos); } return tuple_result; diff --git a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java index b22ca49e332..2504485e7d5 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java +++ b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java @@ -104,8 +104,8 @@ public native int generate( * @param eos The number of EOS (end of sequence) token. * @return a tuple of (error, token, updated startPos) */ - public static native long[] prefill_prompt( - String prompt, long startPos, int bos, int eos, long generatedToken); + public static native long[] prefillPrompt( + String prompt, long startPos, int bos, int eos); /** * Prefill an LLaVA Module with the given images input. @@ -117,7 +117,7 @@ public static native long[] prefill_prompt( * @param startPos The starting position in KV cache of the input in the LLM. * @return a tuple of (error code, updated startPos) */ - public static native long[] prefill_images( + public static native long[] prefillImages( int[] image, int width, int height, int channels, long startPos); /** @@ -129,8 +129,8 @@ public static native long[] prefill_images( * @param llamaCallback callback object to receive results. * @return The error code. */ - public static native int generate_from_pos( - String prompt, int seqLen, long startPos, ExecuTorchLlamaCallback callback); + public static native int generateFromPos( + String prompt, int seqLen, long startPos, LlamaCallback callback); /** Stop current generate() before it finishes. */ @DoNotStrip From f9755563f1441b895e79ae6fbacbc99f432b6b68 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Sep 2024 09:51:59 -0700 Subject: [PATCH 5/8] simplify user facing API to return startPos only --- .../org/pytorch/executorch/LlamaModule.java | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java index 2504485e7d5..503f9df7d34 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java +++ b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java @@ -94,19 +94,6 @@ public native int generate( int seqLen, LlamaCallback llamaCallback); - /** - * Prefill an LLaVA Module with the given text input. - * - * @param prompt The text prompt to LLaVA. - * @param startPos The starting position in KV cache of the input in the LLM. It's passed as - * reference and will be updated inside this function. - * @param bos The number of BOS (begin of sequence) token. - * @param eos The number of EOS (end of sequence) token. - * @return a tuple of (error, token, updated startPos) - */ - public static native long[] prefillPrompt( - String prompt, long startPos, int bos, int eos); - /** * Prefill an LLaVA Module with the given images input. * @@ -115,11 +102,36 @@ public static native long[] prefillPrompt( * @param height Input image height * @param channels Input image number of channels * @param startPos The starting position in KV cache of the input in the LLM. - * @return a tuple of (error code, updated startPos) + * @return The updated starting position in KV cache of the input in the LLM. */ - public static native long[] prefillImages( + public long prefillImages( + int[] image, int width, int height, int channels, long startPos) { + return prefillImagesNative(image, width, height, channels, startPos)[1]; + } + + // returns a tuple of (error code, updated startPos) + private native long[] prefillImagesNative( int[] image, int width, int height, int channels, long startPos); + /** + * Prefill an LLaVA Module with the given text input. + * + * @param prompt The text prompt to LLaVA. + * @param startPos The starting position in KV cache of the input in the LLM. It's passed as + * reference and will be updated inside this function. + * @param bos The number of BOS (begin of sequence) token. + * @param eos The number of EOS (end of sequence) token. + * @return The updated starting position in KV cache of the input in the LLM. + */ + public long prefillPrompt(String prompt, long startPos, int bos, int eos) { + return prefillPromptNative(prompt, startPos, bos, eos)[2]; + } + + + // returns a tuple of (error, token, updated startPos) + private native long[] prefillPromptNative( + String prompt, long startPos, int bos, int eos); + /** * Generate tokens from the given prompt, starting from the given position. * @@ -129,7 +141,7 @@ public static native long[] prefillImages( * @param llamaCallback callback object to receive results. * @return The error code. */ - public static native int generateFromPos( + public native int generateFromPos( String prompt, int seqLen, long startPos, LlamaCallback callback); /** Stop current generate() before it finishes. */ From d888c243eb1e099e50c4e7fddd7765b0aed18258 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Sep 2024 12:59:20 -0700 Subject: [PATCH 6/8] throw exception if native result is not OK --- .../org/pytorch/executorch/LlamaModule.java | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java index 503f9df7d34..bfa3d470f83 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java +++ b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java @@ -104,9 +104,12 @@ public native int generate( * @param startPos The starting position in KV cache of the input in the LLM. * @return The updated starting position in KV cache of the input in the LLM. */ - public long prefillImages( - int[] image, int width, int height, int channels, long startPos) { - return prefillImagesNative(image, width, height, channels, startPos)[1]; + public long prefillImages(int[] image, int width, int height, int channels, long startPos) { + long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos); + if (nativeResult[0] != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); + } + return nativeResult[1]; } // returns a tuple of (error code, updated startPos) @@ -124,13 +127,15 @@ private native long[] prefillImagesNative( * @return The updated starting position in KV cache of the input in the LLM. */ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { - return prefillPromptNative(prompt, startPos, bos, eos)[2]; + long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos); + if (nativeResult[0] != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); + } + return nativeResult[2]; } - // returns a tuple of (error, token, updated startPos) - private native long[] prefillPromptNative( - String prompt, long startPos, int bos, int eos); + private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos); /** * Generate tokens from the given prompt, starting from the given position. From 7b5a0bc604d6384ae874056ac322dc599a0854f7 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Sep 2024 13:00:11 -0700 Subject: [PATCH 7/8] Update docstring --- .../src/main/java/org/pytorch/executorch/LlamaModule.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java index bfa3d470f83..cd2ffa2ccdd 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java +++ b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java @@ -103,6 +103,7 @@ public native int generate( * @param channels Input image number of channels * @param startPos The starting position in KV cache of the input in the LLM. * @return The updated starting position in KV cache of the input in the LLM. + * @throws RuntimeException if the prefill failed */ public long prefillImages(int[] image, int width, int height, int channels, long startPos) { long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos); @@ -125,6 +126,7 @@ private native long[] prefillImagesNative( * @param bos The number of BOS (begin of sequence) token. * @param eos The number of EOS (end of sequence) token. * @return The updated starting position in KV cache of the input in the LLM. + * @throws RuntimeException if the prefill failed */ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos); From 8d375bfd511185f677d59274e9e35b1d0e41e80c Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 6 Sep 2024 13:03:49 -0700 Subject: [PATCH 8/8] Address comments about what to expose --- extension/android/jni/jni_layer_llama.cpp | 10 ++++------ .../main/java/org/pytorch/executorch/LlamaModule.java | 6 +++--- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 41d635e5d54..5f2cac188fc 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -180,17 +180,16 @@ class ExecuTorchLlamaJni return 0; } - // Returns a tuple of (error, token, start_pos) + // Returns a tuple of (error, start_pos) // Contract is valid within an AAR (JNI + corresponding Java code) - // If the first element is not Error::Ok, the other two elements are - // undefined. + // If the first element is not Error::Ok, the other element is undefined. facebook::jni::local_ref prefill_prompt( facebook::jni::alias_ref prompt, jlong start_pos, jint bos, jint eos) { facebook::jni::local_ref tuple_result = - facebook::jni::make_long_array(3); + facebook::jni::make_long_array(2); if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { tuple_result->pin()[0] = static_cast(Error::NotSupported); return tuple_result; @@ -200,8 +199,7 @@ class ExecuTorchLlamaJni prompt->toStdString(), start_pos, bos, eos); tuple_result->pin()[0] = static_cast(Error::Ok); if (result.ok()) { - tuple_result->pin()[1] = static_cast(result.get()); - tuple_result->pin()[2] = static_cast(start_pos); + tuple_result->pin()[1] = static_cast(start_pos); } return tuple_result; } diff --git a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java index cd2ffa2ccdd..e636c5f3f80 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java +++ b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java @@ -113,7 +113,7 @@ public long prefillImages(int[] image, int width, int height, int channels, long return nativeResult[1]; } - // returns a tuple of (error code, updated startPos) + // returns a tuple of (status, updated startPos) private native long[] prefillImagesNative( int[] image, int width, int height, int channels, long startPos); @@ -133,10 +133,10 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { if (nativeResult[0] != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); } - return nativeResult[2]; + return nativeResult[1]; } - // returns a tuple of (error, token, updated startPos) + // returns a tuple of (status, updated startPos) private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos); /**