diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index f24254efb31..96b200303c9 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -73,8 +73,15 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa @Override public void onResult(String result) { - mResultMessage.appendText(result); - run(); + if (result.equals("\n\n")) { + if (!mResultMessage.getText().isEmpty()) { + mResultMessage.appendText(result); + run(); + } + } else { + mResultMessage.appendText(result); + run(); + } } @Override @@ -614,6 +621,7 @@ public void run() { ModelUtils.VISION_MODEL_IMAGE_CHANNELS, prompt, ModelUtils.VISION_MODEL_SEQ_LEN, + false, MainActivity.this); } else { // no image selected, we pass in empty int array @@ -624,10 +632,12 @@ public void run() { ModelUtils.VISION_MODEL_IMAGE_CHANNELS, prompt, ModelUtils.VISION_MODEL_SEQ_LEN, + false, MainActivity.this); } } else { - mModule.generate(prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, MainActivity.this); + mModule.generate( + prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, false, MainActivity.this); } long generateDuration = System.currentTimeMillis() - generateStartTime; diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 5dac3e9adbb..d6b9650c0ae 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -319,7 +319,6 @@ def build_args_parser() -> argparse.ArgumentParser: def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: - path = str(path) if verbose_export(): diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index 2c72b4c724e..6051040e8fd 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -143,7 +143,8 @@ Error Runner::generate( const std::string& prompt, int32_t seq_len, std::function token_callback, - std::function stats_callback) { + std::function stats_callback, + bool echo) { // Prepare the inputs. // Use ones-initialized inputs. ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); @@ -203,7 +204,9 @@ Error Runner::generate( // after the prompt. After that we will enter generate loop. // print prompts - wrapped_callback(prompt); + if (echo) { + wrapped_callback(prompt); + } int64_t pos = 0; auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos); stats_.first_token_ms = util::time_in_ms(); diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 4e3c1daef7b..cec8c61157f 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -40,7 +40,8 @@ class Runner { const std::string& prompt, int32_t seq_len = 128, std::function token_callback = {}, - std::function stats_callback = {}); + std::function stats_callback = {}, + bool echo = true); void stop(); private: diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index dda9ece589d..002bb4836dd 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -150,6 +150,7 @@ class ExecuTorchLlamaJni jint channels, facebook::jni::alias_ref prompt, jint seq_len, + jboolean echo, facebook::jni::alias_ref callback) { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { auto image_size = image->size(); @@ -175,7 +176,8 @@ class ExecuTorchLlamaJni prompt->toStdString(), seq_len, [callback](std::string result) { callback->onResult(result); }, - [callback](const Stats& result) { callback->onStats(result); }); + [callback](const Stats& result) { callback->onStats(result); }, + echo); } return 0; } diff --git a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java index bdc8506aa9c..da396820c42 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java +++ b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java @@ -33,6 +33,7 @@ public class LlamaModule { private final HybridData mHybridData; private static final int DEFAULT_SEQ_LEN = 128; + private static final boolean DEFAULT_ECHO = true; @DoNotStrip private static native HybridData initHybrid( @@ -59,7 +60,7 @@ public void resetNative() { * @param llamaCallback callback object to receive results. */ public int generate(String prompt, LlamaCallback llamaCallback) { - return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback); + return generate(prompt, DEFAULT_SEQ_LEN, DEFAULT_ECHO, llamaCallback); } /** @@ -70,7 +71,30 @@ public int generate(String prompt, LlamaCallback llamaCallback) { * @param llamaCallback callback object to receive results. */ public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) { - return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback); + return generate(null, 0, 0, 0, prompt, seqLen, DEFAULT_ECHO, llamaCallback); + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + * @param llamaCallback callback object to receive results. + */ + public int generate(String prompt, boolean echo, LlamaCallback llamaCallback) { + return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, echo, llamaCallback); + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param seqLen sequence length + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + * @param llamaCallback callback object to receive results. + */ + public int generate(String prompt, int seqLen, boolean echo, LlamaCallback llamaCallback) { + return generate(null, 0, 0, 0, prompt, seqLen, echo, llamaCallback); } /** @@ -82,6 +106,7 @@ public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) { * @param channels Input image number of channels * @param prompt Input prompt * @param seqLen sequence length + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) * @param llamaCallback callback object to receive results. */ @DoNotStrip @@ -92,6 +117,7 @@ public native int generate( int channels, String prompt, int seqLen, + boolean echo, LlamaCallback llamaCallback); /** Stop current generate() before it finishes. */