Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,29 @@ public long prefillImages(int[] image, int width, int height, int channels) {

private native int appendImagesInput(int[] image, int width, int height, int channels);

/**
* Prefill an LLaVA Module with the given images input.
*
* @param image Input normalized image as a float array
* @param width Input image width
* @param height Input image height
* @param channels Input image number of channels
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
* exposed to user.
* @throws RuntimeException if the prefill failed
*/
@Deprecated
public long prefillImages(float[] image, int width, int height, int channels) {
int nativeResult = appendNormalizedImagesInput(image, width, height, channels);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}

private native int appendNormalizedImagesInput(
float[] image, int width, int height, int channels);

/**
* Prefill an LLaVA Module with the given text input.
*
Expand Down
29 changes: 29 additions & 0 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,32 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
return 0;
}

// Returns status_code
jint append_normalized_images_input(
facebook::jni::alias_ref<jfloatArray> image,
jint width,
jint height,
jint channels) {
std::vector<llm::Image> images;
if (image == nullptr) {
return static_cast<jint>(Error::EndOfMethod);
}
auto image_size = image->size();
if (image_size != 0) {
std::vector<jfloat> image_data_jfloat(image_size);
std::vector<float> image_data(image_size);
image->getRegion(0, image_size, image_data_jfloat.data());
for (int i = 0; i < image_size; i++) {
image_data[i] = image_data_jfloat[i];
}
llm::Image image_runner{std::move(image_data), width, height, channels};
prefill_inputs_.emplace_back(
llm::MultimodalInput{std::move(image_runner)});
}

return 0;
}

void stop() {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
multi_modal_runner_->stop();
Expand Down Expand Up @@ -310,6 +336,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
makeNativeMethod("load", ExecuTorchLlmJni::load),
makeNativeMethod(
"appendImagesInput", ExecuTorchLlmJni::append_images_input),
makeNativeMethod(
"appendNormalizedImagesInput",
ExecuTorchLlmJni::append_normalized_images_input),
makeNativeMethod(
"appendTextInput", ExecuTorchLlmJni::append_text_input),
makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context),
Expand Down
Loading