Skip to content
Merged
78 changes: 35 additions & 43 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <unordered_map>
#include <vector>

#include <executorch/examples/models/llava/runner/llava_runner.h>
#include <executorch/extension/llm/runner/image.h>
#include <executorch/extension/llm/runner/irunner.h>
#include <executorch/extension/llm/runner/llm_runner_helper.h>
Expand Down Expand Up @@ -122,7 +121,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
float temperature_ = 0.0f;
int model_type_category_;
std::unique_ptr<llm::IRunner> runner_;
std::unique_ptr<example::LlavaRunner> multi_modal_runner_;
std::unique_ptr<executorch::extension::llm::MultimodalRunner>
multi_modal_runner_;
std::vector<llm::MultimodalInput> prefill_inputs_;

public:
constexpr static auto kJavaDescriptor =
Expand Down Expand Up @@ -168,10 +169,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {

model_type_category_ = model_type_category;
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
multi_modal_runner_ = std::make_unique<example::LlavaRunner>(
multi_modal_runner_ = llm::create_multimodal_runner(
model_path->toStdString().c_str(),
tokenizer_path->toStdString().c_str(),
temperature);
llm::load_tokenizer(tokenizer_path->toStdString()));
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
std::optional<const std::string> data_path_str = data_path
? std::optional<const std::string>{data_path->toStdString()}
Expand Down Expand Up @@ -217,6 +217,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
jboolean echo) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
prefill_inputs_.clear();
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
auto image_size = image->size();
std::vector<llm::Image> images;
if (image_size != 0) {
Expand All @@ -227,15 +230,18 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
image_data[i] = image_data_jint[i];
}
llm::Image image_runner{image_data, width, height, channels};
images.push_back(image_runner);
inputs.emplace_back(llm::MultimodalInput{std::move(image_runner)});
}
executorch::extension::llm::GenerationConfig config{
.echo = static_cast<bool>(echo),
.seq_len = seq_len,
.temperature = temperature_,
};
multi_modal_runner_->generate(
std::move(images),
prompt->toStdString(),
seq_len,
[callback](std::string result) { callback->onResult(result); },
[callback](const llm::Stats& result) { callback->onStats(result); },
echo);
std::move(inputs),
config,
[callback](const std::string& result) { callback->onResult(result); },
[callback](const llm::Stats& result) { callback->onStats(result); });
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
executorch::extension::llm::GenerationConfig config{
.echo = static_cast<bool>(echo),
Expand All @@ -259,19 +265,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
jlong start_pos,
jint bos,
jint eos) {
prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
return tuple_result;
}

auto&& result = multi_modal_runner_->prefill_prompt(
prompt->toStdString(), start_pos, bos, eos);
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
if (result.ok()) {
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
}
return tuple_result;
}

Expand All @@ -285,16 +282,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
jint height,
jint channels,
jlong start_pos) {
facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);

if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
return tuple_result;
}

auto image_size = image->size();
std::vector<llm::Image> images;
auto image_size = image->size();
if (image_size != 0) {
std::vector<jint> image_data_jint(image_size);
std::vector<uint8_t> image_data(image_size);
Expand All @@ -303,13 +292,14 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
image_data[i] = image_data_jint[i];
}
llm::Image image_runner{image_data, width, height, channels};
images.push_back(image_runner);
prefill_inputs_.emplace_back(
llm::MultimodalInput{std::move(image_runner)});
}
// TODO(hsz): make start_pos a reference and update it here
jint result = static_cast<jint>(
multi_modal_runner_->prefill_images(images, start_pos));
tuple_result->pin()[0] = result;
tuple_result->pin()[1] = static_cast<jlong>(start_pos);

facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);

tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
return tuple_result;
}

Expand All @@ -320,13 +310,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
jboolean echo) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
return static_cast<jint>(multi_modal_runner_->generate_from_pos(
prompt->toStdString(),
seq_len,
start_pos,
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
prefill_inputs_.clear();
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
return static_cast<jint>(multi_modal_runner_->generate(
inputs,
llm::GenerationConfig{
.echo = static_cast<bool>(echo), .seq_len = seq_len},
[callback](const std::string& result) { callback->onResult(result); },
[callback](const llm::Stats& stats) { callback->onStats(stats); },
echo));
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
executorch::extension::llm::GenerationConfig config{
.echo = static_cast<bool>(echo),
Expand Down
Loading