Skip to content
110 changes: 15 additions & 95 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,98 +94,17 @@ Error Runner::load() {
eos_id_ = get_module_metadata<int64_t>(
module_.get(), "get_eos_id", tokenizer_->eos_tok());

// Create text decoder runner
// Create text decoder runner and prefiller
text_decoder_runner_ = std::make_unique<TextDecoderRunner>(
module_.get(), use_kv_cache_, vocab_size_, temperature_);

return Error::Ok;
}

Result<uint64_t> Runner::prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t start_pos,
std::function<void(const std::string&)> token_callback) {
// enable_parallel_prefill_ maybe set even when not using kv cache
// When kv cache is not used, start pos is ignored
int32_t num_prompt_tokens = prompt_tokens.size();

ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
ET_CHECK_MSG(
num_prompt_tokens < max_seq_len_,
"Max seq length exceeded - please increase max seq len value");

// store the token
uint64_t cur_token;
if (enable_parallel_prefill_ || !use_kv_cache_) {
// initialize tensor wrappers
ManagedTensor managed_tokens(
prompt_tokens.data(), {1, num_prompt_tokens}, ScalarType::Long);

ManagedTensor managed_start_pos(&start_pos, {1}, ScalarType::Long);

Result<exec_aten::Tensor> outputs_res =
text_decoder_runner_->step(managed_tokens, managed_start_pos);

ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_LOG(
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
ET_CHECK_MSG(
outputs_res.get().size(1) == num_prompt_tokens,
"Expected number of output tokens %d does not match returned value %zu.",
num_prompt_tokens,
outputs_res.get().size(1));
// insert new token into prompt_tokens
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
uint64_t prev = prompt_tokens[0];
uint64_t cur;
for (int i = 1; i < prompt_tokens.size(); i++) {
cur = prompt_tokens[i];
token_callback(ET_UNWRAP(tokenizer_->decode(prev, cur)));
prev = cur;
}
cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
} else { // sequential prefill
int64_t pos = 0; // position in the sequence
uint64_t prev_token;
// token & pos
int64_t pos_data = 0;
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
cur_token = prompt_tokens[0];

// initialize tensor wrappers
ManagedTensor managed_tokens(&cur_token, {1, 1}, ScalarType::Long);

ManagedTensor managed_start_pos(&pos_data, {1}, ScalarType::Long);

while (pos < num_prompt_tokens) {
// Run the model
pos_data = start_pos + pos;

Result<exec_aten::Tensor> logits_res =
text_decoder_runner_->step(managed_tokens, managed_start_pos);

ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
prev_token = cur_token;
text_prefiller_ = std::make_unique<TextPrefiller>(
tokenizer_.get(),
text_decoder_runner_.get(),
use_kv_cache_,
enable_parallel_prefill_);

pos++;

long sample_start_time_ms = util::time_in_ms();

cur_token = pos == num_prompt_tokens
? text_decoder_runner_->logits_to_token(logits_res.get())
: prompt_tokens[pos];

stats_.aggregate_sampling_time_ms +=
util::time_in_ms() - sample_start_time_ms;

// print the token as string, decode it with the Tokenizer object
token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token)));
}
}
// Return the next token
stats_.first_token_ms = util::time_in_ms();
stats_.prompt_eval_end_ms = util::time_in_ms();
return cur_token;
return Error::Ok;
}

