diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index ad1c77a92b9..12a04ea7a13 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -297,16 +297,27 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jlong start_pos, facebook::jni::alias_ref callback, jboolean echo) { - if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { - return static_cast(Error::NotSupported); + if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { + return static_cast(multi_modal_runner_->generate_from_pos( + prompt->toStdString(), + seq_len, + start_pos, + [callback](const std::string& result) { callback->onResult(result); }, + [callback](const llm::Stats& stats) { callback->onStats(stats); }, + echo)); + } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { + executorch::extension::llm::GenerationConfig config{ + .echo = static_cast(echo), + .seq_len = seq_len, + .temperature = temperature_, + }; + runner_->generate_from_pos( + prompt->toStdString(), + start_pos, + config, + [callback](std::string result) { callback->onResult(result); }, + [callback](const llm::Stats& stats) { callback->onStats(stats); }); } - return static_cast(multi_modal_runner_->generate_from_pos( - prompt->toStdString(), - seq_len, - start_pos, - [callback](const std::string& result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); }, - echo)); } void stop() {