Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 104 additions & 170 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,133 +130,101 @@ int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) {
}
}

Result<exec_aten::Tensor> Runner::prefill(
const std::vector<uint64_t>& tokens,
ManagedTensor& managed_tokens,
ManagedTensor& managed_start_pos,
Result<uint64_t> Runner::prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t start_pos,
std::function<void(const std::string&)> token_callback) {
// enable_parallel_prefill_ maybe set even when not using kv cache
// When kv cache is not used, start pos is ignored
int32_t num_tokens = tokens.size();
if (enable_parallel_prefill_) {
managed_tokens.resize({1, num_tokens});
int64_t* tokens_ptr =
managed_tokens.get_aliasing_tensor().mutable_data_ptr<int64_t>();
for (int i = 0; i < num_tokens; i++) {
// The following assumes batch size = 1
tokens_ptr[i] = tokens[i];
}
std::vector<EValue> inputs;
auto tokens_tensor = managed_tokens.get_aliasing_tensor();
auto start_pos = managed_start_pos.get_aliasing_tensor();
int32_t num_prompt_tokens = prompt_tokens.size();

ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token");
ET_CHECK_MSG(
num_prompt_tokens < max_seq_len_,
"Max seq length exceeded - please increase max seq len value");

// store the token
uint64_t cur_token;
if (enable_parallel_prefill_ || !use_kv_cache_) {
// initialize tensor wrappers
ManagedTensor managed_tokens(
prompt_tokens.data(), {1, num_prompt_tokens}, ScalarType::Long);

// inputs:[tokens, start_pos]
inputs.push_back(tokens_tensor);
inputs.push_back(start_pos);
ManagedTensor managed_start_pos(&start_pos, {1}, ScalarType::Long);

Result<exec_aten::Tensor> outputs_res =
run_model_step(managed_tokens, managed_start_pos);

Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_LOG(
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
ET_CHECK_MSG(
outputs_res.get()[0].isTensor(),
"Non Tensor Output returned from executing LLM");
ET_CHECK_MSG(
outputs_res.get()[0].toTensor().size(1) == num_tokens,
outputs_res.get().size(1) == num_prompt_tokens,
"Expected number of output tokens %d does not match returned value %zu.",
num_tokens,
outputs_res.get()[0].toTensor().size(1));

start_pos.mutable_data_ptr<int64_t>()[0] = num_tokens;

uint64_t prev = tokens[0];
num_prompt_tokens,
outputs_res.get().size(1));
// insert new token into prompt_tokens
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
uint64_t prev = prompt_tokens[0];
uint64_t cur;
for (int i = 1; i < num_tokens; i++) {
cur = tokens[i];
auto piece_res = tokenizer_->decode(prev, cur);
ET_CHECK_OK_OR_RETURN_ERROR(piece_res.error());
util::safe_printf(piece_res.get().c_str());
fflush(stdout);
for (int i = 1; i < prompt_tokens.size(); i++) {
cur = prompt_tokens[i];
token_callback(ET_UNWRAP(tokenizer_->decode(prev, cur)));
prev = cur;
if (token_callback) {
token_callback(piece_res.get().c_str());
}
}
cur = logitsToToken(outputs_res.get()[0].toTensor());
auto piece_res = tokenizer_->decode(prev, cur);
ET_CHECK(piece_res.ok());
const char* piece = piece_res.get().c_str();
util::safe_printf(piece);
fflush(stdout);
if (token_callback) {
token_callback(piece_res.get().c_str());
}

// Return the logits tensor
stats_.first_token_ms = util::time_in_ms();
stats_.prompt_eval_end_ms = util::time_in_ms();
return outputs_res.get()[0].toTensor();
cur_token = logitsToToken(outputs_res.get());
} else { // sequential prefill
int64_t pos = 0; // position in the sequence
int64_t cur_token = tokens[0];
int64_t prev_token;
// This is a hack to enable returning a logits tensor from prefill
auto logits_tensor = managed_tokens.get_aliasing_tensor();
while (pos < num_tokens) {
uint64_t prev_token;
// token & pos
int64_t pos_data = 0;
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
cur_token = prompt_tokens[0];

// initialize tensor wrappers
ManagedTensor managed_tokens(&cur_token, {1, 1}, ScalarType::Long);

ManagedTensor managed_start_pos(&pos_data, {1}, ScalarType::Long);

while (pos < num_prompt_tokens) {
// Run the model
Result<exec_aten::Tensor> logits_res = run_model_step(
cur_token, managed_tokens, managed_start_pos, num_tokens);
pos_data = start_pos + pos;

ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
logits_tensor = logits_res.get();
// Hack to enable returning a logits tensor from prefill
Result<exec_aten::Tensor> logits_res =
run_model_step(managed_tokens, managed_start_pos);

ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
prev_token = cur_token;

pos++;

long sample_start_time_ms = util::time_in_ms();
cur_token = logitsToToken(logits_tensor);
cur_token = pos == num_prompt_tokens ? logitsToToken(logits_res.get())
: prompt_tokens[pos];

stats_.aggregate_sampling_time_ms +=
util::time_in_ms() - sample_start_time_ms;

// advance the state machine
if (pos < num_tokens - 1) {
// prefill, force the next token to be the next prompt token
cur_token = tokens[pos + 1];
}
pos++;

// print the token as string, decode it with the Tokenizer object
auto piece_res = tokenizer_->decode(prev_token, cur_token);
ET_CHECK(piece_res.ok());
const char* piece = piece_res.get().c_str();
util::safe_printf(piece);
fflush(stdout);
if (token_callback) {
token_callback(piece_res.get().c_str());
}
token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token)));
}
auto start_pos = managed_start_pos.get_aliasing_tensor();
start_pos.mutable_data_ptr<int64_t>()[0] = num_tokens;
stats_.first_token_ms = util::time_in_ms();
stats_.prompt_eval_end_ms = util::time_in_ms();
return logits_tensor;
}
// Return the next token
stats_.first_token_ms = util::time_in_ms();
stats_.prompt_eval_end_ms = util::time_in_ms();
return cur_token;
}

