Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
6f9b741
[llava][13/N] Move metadata util to a separate header for reuse
larryliu0820 Aug 5, 2024
411e45d
Update on "[llava][13/N] Move metadata util to a separate header for …
larryliu0820 Aug 5, 2024
60e24ab
Update on "[llava][13/N] Move metadata util to a separate header for …
larryliu0820 Aug 6, 2024
4e81051
Update on "[llava][13/N] Move metadata util to a separate header for …
larryliu0820 Aug 6, 2024
5b76148
Update on "[llava][13/N] Move metadata util to a separate header for …
larryliu0820 Aug 6, 2024
2dad738
[llava][14/N] Refactor runner prefill() and run_model_step()
larryliu0820 Aug 6, 2024
94290d1
Update on "[llava][14/N] Refactor runner prefill() and run_model_step()"
larryliu0820 Aug 6, 2024
188a000
Update on "[llava][14/N] Refactor runner prefill() and run_model_step()"
larryliu0820 Aug 6, 2024
d963403
[llava][15/N] Extract out text decoder runner
larryliu0820 Aug 6, 2024
aa08c52
Update on "[llava][14/N] Refactor runner prefill() and run_model_step()"
larryliu0820 Aug 6, 2024
7b3c6d8
Update on "[llava][15/N] Extract out text decoder runner"
larryliu0820 Aug 6, 2024
5a43186
Update base for Update on "[llava][15/N] Extract out text decoder run…
larryliu0820 Aug 7, 2024
a1e6c5e
Update on "[llava][15/N] Extract out text decoder runner"
larryliu0820 Aug 7, 2024
20a6a23
[llava][16/N] Extract out prefill logic into a new class
larryliu0820 Aug 7, 2024
61c08d7
Update base for Update on "[llava][16/N] Extract out prefill logic in…
larryliu0820 Aug 7, 2024
5716659
Update on "[llava][16/N] Extract out prefill logic into a new class"
larryliu0820 Aug 7, 2024
0589bdf
Update base for Update on "[llava][16/N] Extract out prefill logic in…
larryliu0820 Aug 7, 2024
7f6b71d
Update on "[llava][16/N] Extract out prefill logic into a new class"
larryliu0820 Aug 7, 2024
99a29e2
Update base for Update on "[llava][16/N] Extract out prefill logic in…
larryliu0820 Aug 7, 2024
9a60fde
Update on "[llava][16/N] Extract out prefill logic into a new class"
larryliu0820 Aug 7, 2024
a52eeb4
Update base for Update on "[llava][17/N] Move util.h into /e/llm/runner"
larryliu0820 Aug 7, 2024
450c71d
Update base for Update on "[llava][17/N] Move util.h into /e/llm/runner"
larryliu0820 Aug 8, 2024
6edf64e
Update base for Update on "[llava][17/N] Move util.h into /e/llm/runner"
larryliu0820 Aug 8, 2024
e79c362
[llava][17/N] Move util.h into /e/llm/runner
larryliu0820 Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
334 changes: 61 additions & 273 deletions examples/models/llama2/runner/runner.cpp

Large diffs are not rendered by default.

26 changes: 11 additions & 15 deletions examples/models/llama2/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <unordered_map>

