From d4c573363c526342f4cac51a91222e6902eb39b0 Mon Sep 17 00:00:00 2001 From: cmodi-meta <98582575+cmodi-meta@users.noreply.github.com> Date: Fri, 30 Aug 2024 01:39:52 -0700 Subject: [PATCH 1/7] Add Echo parameter to llama runner and jni+java layer --- examples/models/llama2/runner/runner.cpp | 5 +++- examples/models/llama2/runner/runner.h | 1 + extension/android/jni/jni_layer_llama.cpp | 2 ++ .../org/pytorch/executorch/LlamaModule.java | 30 +++++++++++++++++-- 4 files changed, 35 insertions(+), 3 deletions(-) diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index 2c72b4c724e..2d5ed12ab11 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -142,6 +142,7 @@ Error Runner::load() { Error Runner::generate( const std::string& prompt, int32_t seq_len, + bool echo, std::function token_callback, std::function stats_callback) { // Prepare the inputs. @@ -212,7 +213,9 @@ Error Runner::generate( uint64_t cur_token = prefill_res.get(); // print the first token from prefill. No prev_token so use cur_token for it. - wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token))); + if (echo) { + wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token))); + } // start the main loop prompt_tokens.push_back(cur_token); diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 4e3c1daef7b..476171db023 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -39,6 +39,7 @@ class Runner { Error generate( const std::string& prompt, int32_t seq_len = 128, + bool echo = true, std::function token_callback = {}, std::function stats_callback = {}); void stop(); diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index dda9ece589d..f46e25746df 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -150,6 +150,7 @@ class ExecuTorchLlamaJni jint channels, facebook::jni::alias_ref prompt, jint seq_len, + jboolean echo, facebook::jni::alias_ref callback) { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { auto image_size = image->size(); @@ -174,6 +175,7 @@ class ExecuTorchLlamaJni runner_->generate( prompt->toStdString(), seq_len, + echo, [callback](std::string result) { callback->onResult(result); }, [callback](const Stats& result) { callback->onStats(result); }); } diff --git a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java index bdc8506aa9c..da396820c42 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java +++ b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java @@ -33,6 +33,7 @@ public class LlamaModule { private final HybridData mHybridData; private static final int DEFAULT_SEQ_LEN = 128; + private static final boolean DEFAULT_ECHO = true; @DoNotStrip private static native HybridData initHybrid( @@ -59,7 +60,7 @@ public void resetNative() { * @param llamaCallback callback object to receive results. */ public int generate(String prompt, LlamaCallback llamaCallback) { - return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback); + return generate(prompt, DEFAULT_SEQ_LEN, DEFAULT_ECHO, llamaCallback); } /** @@ -70,7 +71,30 @@ public int generate(String prompt, LlamaCallback llamaCallback) { * @param llamaCallback callback object to receive results. */ public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) { - return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback); + return generate(null, 0, 0, 0, prompt, seqLen, DEFAULT_ECHO, llamaCallback); + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + * @param llamaCallback callback object to receive results. + */ + public int generate(String prompt, boolean echo, LlamaCallback llamaCallback) { + return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, echo, llamaCallback); + } + + /** + * Start generating tokens from the module. + * + * @param prompt Input prompt + * @param seqLen sequence length + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) + * @param llamaCallback callback object to receive results. + */ + public int generate(String prompt, int seqLen, boolean echo, LlamaCallback llamaCallback) { + return generate(null, 0, 0, 0, prompt, seqLen, echo, llamaCallback); } /** @@ -82,6 +106,7 @@ public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) { * @param channels Input image number of channels * @param prompt Input prompt * @param seqLen sequence length + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) * @param llamaCallback callback object to receive results. */ @DoNotStrip @@ -92,6 +117,7 @@ public native int generate( int channels, String prompt, int seqLen, + boolean echo, LlamaCallback llamaCallback); /** Stop current generate() before it finishes. */ From 2388966ad98ffd43eb5893d21cdc952aca97522e Mon Sep 17 00:00:00 2001 From: Riandy Riandy Date: Thu, 5 Sep 2024 11:08:04 -0700 Subject: [PATCH 2/7] Update Android app to support echo flag in generate Summary: Echo flag is added in the runner side via https://github.com/pytorch/executorch/pull/5011 by Chirag. Now, we update the app side to leverage the new echo flag, so that we don't display the user prompt in response. Differential Revision: D62250116 --- .../java/com/example/executorchllamademo/MainActivity.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index f24254efb31..0d8553bdb4c 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -614,6 +614,7 @@ public void run() { ModelUtils.VISION_MODEL_IMAGE_CHANNELS, prompt, ModelUtils.VISION_MODEL_SEQ_LEN, + false, MainActivity.this); } else { // no image selected, we pass in empty int array @@ -624,10 +625,12 @@ public void run() { ModelUtils.VISION_MODEL_IMAGE_CHANNELS, prompt, ModelUtils.VISION_MODEL_SEQ_LEN, + false, MainActivity.this); } } else { - mModule.generate(prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, MainActivity.this); + mModule.generate( + prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, false, MainActivity.this); } long generateDuration = System.currentTimeMillis() - generateStartTime; From 3fd201e43c6baf547dd4180b6e7bdabfabfffb16 Mon Sep 17 00:00:00 2001 From: cmodi-meta <98582575+cmodi-meta@users.noreply.github.com> Date: Thu, 5 Sep 2024 12:27:33 -0700 Subject: [PATCH 3/7] Placing echo as last argument to perserve default value --- examples/models/llama2/runner/runner.cpp | 4 ++-- examples/models/llama2/runner/runner.h | 4 ++-- extension/android/jni/jni_layer_llama.cpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index 2d5ed12ab11..b8a5b4df278 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -142,9 +142,9 @@ Error Runner::load() { Error Runner::generate( const std::string& prompt, int32_t seq_len, - bool echo, std::function token_callback, - std::function stats_callback) { + std::function stats_callback, + bool echo) { // Prepare the inputs. // Use ones-initialized inputs. ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 476171db023..cec8c61157f 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -39,9 +39,9 @@ class Runner { Error generate( const std::string& prompt, int32_t seq_len = 128, - bool echo = true, std::function token_callback = {}, - std::function stats_callback = {}); + std::function stats_callback = {}, + bool echo = true); void stop(); private: diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index f46e25746df..002bb4836dd 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -175,9 +175,9 @@ class ExecuTorchLlamaJni runner_->generate( prompt->toStdString(), seq_len, - echo, [callback](std::string result) { callback->onResult(result); }, - [callback](const Stats& result) { callback->onStats(result); }); + [callback](const Stats& result) { callback->onStats(result); }, + echo); } return 0; } From def65b1ec2530c52367b0fa847914c093c609e94 Mon Sep 17 00:00:00 2001 From: cmodi-meta <98582575+cmodi-meta@users.noreply.github.com> Date: Thu, 5 Sep 2024 22:56:39 -0700 Subject: [PATCH 4/7] Remove echo condition from first token from prefill gen --- examples/models/llama2/runner/runner.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index b8a5b4df278..586550c0895 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -213,9 +213,7 @@ Error Runner::generate( uint64_t cur_token = prefill_res.get(); // print the first token from prefill. No prev_token so use cur_token for it. - if (echo) { - wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token))); - } + wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token))); // start the main loop prompt_tokens.push_back(cur_token); From 621f5b9508f7ace9b0ba8cfe610d379d0a754fa2 Mon Sep 17 00:00:00 2001 From: cmodi-meta <98582575+cmodi-meta@users.noreply.github.com> Date: Thu, 5 Sep 2024 23:29:47 -0700 Subject: [PATCH 5/7] Avoid printing newlines on output in app --- .../com/example/executorchllamademo/MainActivity.java | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index 0d8553bdb4c..4d3f14b840c 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -73,8 +73,15 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa @Override public void onResult(String result) { - mResultMessage.appendText(result); - run(); + if(result.equals("\n\n")) { + if(!mResultMessage.getText().isEmpty()) { + mResultMessage.appendText(result); + run(); + } + } else { + mResultMessage.appendText(result); + run(); + } } @Override From a4b63e0dd57a2a1a20bcb298057ee8c956dcfe3e Mon Sep 17 00:00:00 2001 From: cmodi-meta <98582575+cmodi-meta@users.noreply.github.com> Date: Fri, 6 Sep 2024 11:26:20 -0700 Subject: [PATCH 6/7] Rebase echo parameter PR --- examples/models/llama2/runner/runner.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index 586550c0895..6051040e8fd 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -204,7 +204,9 @@ Error Runner::generate( // after the prompt. After that we will enter generate loop. // print prompts - wrapped_callback(prompt); + if (echo) { + wrapped_callback(prompt); + } int64_t pos = 0; auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos); stats_.first_token_ms = util::time_in_ms(); From d081b02876e47556b9027de3ad6d99cc1183ed33 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 6 Sep 2024 12:05:00 -0700 Subject: [PATCH 7/7] Lint Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- .../java/com/example/executorchllamademo/MainActivity.java | 4 ++-- examples/models/llama2/export_llama_lib.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index 4d3f14b840c..96b200303c9 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -73,8 +73,8 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa @Override public void onResult(String result) { - if(result.equals("\n\n")) { - if(!mResultMessage.getText().isEmpty()) { + if (result.equals("\n\n")) { + if (!mResultMessage.getText().isEmpty()) { mResultMessage.appendText(result); run(); } diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 5dac3e9adbb..d6b9650c0ae 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -319,7 +319,6 @@ def build_args_parser() -> argparse.ArgumentParser: def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str: - path = str(path) if verbose_export():