// Given an input token. Set up the inputs for the model and execute a single
// step. Returning the logits tensor.
Result<exec_aten::Tensor> Runner::run_model_step(
int64_t input_token,
ManagedTensor& managed_tokens,
ManagedTensor& managed_start_pos,
size_t max_seq_len) {
ManagedTensor& managed_start_pos) {
// ET_LOG(Info, "Input token %" PRIu64, input_token);
auto tokens = managed_tokens.get_aliasing_tensor();
if (use_kv_cache_) {
auto tokens = managed_tokens.get_aliasing_tensor();
auto start_pos = managed_start_pos.get_aliasing_tensor();

// When using kv-cache our input is always 1 token, so just update to the
// latest.
tokens.mutable_data_ptr<int64_t>()[0] = input_token;

Result<std::vector<EValue>> outputs_res =
module_->forward({tokens, start_pos});
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
Expand All @@ -267,25 +235,12 @@ Result<exec_aten::Tensor> Runner::run_model_step(
outputs_res.get()[0].isTensor(),
"Non Tensor Output returned from executing LLM");

// Bump start_pos by 1
start_pos.mutable_data_ptr<int64_t>()[0]++;

// Return the logits tensor
return outputs_res.get()[0].toTensor();
} else { // no kv cache
std::vector<EValue> inputs;
auto tokens = managed_tokens.get_aliasing_tensor();
(void)managed_start_pos; // unused

// When not using kv-cache our input is the entire history of tokens we have
// seen, so resize input to be 1 larger and append the new token to the end.
// TODO does this work in ATen mode?
tokens.mutable_data_ptr<int64_t>()[tokens.size(1) - 1] = input_token;

// inputs:[tokens]
inputs.push_back(tokens);

Result<std::vector<EValue>> outputs_res = module_->forward(inputs);
Result<std::vector<EValue>> outputs_res = module_->forward({tokens});
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_CHECK_MSG(
outputs_res.get().size() == 1,
Expand All @@ -294,14 +249,6 @@ Result<exec_aten::Tensor> Runner::run_model_step(
outputs_res.get()[0].isTensor(),
"Non Tensor Output returned from executing LLM");

if (tokens.size(1) < max_seq_len) {
// Resize the tokens tensor to be 1 larger for next step.
// Note that this relies on the fact that underlying memory is the same
// such that previous tokens stored there will still exist.
// Not a good thing to rely upon.
managed_tokens.resize({1, static_cast<int>(tokens.size(1) + 1)});
}

// Return the logits tensor
return outputs_res.get()[0].toTensor();
}
Expand All @@ -321,6 +268,15 @@ Error Runner::generate(
stats_.model_load_end_ms = util::time_in_ms();
}

// Wrap the token_callback with print function
std::function<void(const std::string&)> wrapped_callback =
[token_callback](const std::string& piece) {
util::safe_printf(piece.c_str());
fflush(stdout);
if (token_callback) {
token_callback(piece);
}
};
// First token time only measures the time it takes to encode the prompt and
// return a response token.

Expand Down Expand Up @@ -349,70 +305,48 @@ Error Runner::generate(
num_prompt_tokens < seq_len,
"Sequence length exceeded - please increase the seq_len value passed to generate()");

// Prefill first
// Here feed all tokens to the model and get the next predicted token
// after the prompt. After that we will enter generate loop.
auto prefill_res = prefill(prompt_tokens, 0, wrapped_callback);
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
int64_t cur_token = prefill_res.get();

// print the first token from prefill. No prev_token so use cur_token for it.
wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token)));