#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/llm/runner/text_prefiller.h>
#include <executorch/extension/llm/sampler/sampler.h>
#include <executorch/extension/llm/tokenizer/tokenizer.h>
#include <executorch/extension/module/module.h>
Expand All @@ -44,17 +46,6 @@ class Runner {
void stop();

private:
int32_t logitsToToken(const exec_aten::Tensor& logits_tensor);
Result<exec_aten::Tensor> prefill(
const std::vector<uint64_t>& tokens,
ManagedTensor& managed_tokens,
ManagedTensor& managed_start_pos,
std::function<void(const std::string&)> token_callback);
Result<exec_aten::Tensor> run_model_step(
int64_t input_token,
ManagedTensor& tokens,
ManagedTensor& start_pos,
size_t max_seq_len);
// metadata
int32_t vocab_size_;
int32_t bos_id_;
Expand All @@ -65,16 +56,21 @@ class Runner {
bool use_kv_cache_;
bool use_sdpa_with_kv_cache_;
bool append_eos_;
float temperature_;
bool enable_parallel_prefill_;
bool shouldStop_{false};

// model
std::unordered_set<std::string> model_methods_;
std::string model_path_;
std::unique_ptr<Module> module_;
std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
std::unique_ptr<TextPrefiller> text_prefiller_;
std::string tokenizer_path_;
float temperature_;
std::unique_ptr<Tokenizer> tokenizer_;
std::unique_ptr<Sampler> sampler_;
bool shouldStop_{false};

// stats
Stats stats_;
bool enable_parallel_prefill_;
};

} // namespace torch::executor
4 changes: 2 additions & 2 deletions examples/models/llama2/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def define_common_targets():
],
exported_headers = [
"runner.h",
"util.h",
],
preprocessor_flags = [
"-DUSE_ATEN_LIB",
Expand All @@ -34,7 +33,8 @@ def define_common_targets():
exported_deps = [
"//executorch/backends/xnnpack:xnnpack_backend",
"//executorch/extension/llm/runner:stats",
"//executorch/extension/llm/sampler:sampler" + aten_suffix,
"//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix,
"//executorch/extension/llm/runner:text_prefiller" + aten_suffix,
"//executorch/extension/evalue_util:print_evalue" + aten_suffix,
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
"//executorch/extension/module:module" + aten_suffix,
Expand Down
38 changes: 37 additions & 1 deletion extension/llm/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,44 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
def define_common_targets():
runtime.cxx_library(
name = "stats",
exported_headers = ["stats.h"],
exported_headers = [
"stats.h",
"util.h",
],
visibility = [
"@EXECUTORCH_CLIENTS",
],
)

for aten in (True, False):
aten_suffix = "_aten" if aten else ""

runtime.cxx_library(
name = "text_decoder_runner" + aten_suffix,
exported_headers = ["text_decoder_runner.h"],
srcs = ["text_decoder_runner.cpp"],
visibility = [
"@EXECUTORCH_CLIENTS",
],
exported_deps = [
":stats",
"//executorch/extension/llm/sampler:sampler" + aten_suffix,
"//executorch/extension/module:module" + aten_suffix,
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
],
)

runtime.cxx_library(
name = "text_prefiller" + aten_suffix,
exported_headers = ["text_prefiller.h"],
srcs = ["text_prefiller.cpp"],
visibility = [
"@EXECUTORCH_CLIENTS",
],
exported_deps = [
":text_decoder_runner" + aten_suffix,
"//executorch/extension/llm/tokenizer:tokenizer_header",
"//executorch/extension/module:module" + aten_suffix,
"//executorch/extension/runner_util:managed_tensor" + aten_suffix,
],
)
72 changes: 72 additions & 0 deletions extension/llm/runner/text_decoder_runner.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// Given inputs, run a text decoder and return logits.

#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <ctime>

namespace torch::executor {

// NOTE: we observed ~2x loading performance increase on iPhone 15
// and a ~5% improvement on Galaxy S22 by switching to
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
TextDecoderRunner::TextDecoderRunner(
Module* module,
bool use_kv_cache,
int32_t vocab_size,
float temperature)
: module_(module),
sampler_(std::make_unique<Sampler>(
vocab_size,
temperature,
::executorch::llm::kTopp,
static_cast<unsigned long long>(std::time(nullptr)))),
use_kv_cache_(use_kv_cache) {}

// This function is functional, meaning it shouldn't modify any state of the
// input. It should be safe to call multiple times with the same inputs. The
// outer loop (call site) is responsible for managing state.
Result<exec_aten::Tensor> TextDecoderRunner::step(
ManagedTensor& managed_tokens,
ManagedTensor& managed_start_pos) {
auto tokens = managed_tokens.get_aliasing_tensor();
// ET_LOG(Info, "Input token %" PRIu64, input_token);
if (use_kv_cache_) {
auto start_pos = managed_start_pos.get_aliasing_tensor();
Result<std::vector<EValue>> outputs_res =
module_->forward({tokens, start_pos});
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_CHECK_MSG(
outputs_res.get().size() == 1,
"More then one output returned from executing LLM.");
ET_CHECK_MSG(
outputs_res.get()[0].isTensor(),
"Non Tensor Output returned from executing LLM");

// Return the logits tensor
return outputs_res.get()[0].toTensor();
} else { // no kv cache
(void)managed_start_pos; // unused

Result<std::vector<EValue>> outputs_res = module_->forward({tokens});
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_CHECK_MSG(
outputs_res.get().size() == 1,
"More then one output returned from executing LLM.");
ET_CHECK_MSG(
outputs_res.get()[0].isTensor(),
"Non Tensor Output returned from executing LLM");

// Return the logits tensor
return outputs_res.get()[0].toTensor();
}
}

} // namespace torch::executor
95 changes: 95 additions & 0 deletions extension/llm/runner/text_decoder_runner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// Given inputs, run a text decoder in LLM and return the output.

#pragma once

#include <executorch/extension/llm/sampler/sampler.h>
#include <executorch/extension/module/module.h>
#include <executorch/extension/runner_util/managed_tensor.h>
// patternlint-disable-next-line executorch-cpp-nostdinc
#include <functional>

namespace torch::executor {

class TextDecoderRunner {
public:
TextDecoderRunner(
Module* module,
bool use_kv_cache,
int32_t vocab_size,
float temperature);
/**
* Run LLM text decoder with inputs to generate next token.
* @param input The input to the LLM Module.
* @param start_pos The starting position in KV cache of the input in the LLM
* Module.
* @return The output of the LLM Module. This will be a tensor of logits.
*/
Result<exec_aten::Tensor> step(
ManagedTensor& input,
ManagedTensor& start_pos);

/**
* Load the Module for a given method name.
* @param method_name The name of the method to load.
* @return The error code.
*/
inline Error load(const std::string& method_name = "forward") {
return module_->load_method(method_name);
}

/**
* Check if the Module is loaded.
* @return True if the Module is loaded, false otherwise.
*/
inline bool is_method_loaded(const std::string& method_name = "forward") {
return module_->is_method_loaded(method_name);
}

/**
* Sample the next token from the logits tensor.
* @param logits_tensor The logits tensor.
* @return The next token.
*/
inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) {
ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D");
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);

switch (logits_tensor.scalar_type()) {
case ScalarType::Float: {
float* logits = logits_tensor.mutable_data_ptr<float>();
float* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
}
case ScalarType::Half: {
exec_aten::Half* logits =
logits_tensor.mutable_data_ptr<exec_aten::Half>();
exec_aten::Half* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
}
default:
ET_CHECK_MSG(
false,
"Unsupported dtype output %hhd",
static_cast<int8_t>(logits_tensor.scalar_type()));
}
}

protected:
// TODO: use shared_ptr for module
Module* module_;
std::unique_ptr<Sampler> sampler_;
bool use_kv_cache_;
};

} // namespace torch::executor
104 changes: 104 additions & 0 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// Given a text prompt, encode it using tokenizer and prefill the KV cache of a
// LLM.