Error Runner::generate(
Expand Down Expand Up @@ -242,9 +161,12 @@ Error Runner::generate(
// Prefill first
// Here feed all tokens to the model and get the next predicted token
// after the prompt. After that we will enter generate loop.
auto prefill_res = prefill(prompt_tokens, 0, wrapped_callback);
auto prefill_res =
text_prefiller_->prefill(prompt_tokens, 0, wrapped_callback);
stats_.first_token_ms = util::time_in_ms();
stats_.prompt_eval_end_ms = util::time_in_ms();
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
int64_t cur_token = prefill_res.get();
uint64_t cur_token = prefill_res.get();

// print the first token from prefill. No prev_token so use cur_token for it.
wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token)));
Expand All @@ -253,17 +175,15 @@ Error Runner::generate(
int64_t pos = num_prompt_tokens; // position in the sequence

// Generate the rest of the sequence
std::vector<int64_t> token_data; // allocate space for the tokens
std::vector<uint64_t> token_data; // allocate space for the tokens
std::vector<exec_aten::SizesType> token_shape;

if (use_kv_cache_) {
// hard code these to size 1 as kv cache is locked to static size right now.
token_data = {cur_token};
token_shape = {1, 1};
} else {
for (auto tok : prompt_tokens) {
token_data.push_back(tok);
}
token_data = prompt_tokens;
token_data.push_back(cur_token);
token_shape = {1, num_prompt_tokens + 1};
}
Expand All @@ -274,7 +194,7 @@ Error Runner::generate(

ManagedTensor start_pos_managed(&pos, {1}, ScalarType::Long);

int64_t prev_token;
uint64_t prev_token;

// Generate our tokens
while (pos < seq_len - 1) {
Expand Down
6 changes: 2 additions & 4 deletions examples/models/llama2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/llm/runner/text_prefiller.h>
#include <executorch/extension/llm/sampler/sampler.h>
#include <executorch/extension/llm/tokenizer/tokenizer.h>
#include <executorch/extension/module/module.h>
Expand All @@ -45,10 +46,6 @@ class Runner {
void stop();

private:
Result<uint64_t> prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t start_pos,
std::function<void(const std::string&)> token_callback);
// metadata
int32_t vocab_size_;
int32_t bos_id_;
Expand All @@ -68,6 +65,7 @@ class Runner {
std::string model_path_;
std::unique_ptr<Module> module_;
std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
std::unique_ptr<TextPrefiller> text_prefiller_;
std::string tokenizer_path_;
std::unique_ptr<Tokenizer> tokenizer_;

Expand Down
1 change: 1 addition & 0 deletions examples/models/llama2/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def define_common_targets():
"//executorch/backends/xnnpack:xnnpack_backend",
"//executorch/extension/llm/runner:stats",
"//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix,
"//executorch/extension/llm/runner:text_prefiller" + aten_suffix,
"//executorch/extension/evalue_util:print_evalue" + aten_suffix,
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
"//executorch/extension/module:module" + aten_suffix,
Expand Down
15 changes: 15 additions & 0 deletions extension/llm/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,18 @@ def define_common_targets():
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
],
)

runtime.cxx_library(
name = "text_prefiller" + aten_suffix,
exported_headers = ["text_prefiller.h"],
srcs = ["text_prefiller.cpp"],
visibility = [
"@EXECUTORCH_CLIENTS",
],
exported_deps = [
":text_decoder_runner" + aten_suffix,
"//executorch/extension/llm/tokenizer:tokenizer_header",
"//executorch/extension/module:module" + aten_suffix,
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
],
)
104 changes: 104 additions & 0 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// Given a text prompt, encode it using tokenizer and prefill the KV cache of a
// LLM.

#include <executorch/extension/llm/runner/text_prefiller.h>

namespace torch::executor {

TextPrefiller::TextPrefiller(
Tokenizer* tokenizer,
TextDecoderRunner* text_decoder_runner,
bool use_kv_cache,
bool enable_parallel_prefill)
: tokenizer_(tokenizer),
text_decoder_runner_(text_decoder_runner),
use_kv_cache_(use_kv_cache),
enable_parallel_prefill_(enable_parallel_prefill) {}

Result<uint64_t> TextPrefiller::prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t start_pos,
std::function<void(const std::string&)> token_callback) {
ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null");
if (!text_decoder_runner_->is_method_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
}
// enable_parallel_prefill_ maybe set even when not using kv cache
// When kv cache is not used, start pos is ignored
int32_t num_prompt_tokens = prompt_tokens.size();

// store the token
uint64_t cur_token;
if (enable_parallel_prefill_ || !use_kv_cache_) {
// initialize tensor wrappers
ManagedTensor managed_tokens(
prompt_tokens.data(), {1, num_prompt_tokens}, ScalarType::Long);

ManagedTensor managed_start_pos(&start_pos, {1}, ScalarType::Long);

Result<exec_aten::Tensor> outputs_res =
text_decoder_runner_->step(managed_tokens, managed_start_pos);

ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_LOG(
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
ET_CHECK_MSG(
outputs_res.get().size(1) == num_prompt_tokens,
"Expected number of output tokens %d does not match returned value %zu.",
num_prompt_tokens,
outputs_res.get().size(1));
// insert new token into prompt_tokens
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
uint64_t prev = prompt_tokens[0];
uint64_t cur;
for (int i = 1; i < prompt_tokens.size(); i++) {
cur = prompt_tokens[i];
token_callback(ET_UNWRAP(tokenizer_->decode(prev, cur)));
prev = cur;
}
cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
} else { // sequential prefill
int64_t pos = 0; // position in the sequence
int64_t prev_token;
// token & pos
int64_t pos_data = 0;
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
cur_token = prompt_tokens[0];

// initialize tensor wrappers
ManagedTensor managed_tokens(&cur_token, {1, 1}, ScalarType::Long);

ManagedTensor managed_start_pos(&pos_data, {1}, ScalarType::Long);

while (pos < num_prompt_tokens) {
// Run the model
pos_data = start_pos + pos;

Result<exec_aten::Tensor> logits_res =
text_decoder_runner_->step(managed_tokens, managed_start_pos);

ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
prev_token = cur_token;

pos++;

cur_token = pos == num_prompt_tokens
? text_decoder_runner_->logits_to_token(logits_res.get())
: prompt_tokens[pos];

// print the token as string, decode it with the Tokenizer object
token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token)));
}
}
return cur_token;
}

} // namespace torch::executor
50 changes: 50 additions & 0 deletions extension/llm/runner/text_prefiller.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// Given a text prompt, encode it using tokenizer and prefill the KV cache of a
// LLM.

#pragma once

#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/llm/tokenizer/tokenizer.h>
// patternlint-disable-next-line executorch-cpp-nostdinc
#include <functional>

namespace torch::executor {

class TextPrefiller {
public:
TextPrefiller(
Tokenizer* tokenizer,
TextDecoderRunner* text_decoder_runner,
bool use_kv_cache_,
bool enable_parallel_prefill);
/**
* Prefill an LLM Module with the given text input.
* @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by
* tokenizer.
* @param start_pos The starting position in KV cache of the input in the LLM
* Module.
* @param token_callback A callback function that will be called for each
* token in the prompt.
* @return The next token of the LLM Module after prefill.
*/
Result<uint64_t> prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t start_pos = 0,
std::function<void(const std::string&)> token_callback = {});

private:
Tokenizer* tokenizer_;
TextDecoderRunner* text_decoder_runner_;
bool use_kv_cache_;
bool enable_parallel_prefill_;
};

} // namespace torch::executor
Loading