Skip to content
Merged
98 changes: 16 additions & 82 deletions examples/models/llama2/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ Runner::Runner(
// NOTE: we observed ~2x loading performance increase on iPhone 15
// and a ~5% improvement on Galaxy S22 by switching to
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
: module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
tokenizer_path_(tokenizer_path),
temperature_(temperature) {
: temperature_(temperature),
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
tokenizer_path_(tokenizer_path) {
ET_LOG(
Info,
"Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
Expand All @@ -52,7 +52,7 @@ Runner::Runner(
}

bool Runner::is_loaded() const {
return module_->is_loaded() && tokenizer_ && sampler_;
return module_->is_loaded() && tokenizer_ && text_decoder_runner_;
}

Error Runner::load() {
Expand Down Expand Up @@ -94,42 +94,13 @@ Error Runner::load() {
eos_id_ = get_module_metadata<int64_t>(
module_.get(), "get_eos_id", tokenizer_->eos_tok());

// Create sampler
sampler_ = std::make_unique<Sampler>(
vocab_size_,
temperature_,
::executorch::llm::kTopp,
static_cast<unsigned long long>(std::time(nullptr)));
// Create text decoder runner
text_decoder_runner_ = std::make_unique<TextDecoderRunner>(
module_.get(), use_kv_cache_, vocab_size_, temperature_);

return Error::Ok;
}

int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) {
ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D");
auto num_tokens = logits_tensor.size(1);

switch (logits_tensor.scalar_type()) {
case ScalarType::Float: {
float* logits = logits_tensor.mutable_data_ptr<float>();
float* logits_last = logits;
logits_last += (num_tokens - 1) * tokenizer_->vocab_size();
return sampler_->sample(logits_last);
}
case ScalarType::Half: {
exec_aten::Half* logits =
logits_tensor.mutable_data_ptr<exec_aten::Half>();
exec_aten::Half* logits_last = logits;
logits_last += (num_tokens - 1) * tokenizer_->vocab_size();
return sampler_->sample(logits_last);
}
default:
ET_CHECK_MSG(
false,
"Unsupported dtype output %hhd",
static_cast<int8_t>(logits_tensor.scalar_type()));
}
}

Result<uint64_t> Runner::prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t start_pos,
Expand All @@ -153,7 +124,7 @@ Result<uint64_t> Runner::prefill(
ManagedTensor managed_start_pos(&start_pos, {1}, ScalarType::Long);

Result<exec_aten::Tensor> outputs_res =
run_model_step(managed_tokens, managed_start_pos);
text_decoder_runner_->step(managed_tokens, managed_start_pos);

ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_LOG(
Expand All @@ -172,7 +143,7 @@ Result<uint64_t> Runner::prefill(
token_callback(ET_UNWRAP(tokenizer_->decode(prev, cur)));
prev = cur;
}
cur_token = logitsToToken(outputs_res.get());
cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
} else { // sequential prefill
int64_t pos = 0; // position in the sequence
uint64_t prev_token;
Expand All @@ -191,16 +162,18 @@ Result<uint64_t> Runner::prefill(
pos_data = start_pos + pos;

Result<exec_aten::Tensor> logits_res =
run_model_step(managed_tokens, managed_start_pos);
text_decoder_runner_->step(managed_tokens, managed_start_pos);

ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
prev_token = cur_token;

pos++;

long sample_start_time_ms = util::time_in_ms();
cur_token = pos == num_prompt_tokens ? logitsToToken(logits_res.get())
: prompt_tokens[pos];

cur_token = pos == num_prompt_tokens
? text_decoder_runner_->logits_to_token(logits_res.get())
: prompt_tokens[pos];

stats_.aggregate_sampling_time_ms +=
util::time_in_ms() - sample_start_time_ms;
Expand All @@ -215,45 +188,6 @@ Result<uint64_t> Runner::prefill(
return cur_token;
}

// Given an input token. Set up the inputs for the model and execute a single
// step. Returning the logits tensor.
Result<exec_aten::Tensor> Runner::run_model_step(
ManagedTensor& managed_tokens,
ManagedTensor& managed_start_pos) {
// ET_LOG(Info, "Input token %" PRIu64, input_token);
auto tokens = managed_tokens.get_aliasing_tensor();
if (use_kv_cache_) {
auto start_pos = managed_start_pos.get_aliasing_tensor();

Result<std::vector<EValue>> outputs_res =
module_->forward({tokens, start_pos});
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_CHECK_MSG(
outputs_res.get().size() == 1,
"More then one output returned from executing LLM.");
ET_CHECK_MSG(
outputs_res.get()[0].isTensor(),
"Non Tensor Output returned from executing LLM");

// Return the logits tensor
return outputs_res.get()[0].toTensor();
} else { // no kv cache
(void)managed_start_pos; // unused

Result<std::vector<EValue>> outputs_res = module_->forward({tokens});
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_CHECK_MSG(
outputs_res.get().size() == 1,
"More then one output returned from executing LLM.");
ET_CHECK_MSG(
outputs_res.get()[0].isTensor(),
"Non Tensor Output returned from executing LLM");

// Return the logits tensor
return outputs_res.get()[0].toTensor();
}
}

Error Runner::generate(
const std::string& prompt,
int32_t seq_len,
Expand Down Expand Up @@ -346,15 +280,15 @@ Error Runner::generate(
while (pos < seq_len - 1) {
// Run the model
Result<exec_aten::Tensor> logits_res =
run_model_step(tokens_managed, start_pos_managed);
text_decoder_runner_->step(tokens_managed, start_pos_managed);

ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
exec_aten::Tensor& logits_tensor = logits_res.get();

prev_token = cur_token;

long sample_start_time_ms = util::time_in_ms();
cur_token = logitsToToken(logits_tensor);
cur_token = text_decoder_runner_->logits_to_token(logits_tensor);
stats_.aggregate_sampling_time_ms +=
util::time_in_ms() - sample_start_time_ms;

Expand Down
17 changes: 9 additions & 8 deletions examples/models/llama2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <unordered_map>

#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/llm/sampler/sampler.h>
#include <executorch/extension/llm/tokenizer/tokenizer.h>
#include <executorch/extension/module/module.h>
Expand All @@ -44,14 +45,10 @@ class Runner {
void stop();

private:
int32_t logitsToToken(const exec_aten::Tensor& logits_tensor);
Result<uint64_t> prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t start_pos,
std::function<void(const std::string&)> token_callback);
Result<exec_aten::Tensor> run_model_step(
ManagedTensor& managed_tokens,
ManagedTensor& managed_start_pos);
// metadata
int32_t vocab_size_;
int32_t bos_id_;
Expand All @@ -62,16 +59,20 @@ class Runner {
bool use_kv_cache_;
bool use_sdpa_with_kv_cache_;
bool append_eos_;
float temperature_;
bool enable_parallel_prefill_;
bool shouldStop_{false};

// model
std::unordered_set<std::string> model_methods_;
std::string model_path_;
std::unique_ptr<Module> module_;
std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
std::string tokenizer_path_;
float temperature_;
std::unique_ptr<Tokenizer> tokenizer_;
std::unique_ptr<Sampler> sampler_;
bool shouldStop_{false};

// stats
Stats stats_;
bool enable_parallel_prefill_;
};

} // namespace torch::executor
2 changes: 1 addition & 1 deletion examples/models/llama2/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def define_common_targets():
exported_deps = [
"//executorch/backends/xnnpack:xnnpack_backend",
"//executorch/extension/llm/runner:stats",
"//executorch/extension/llm/sampler:sampler" + aten_suffix,
"//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix,
"//executorch/extension/evalue_util:print_evalue" + aten_suffix,
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
"//executorch/extension/module:module" + aten_suffix,
Expand Down
18 changes: 18 additions & 0 deletions extension/llm/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,21 @@ def define_common_targets():
"@EXECUTORCH_CLIENTS",
],
)

for aten in (True, False):
aten_suffix = "_aten" if aten else ""

runtime.cxx_library(
name = "text_decoder_runner" + aten_suffix,
exported_headers = ["text_decoder_runner.h"],
srcs = ["text_decoder_runner.cpp"],
visibility = [
"@EXECUTORCH_CLIENTS",
],
exported_deps = [
":stats",
"//executorch/extension/llm/sampler:sampler" + aten_suffix,
"//executorch/extension/module:module" + aten_suffix,
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
],
)
72 changes: 72 additions & 0 deletions extension/llm/runner/text_decoder_runner.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// Given inputs, run a text decoder and return logits.

#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <ctime>

namespace torch::executor {

// NOTE: we observed ~2x loading performance increase on iPhone 15
// and a ~5% improvement on Galaxy S22 by switching to
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
TextDecoderRunner::TextDecoderRunner(
Module* module,
bool use_kv_cache,
int32_t vocab_size,
float temperature)
: module_(module),
sampler_(std::make_unique<Sampler>(
vocab_size,
temperature,
::executorch::llm::kTopp,
static_cast<unsigned long long>(std::time(nullptr)))),
use_kv_cache_(use_kv_cache) {}

// This function is functional, meaning it shouldn't modify any state of the
// input. It should be safe to call multiple times with the same inputs. The
// outer loop (call site) is responsible for managing state.
Result<exec_aten::Tensor> TextDecoderRunner::step(
ManagedTensor& managed_tokens,
ManagedTensor& managed_start_pos) {
auto tokens = managed_tokens.get_aliasing_tensor();
// ET_LOG(Info, "Input token %" PRIu64, input_token);
if (use_kv_cache_) {
auto start_pos = managed_start_pos.get_aliasing_tensor();
Result<std::vector<EValue>> outputs_res =
module_->forward({tokens, start_pos});
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_CHECK_MSG(
outputs_res.get().size() == 1,
"More then one output returned from executing LLM.");
ET_CHECK_MSG(
outputs_res.get()[0].isTensor(),
"Non Tensor Output returned from executing LLM");

// Return the logits tensor
return outputs_res.get()[0].toTensor();
} else { // no kv cache
(void)managed_start_pos; // unused

Result<std::vector<EValue>> outputs_res = module_->forward({tokens});
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_CHECK_MSG(
outputs_res.get().size() == 1,
"More then one output returned from executing LLM.");
ET_CHECK_MSG(
outputs_res.get()[0].isTensor(),
"Non Tensor Output returned from executing LLM");

// Return the logits tensor
return outputs_res.get()[0].toTensor();
}
}

} // namespace torch::executor
Loading