#include <executorch/extension/llm/runner/text_prefiller.h>

namespace torch::executor {

TextPrefiller::TextPrefiller(
Tokenizer* tokenizer,
TextDecoderRunner* text_decoder_runner,
bool use_kv_cache,
bool enable_parallel_prefill)
: tokenizer_(tokenizer),
text_decoder_runner_(text_decoder_runner),
use_kv_cache_(use_kv_cache),
enable_parallel_prefill_(enable_parallel_prefill) {}

Result<uint64_t> TextPrefiller::prefill(
std::vector<uint64_t>& prompt_tokens,
int64_t start_pos,
std::function<void(const std::string&)> token_callback) {
ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null");
if (!text_decoder_runner_->is_method_loaded()) {
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
}
// enable_parallel_prefill_ maybe set even when not using kv cache
// When kv cache is not used, start pos is ignored
int32_t num_prompt_tokens = prompt_tokens.size();

// store the token
uint64_t cur_token;
if (enable_parallel_prefill_ || !use_kv_cache_) {
// initialize tensor wrappers
ManagedTensor managed_tokens(
prompt_tokens.data(), {1, num_prompt_tokens}, ScalarType::Long);

ManagedTensor managed_start_pos(&start_pos, {1}, ScalarType::Long);

Result<exec_aten::Tensor> outputs_res =
text_decoder_runner_->step(managed_tokens, managed_start_pos);

ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_LOG(
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
ET_CHECK_MSG(
outputs_res.get().size(1) == num_prompt_tokens,
"Expected number of output tokens %d does not match returned value %zu.",
num_prompt_tokens,
outputs_res.get().size(1));
// insert new token into prompt_tokens
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
uint64_t prev = prompt_tokens[0];
uint64_t cur;
for (int i = 1; i < prompt_tokens.size(); i++) {
cur = prompt_tokens[i];
token_callback(ET_UNWRAP(tokenizer_->decode(prev, cur)));
prev = cur;
}
cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
} else { // sequential prefill
int64_t pos = 0; // position in the sequence
int64_t prev_token;
// token & pos
int64_t pos_data = 0;
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
cur_token = prompt_tokens[0];

// initialize tensor wrappers
ManagedTensor managed_tokens(&cur_token, {1, 1}, ScalarType::Long);

ManagedTensor managed_start_pos(&pos_data, {1}, ScalarType::Long);

while (pos < num_prompt_tokens) {
// Run the model
pos_data = start_pos + pos;

Result<exec_aten::Tensor> logits_res =
text_decoder_runner_->step(managed_tokens, managed_start_pos);

ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
prev_token = cur_token;

pos++;

cur_token = pos == num_prompt_tokens
? text_decoder_runner_->logits_to_token(logits_res.get())
: prompt_tokens[pos];

// print the token as string, decode it with the Tokenizer object
token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token)));
}
}
return cur_token;
}

} // namespace torch::executor
Loading