// start the main loop
int64_t pos = 0; // position in the sequence
int64_t pos = num_prompt_tokens; // position in the sequence

// Generate the rest of the sequence
std::vector<int64_t> token_data; // allocate space for the tokens
std::vector<exec_aten::SizesType> token_shape = {1, seq_len};

std::vector<int64_t> start_pos_data; // allocate space for the tokens
std::vector<exec_aten::SizesType> start_pos_shape = {1};
std::vector<exec_aten::SizesType> token_shape;

token_data.resize(seq_len);
if (use_kv_cache_) {
// hard code these to size 1 as kv cache is locked to static size right now.
start_pos_data.resize(1);
start_pos_data.push_back(0);
token_data = {cur_token};
token_shape = {1, 1};
} else {
for (auto tok : prompt_tokens) {
token_data.push_back(tok);
}
token_data.push_back(cur_token);
token_shape = {1, num_prompt_tokens + 1};
}

// initialize tensor wrappers
ManagedTensor tokens_managed(
token_data.data(), token_shape, ScalarType::Long);
// Create with the max shape to approapriately set the capacity of this
// tensor, then resize back to 1 for first input.
tokens_managed.resize({1, 1});

ManagedTensor start_pos_managed(
start_pos_data.data(), start_pos_shape, ScalarType::Long);
ManagedTensor start_pos_managed(&pos, {1}, ScalarType::Long);

int64_t prev_token;
int64_t cur_token = prompt_tokens[0];

// Prefill first
// Here feed all tokens to the model and get the next predicted token
// after the prompt. After that we will enter generate loop.
auto prefill_res =
prefill(prompt_tokens, tokens_managed, start_pos_managed, token_callback);
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
exec_aten::Tensor& prefill_res_tensor = prefill_res.get();
cur_token = logitsToToken(prefill_res_tensor);
if (use_kv_cache_) {
// Prefill could be parallel or sequential.
// Parallel:
// kv cache:
// - tokens_managed should resized to 1 as inference expects one token at
// a time.
// no kv cache:
// - tokens_managed should be resized to prompt length + 1, as inference
// expects all tokens at once.
// Sequential prefill:
// kv cache:
// - tokens_managed should be resized to 1, as inference expects one
// token at a time.
// no kv cache:
// - tokens_managed should be resized to prompt length + 1, as inference
// expects all tokens at once.
tokens_managed.resize({1, 1});
} else {
tokens_managed.resize({1, num_prompt_tokens + 1});
}
pos = num_prompt_tokens;

// Generate our tokens
while (pos < seq_len - 1) {
// Run the model
Result<exec_aten::Tensor> logits_res =
run_model_step(cur_token, tokens_managed, start_pos_managed, seq_len);
run_model_step(tokens_managed, start_pos_managed);

ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
exec_aten::Tensor& logits_tensor = logits_res.get();
Expand All @@ -426,19 +360,19 @@ Error Runner::generate(

pos++;

// print the token as string, decode it with the Tokenizer object
auto piece_res = tokenizer_->decode(prev_token, cur_token);
ET_CHECK(piece_res.ok());
const char* piece = piece_res.get().c_str();

// same as printf("%s", piece), but skips "unsafe" bytes
util::safe_printf(piece);
fflush(stdout);

if (token_callback) {
token_callback(piece);
if (use_kv_cache_) {
// update the token tensor. token_data will not be empty.
// NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
token_data[0] = cur_token;
} else {
// push it to the back
token_data.push_back(cur_token);
tokens_managed.resize({1, static_cast<int>(token_data.size())});
}

// print the token as string, decode it with the Tokenizer object
wrapped_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token)));

if (shouldStop_) {
break;
}
Expand Down
13 changes: 5 additions & 8 deletions examples/models/llama2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,13 @@ class Runner {

private:
int32_t logitsToToken(const exec_aten::Tensor& logits_tensor);
Result<exec_aten::Tensor> prefill(
const std::vector<uint64_t>& tokens,
ManagedTensor& managed_tokens,
ManagedTensor& managed_start_pos,
Result<uint64_t> prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t start_pos,
std::function<void(const std::string&)> token_callback);
Result<exec_aten::Tensor> run_model_step(
int64_t input_token,
ManagedTensor& tokens,
ManagedTensor& start_pos,
size_t max_seq_len);
ManagedTensor& managed_tokens,
ManagedTensor& managed_start_pos);
// metadata
int32_t vocab_size_;
int32_t bos_id_;
Expand Down