diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index 1b86c0bd215..a4c3205bf07 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -94,98 +94,17 @@ Error Runner::load() { eos_id_ = get_module_metadata( module_.get(), "get_eos_id", tokenizer_->eos_tok()); - // Create text decoder runner + // Create text decoder runner and prefiller text_decoder_runner_ = std::make_unique( module_.get(), use_kv_cache_, vocab_size_, temperature_); - return Error::Ok; -} - -Result Runner::prefill( - std::vector& prompt_tokens, - int64_t start_pos, - std::function token_callback) { - // enable_parallel_prefill_ maybe set even when not using kv cache - // When kv cache is not used, start pos is ignored - int32_t num_prompt_tokens = prompt_tokens.size(); - - ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token"); - ET_CHECK_MSG( - num_prompt_tokens < max_seq_len_, - "Max seq length exceeded - please increase max seq len value"); - - // store the token - uint64_t cur_token; - if (enable_parallel_prefill_ || !use_kv_cache_) { - // initialize tensor wrappers - ManagedTensor managed_tokens( - prompt_tokens.data(), {1, num_prompt_tokens}, ScalarType::Long); - - ManagedTensor managed_start_pos(&start_pos, {1}, ScalarType::Long); - - Result outputs_res = - text_decoder_runner_->step(managed_tokens, managed_start_pos); - - ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); - ET_LOG( - Info, "Prefill token result numel(): %zu", outputs_res.get().numel()); - ET_CHECK_MSG( - outputs_res.get().size(1) == num_prompt_tokens, - "Expected number of output tokens %d does not match returned value %zu.", - num_prompt_tokens, - outputs_res.get().size(1)); - // insert new token into prompt_tokens - // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) - uint64_t prev = prompt_tokens[0]; - uint64_t cur; - for (int i = 1; i < prompt_tokens.size(); i++) { - cur = prompt_tokens[i]; - token_callback(ET_UNWRAP(tokenizer_->decode(prev, cur))); - prev = cur; - } - cur_token = text_decoder_runner_->logits_to_token(outputs_res.get()); - } else { // sequential prefill - int64_t pos = 0; // position in the sequence - uint64_t prev_token; - // token & pos - int64_t pos_data = 0; - // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) - cur_token = prompt_tokens[0]; - - // initialize tensor wrappers - ManagedTensor managed_tokens(&cur_token, {1, 1}, ScalarType::Long); - - ManagedTensor managed_start_pos(&pos_data, {1}, ScalarType::Long); - - while (pos < num_prompt_tokens) { - // Run the model - pos_data = start_pos + pos; - - Result logits_res = - text_decoder_runner_->step(managed_tokens, managed_start_pos); - - ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); - prev_token = cur_token; + text_prefiller_ = std::make_unique( + tokenizer_.get(), + text_decoder_runner_.get(), + use_kv_cache_, + enable_parallel_prefill_); - pos++; - - long sample_start_time_ms = util::time_in_ms(); - - cur_token = pos == num_prompt_tokens - ? text_decoder_runner_->logits_to_token(logits_res.get()) - : prompt_tokens[pos]; - - stats_.aggregate_sampling_time_ms += - util::time_in_ms() - sample_start_time_ms; - - // print the token as string, decode it with the Tokenizer object - token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token))); - } - } - // Return the next token - stats_.first_token_ms = util::time_in_ms(); - stats_.prompt_eval_end_ms = util::time_in_ms(); - return cur_token; + return Error::Ok; } Error Runner::generate( @@ -242,9 +161,12 @@ Error Runner::generate( // Prefill first // Here feed all tokens to the model and get the next predicted token // after the prompt. After that we will enter generate loop. - auto prefill_res = prefill(prompt_tokens, 0, wrapped_callback); + auto prefill_res = + text_prefiller_->prefill(prompt_tokens, 0, wrapped_callback); + stats_.first_token_ms = util::time_in_ms(); + stats_.prompt_eval_end_ms = util::time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); - int64_t cur_token = prefill_res.get(); + uint64_t cur_token = prefill_res.get(); // print the first token from prefill. No prev_token so use cur_token for it. wrapped_callback(ET_UNWRAP(tokenizer_->decode(cur_token, cur_token))); @@ -253,7 +175,7 @@ Error Runner::generate( int64_t pos = num_prompt_tokens; // position in the sequence // Generate the rest of the sequence - std::vector token_data; // allocate space for the tokens + std::vector token_data; // allocate space for the tokens std::vector token_shape; if (use_kv_cache_) { @@ -261,9 +183,7 @@ Error Runner::generate( token_data = {cur_token}; token_shape = {1, 1}; } else { - for (auto tok : prompt_tokens) { - token_data.push_back(tok); - } + token_data = prompt_tokens; token_data.push_back(cur_token); token_shape = {1, num_prompt_tokens + 1}; } @@ -274,7 +194,7 @@ Error Runner::generate( ManagedTensor start_pos_managed(&pos, {1}, ScalarType::Long); - int64_t prev_token; + uint64_t prev_token; // Generate our tokens while (pos < seq_len - 1) { diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 3f02149ddef..7ce210d51c1 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -45,10 +46,6 @@ class Runner { void stop(); private: - Result prefill( - std::vector& prompt_tokens, - int64_t start_pos, - std::function token_callback); // metadata int32_t vocab_size_; int32_t bos_id_; @@ -68,6 +65,7 @@ class Runner { std::string model_path_; std::unique_ptr module_; std::unique_ptr text_decoder_runner_; + std::unique_ptr text_prefiller_; std::string tokenizer_path_; std::unique_ptr tokenizer_; diff --git a/examples/models/llama2/runner/targets.bzl b/examples/models/llama2/runner/targets.bzl index 534fcd7fdd9..e67c9abd910 100644 --- a/examples/models/llama2/runner/targets.bzl +++ b/examples/models/llama2/runner/targets.bzl @@ -35,6 +35,7 @@ def define_common_targets(): "//executorch/backends/xnnpack:xnnpack_backend", "//executorch/extension/llm/runner:stats", "//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix, + "//executorch/extension/llm/runner:text_prefiller" + aten_suffix, "//executorch/extension/evalue_util:print_evalue" + aten_suffix, "//executorch/extension/runner_util:managed_tensor" + aten_suffix, "//executorch/extension/module:module" + aten_suffix, diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 4c9a218b4b1..e0522c1ee57 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -26,3 +26,18 @@ def define_common_targets(): "//executorch/extension/runner_util:managed_tensor" + aten_suffix, ], ) + + runtime.cxx_library( + name = "text_prefiller" + aten_suffix, + exported_headers = ["text_prefiller.h"], + srcs = ["text_prefiller.cpp"], + visibility = [ + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + ":text_decoder_runner" + aten_suffix, + "//executorch/extension/llm/tokenizer:tokenizer_header", + "//executorch/extension/module:module" + aten_suffix, + "//executorch/extension/runner_util:managed_tensor" + aten_suffix, + ], + ) diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp new file mode 100644 index 00000000000..961c43d8c93 --- /dev/null +++ b/extension/llm/runner/text_prefiller.cpp @@ -0,0 +1,104 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Given a text prompt, encode it using tokenizer and prefill the KV cache of a +// LLM. + +#include + +namespace torch::executor { + +TextPrefiller::TextPrefiller( + Tokenizer* tokenizer, + TextDecoderRunner* text_decoder_runner, + bool use_kv_cache, + bool enable_parallel_prefill) + : tokenizer_(tokenizer), + text_decoder_runner_(text_decoder_runner), + use_kv_cache_(use_kv_cache), + enable_parallel_prefill_(enable_parallel_prefill) {} + +Result TextPrefiller::prefill( + std::vector& prompt_tokens, + int64_t start_pos, + std::function token_callback) { + ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null"); + if (!text_decoder_runner_->is_method_loaded()) { + ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load()); + } + // enable_parallel_prefill_ maybe set even when not using kv cache + // When kv cache is not used, start pos is ignored + int32_t num_prompt_tokens = prompt_tokens.size(); + + // store the token + uint64_t cur_token; + if (enable_parallel_prefill_ || !use_kv_cache_) { + // initialize tensor wrappers + ManagedTensor managed_tokens( + prompt_tokens.data(), {1, num_prompt_tokens}, ScalarType::Long); + + ManagedTensor managed_start_pos(&start_pos, {1}, ScalarType::Long); + + Result outputs_res = + text_decoder_runner_->step(managed_tokens, managed_start_pos); + + ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + ET_LOG( + Info, "Prefill token result numel(): %zu", outputs_res.get().numel()); + ET_CHECK_MSG( + outputs_res.get().size(1) == num_prompt_tokens, + "Expected number of output tokens %d does not match returned value %zu.", + num_prompt_tokens, + outputs_res.get().size(1)); + // insert new token into prompt_tokens + // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) + uint64_t prev = prompt_tokens[0]; + uint64_t cur; + for (int i = 1; i < prompt_tokens.size(); i++) { + cur = prompt_tokens[i]; + token_callback(ET_UNWRAP(tokenizer_->decode(prev, cur))); + prev = cur; + } + cur_token = text_decoder_runner_->logits_to_token(outputs_res.get()); + } else { // sequential prefill + int64_t pos = 0; // position in the sequence + int64_t prev_token; + // token & pos + int64_t pos_data = 0; + // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) + cur_token = prompt_tokens[0]; + + // initialize tensor wrappers + ManagedTensor managed_tokens(&cur_token, {1, 1}, ScalarType::Long); + + ManagedTensor managed_start_pos(&pos_data, {1}, ScalarType::Long); + + while (pos < num_prompt_tokens) { + // Run the model + pos_data = start_pos + pos; + + Result logits_res = + text_decoder_runner_->step(managed_tokens, managed_start_pos); + + ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); + prev_token = cur_token; + + pos++; + + cur_token = pos == num_prompt_tokens + ? text_decoder_runner_->logits_to_token(logits_res.get()) + : prompt_tokens[pos]; + + // print the token as string, decode it with the Tokenizer object + token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token))); + } + } + return cur_token; +} + +} // namespace torch::executor diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h new file mode 100644 index 00000000000..7293fdca2a4 --- /dev/null +++ b/extension/llm/runner/text_prefiller.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Given a text prompt, encode it using tokenizer and prefill the KV cache of a +// LLM. + +#pragma once + +#include +#include +// patternlint-disable-next-line executorch-cpp-nostdinc +#include + +namespace torch::executor { + +class TextPrefiller { + public: + TextPrefiller( + Tokenizer* tokenizer, + TextDecoderRunner* text_decoder_runner, + bool use_kv_cache_, + bool enable_parallel_prefill); + /** + * Prefill an LLM Module with the given text input. + * @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by + * tokenizer. + * @param start_pos The starting position in KV cache of the input in the LLM + * Module. + * @param token_callback A callback function that will be called for each + * token in the prompt. + * @return The next token of the LLM Module after prefill. + */ + Result prefill( + std::vector& prompt_tokens, + int64_t start_pos = 0, + std::function token_callback = {}); + + private: + Tokenizer* tokenizer_; + TextDecoderRunner* text_decoder_runner_; + bool use_kv_cache_; + bool enable_parallel_prefill_; +}; + +} // namespace torch::executor diff --git a/extension/llm/tokenizer/targets.bzl b/extension/llm/tokenizer/targets.bzl index 8229bced89e..f8e4df095ca 100644 --- a/extension/llm/tokenizer/targets.bzl +++ b/extension/llm/tokenizer/targets.bzl @@ -59,16 +59,29 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "tokenizer_header", + exported_headers = [ + "tokenizer.h", + ], + exported_deps = [ + "//executorch/runtime/core:core", + ], + visibility = [ + "@EXECUTORCH_CLIENTS", + ], + ) + runtime.cxx_library( name = "bpe_tokenizer", srcs = [ "bpe_tokenizer.cpp", ], exported_headers = [ - "tokenizer.h", "bpe_tokenizer.h", ], exported_deps = [ + ":tokenizer_header", "//executorch/runtime/core:core", ], visibility = [ @@ -82,11 +95,11 @@ def define_common_targets(): "tiktoken.cpp", ], exported_headers = [ - "tokenizer.h", "tiktoken.h", "base64.h", ], exported_deps = [ + ":tokenizer_header", "//executorch/runtime/core:core", ], visibility = [