Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,15 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa

@Override
public void onResult(String result) {
mResultMessage.appendText(result);
run();
if (result.equals("\n\n")) {
if (!mResultMessage.getText().isEmpty()) {
mResultMessage.appendText(result);
run();
}
} else {
mResultMessage.appendText(result);
run();
}
}

@Override
Expand Down Expand Up @@ -614,6 +621,7 @@ public void run() {
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
prompt,
ModelUtils.VISION_MODEL_SEQ_LEN,
false,
MainActivity.this);
} else {
// no image selected, we pass in empty int array
Expand All @@ -624,10 +632,12 @@ public void run() {
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
prompt,
ModelUtils.VISION_MODEL_SEQ_LEN,
false,
MainActivity.this);
}
} else {
mModule.generate(prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, MainActivity.this);
mModule.generate(
prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, false, MainActivity.this);
}

long generateDuration = System.currentTimeMillis() - generateStartTime;
Expand Down
1 change: 0 additions & 1 deletion examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ def build_args_parser() -> argparse.ArgumentParser:


def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:

path = str(path)

if verbose_export():
Expand Down
7 changes: 5 additions & 2 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ Error Runner::generate(
const std::string& prompt,
int32_t seq_len,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
std::function<void(const Stats&)> stats_callback,
bool echo) {
// Prepare the inputs.
// Use ones-initialized inputs.
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
Expand Down Expand Up @@ -203,7 +204,9 @@ Error Runner::generate(
// after the prompt. After that we will enter generate loop.

// print prompts
wrapped_callback(prompt);
if (echo) {
wrapped_callback(prompt);
}
int64_t pos = 0;
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
stats_.first_token_ms = util::time_in_ms();
Expand Down
3 changes: 2 additions & 1 deletion examples/models/llama2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class Runner {
const std::string& prompt,
int32_t seq_len = 128,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {});
std::function<void(const Stats&)> stats_callback = {},
bool echo = true);
void stop();

private:
Expand Down
4 changes: 3 additions & 1 deletion extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class ExecuTorchLlamaJni
jint channels,
facebook::jni::alias_ref<jstring> prompt,
jint seq_len,
jboolean echo,
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
auto image_size = image->size();
Expand All @@ -175,7 +176,8 @@ class ExecuTorchLlamaJni
prompt->toStdString(),
seq_len,
[callback](std::string result) { callback->onResult(result); },
[callback](const Stats& result) { callback->onStats(result); });
[callback](const Stats& result) { callback->onStats(result); },
echo);
}
return 0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class LlamaModule {

private final HybridData mHybridData;
private static final int DEFAULT_SEQ_LEN = 128;
private static final boolean DEFAULT_ECHO = true;

@DoNotStrip
private static native HybridData initHybrid(
Expand All @@ -59,7 +60,7 @@ public void resetNative() {
* @param llamaCallback callback object to receive results.
*/
public int generate(String prompt, LlamaCallback llamaCallback) {
return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback);
return generate(prompt, DEFAULT_SEQ_LEN, DEFAULT_ECHO, llamaCallback);
}

/**
Expand All @@ -70,7 +71,30 @@ public int generate(String prompt, LlamaCallback llamaCallback) {
* @param llamaCallback callback object to receive results.
*/
public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) {
return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback);
return generate(null, 0, 0, 0, prompt, seqLen, DEFAULT_ECHO, llamaCallback);
}

/**
* Start generating tokens from the module.
*
* @param prompt Input prompt
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
* @param llamaCallback callback object to receive results.
*/
public int generate(String prompt, boolean echo, LlamaCallback llamaCallback) {
return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, echo, llamaCallback);
}

/**
* Start generating tokens from the module.
*
* @param prompt Input prompt
* @param seqLen sequence length
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
* @param llamaCallback callback object to receive results.
*/
public int generate(String prompt, int seqLen, boolean echo, LlamaCallback llamaCallback) {
return generate(null, 0, 0, 0, prompt, seqLen, echo, llamaCallback);
}

/**
Expand All @@ -82,6 +106,7 @@ public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) {
* @param channels Input image number of channels
* @param prompt Input prompt
* @param seqLen sequence length
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
* @param llamaCallback callback object to receive results.
*/
@DoNotStrip
Expand All @@ -92,6 +117,7 @@ public native int generate(
int channels,
String prompt,
int seqLen,
boolean echo,
LlamaCallback llamaCallback);

/** Stop current generate() before it finishes. */
Expand Down
Loading