Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlmCall
private Runnable memoryUpdater;
private boolean mThinkMode = false;
private int promptID = 0;
private long startPos = 0;
private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2;
private Executor executor;

Expand Down Expand Up @@ -178,7 +177,8 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera

if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) {
ETLogging.getInstance().log("Llava start prefill prompt");
startPos = mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt(), 0, 1, 0);
mModule.resetContext();
mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt());
ETLogging.getInstance().log("Llava completes prefill prompt");
}
}
Expand Down Expand Up @@ -645,13 +645,11 @@ private void showMediaPreview(List<Uri> uris) {
ETLogging.getInstance().log("Starting runnable prefill image");
ETImage img = processedImageList.get(0);
ETLogging.getInstance().log("Llava start prefill image");
startPos =
mModule.prefillImages(
img.getInts(),
img.getWidth(),
img.getHeight(),
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
startPos);
mModule.prefillImages(
img.getInts(),
img.getWidth(),
img.getHeight(),
ModelUtils.VISION_MODEL_IMAGE_CHANNELS);
};
executor.execute(runnable);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) {
* @param llmCallback callback object to receive results
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
*/
public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) {
return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo);
}
public native int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo);

/**
* Start generating tokens from the module.
Expand All @@ -154,16 +152,19 @@ public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCa
* @param llmCallback callback object to receive results.
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
*/
@DoNotStrip
public native int generate(
public int generate(
int[] image,
int width,
int height,
int channels,
String prompt,
int seqLen,
LlmCallback llmCallback,
boolean echo);
boolean echo) {
prefillPrompt(prompt);
prefillImages(image, width, height, channels);
return generate("", llmCallback, echo);
}

/**
* Prefill an LLaVA Module with the given images input.
Expand All @@ -172,16 +173,12 @@ public native int generate(
* @param width Input image width
* @param height Input image height
* @param channels Input image number of channels
* @param startPos The starting position in KV cache of the input in the LLM.
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
* exposed to user.
* @throws RuntimeException if the prefill failed
*/
@Deprecated
public long prefillImages(int[] image, int width, int height, int channels, long startPos) {
if (startPos == 0) {
resetContext();
}
public long prefillImages(int[] image, int width, int height, int channels) {
int nativeResult = appendImagesInput(image, width, height, channels);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
Expand All @@ -195,28 +192,21 @@ public long prefillImages(int[] image, int width, int height, int channels, long
* Prefill an LLaVA Module with the given text input.
*
* @param prompt The text prompt to LLaVA.
* @param startPos The starting position in KV cache of the input in the LLM. It's passed as
* reference and will be updated inside this function.
* @param bos The number of BOS (begin of sequence) token.
* @param eos The number of EOS (end of sequence) token.
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
* exposed to user.
* @throws RuntimeException if the prefill failed
*/
@Deprecated
public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
if (startPos == 0) {
resetContext();
}
int nativeResult = appendTextInput(prompt, bos, eos);
public long prefillPrompt(String prompt) {
int nativeResult = appendTextInput(prompt);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}

// returns a tuple of (status, updated startPos)
private native int appendTextInput(String prompt, int bos, int eos);
// returns status
private native int appendTextInput(String prompt);

/**
* Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM.
Expand Down
30 changes: 8 additions & 22 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,29 +208,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
}

jint generate(
facebook::jni::alias_ref<jintArray> image,
jint width,
jint height,
jint channels,
facebook::jni::alias_ref<jstring> prompt,
jint seq_len,
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
jboolean echo) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
prefill_inputs_.clear();
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
auto image_size = image->size();
std::vector<llm::Image> images;
if (image_size != 0) {
std::vector<jint> image_data_jint(image_size);
std::vector<uint8_t> image_data(image_size);
image->getRegion(0, image_size, image_data_jint.data());
for (int i = 0; i < image_size; i++) {
image_data[i] = image_data_jint[i];
}
llm::Image image_runner{image_data, width, height, channels};
inputs.emplace_back(llm::MultimodalInput{std::move(image_runner)});
if (!prompt->toStdString().empty()) {
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
}
executorch::extension::llm::GenerationConfig config{
.echo = static_cast<bool>(echo),
Expand All @@ -257,23 +243,23 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
return 0;
}

// Returns a tuple of (error, start_pos)
// Returns status_code
// Contract is valid within an AAR (JNI + corresponding Java code)
// If the first element is not Error::Ok, the other element is undefined.
jint append_text_input(
facebook::jni::alias_ref<jstring> prompt,
jint bos,
jint eos) {
jint append_text_input(facebook::jni::alias_ref<jstring> prompt) {
prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
return 0;
}

// Returns status_code
jint append_images_input(
facebook::jni::alias_ref<jintArray> image,
jint width,
jint height,
jint channels) {
std::vector<llm::Image> images;
if (image == nullptr) {
return static_cast<jint>(Error::EndOfMethod);
}
auto image_size = image->size();
if (image_size != 0) {
std::vector<jint> image_data_jint(image_size);
Expand Down
Loading