diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index e4be53f65cd..9494b0fe5cb 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -231,8 +231,10 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { * @return The error code. */ @Deprecated - public native int generateFromPos( - String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo); + public int generateFromPos( + String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) { + return generate(prompt, seqLen, callback, echo); + } /** * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index d019313ca6a..331c20ee6f1 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -290,37 +290,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return 0; } - jint generate_from_pos( - facebook::jni::alias_ref prompt, - jint seq_len, - jlong start_pos, - facebook::jni::alias_ref callback, - jboolean echo) { - if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - std::vector inputs = prefill_inputs_; - prefill_inputs_.clear(); - inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - return static_cast(multi_modal_runner_->generate( - inputs, - llm::GenerationConfig{ - .echo = static_cast(echo), .seq_len = seq_len}, - [callback](const std::string& result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); })); - } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { - executorch::extension::llm::GenerationConfig config{ - .echo = static_cast(echo), - .seq_len = seq_len, - .temperature = temperature_, - }; - return static_cast(runner_->generate( - prompt->toStdString(), - config, - [callback](std::string result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); })); - } - return static_cast(executorch::runtime::Error::InvalidArgument); - } - void stop() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_->stop(); @@ -357,8 +326,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { "appendImagesInput", ExecuTorchLlmJni::append_images_input), makeNativeMethod( "appendTextInput", ExecuTorchLlmJni::append_text_input), - makeNativeMethod( - "generateFromPos", ExecuTorchLlmJni::generate_from_pos), makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), }); }