From 21f8a0748595f5544e17dbc701efa8971de7387b Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 31 Mar 2025 15:43:53 -0700 Subject: [PATCH 1/5] [ExecuTorch][Llama] Change runner to enable chunked prefill This diff adds code to chunk prompt longer than max_seq_len to enable prefill of larger context Differential Revision: [D71833061](https://our.internmc.facebook.com/intern/diff/D71833061/) [ghstack-poisoned] --- examples/models/llama/runner/runner.cpp | 26 ++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/examples/models/llama/runner/runner.cpp b/examples/models/llama/runner/runner.cpp index 0ecc611ef6c..abefb98eb57 100644 --- a/examples/models/llama/runner/runner.cpp +++ b/examples/models/llama/runner/runner.cpp @@ -11,6 +11,7 @@ #include +#include #include #include @@ -221,11 +222,11 @@ Error Runner::generate( ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token"); ET_CHECK_MSG( - num_prompt_tokens < metadata_.at(kMaxSeqLen), + num_prompt_tokens < metadata_.at(kMaxContextLen), "num_prompt_tokens %d >= max_seq_len_ %" PRId64 ", Max seq length exceeded - please increase max seq len value in .../llama2/model.py", num_prompt_tokens, - metadata_.at(kMaxSeqLen)); + metadata_.at(kMaxContextLen)); ET_CHECK_MSG( num_prompt_tokens < seq_len, "num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()", @@ -241,11 +242,26 @@ Error Runner::generate( wrapped_callback(prompt); } int64_t pos = 0; - auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos); + uint64_t cur_token; + int max_seq_len = metadata_.at(kMaxSeqLen) - + 1; // -1 because for some reason tracing results in this upperbound + int num_tokens_to_process = 0; + while (num_tokens_to_process < num_prompt_tokens) { + auto num_tokens_to_prefill_with = + std::min(num_prompt_tokens - num_tokens_to_process, max_seq_len); + std::vector prompt_tokens_to_process(num_tokens_to_prefill_with); + std::copy( + prompt_tokens.begin() + num_tokens_to_process, + prompt_tokens.begin() + num_tokens_to_process + num_tokens_to_prefill_with, + prompt_tokens_to_process.begin()); + auto prefill_res = + text_prefiller_->prefill(prompt_tokens_to_process, pos); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + cur_token = prefill_res.get(); + num_tokens_to_process += num_tokens_to_prefill_with; + } stats_.first_token_ms = llm::time_in_ms(); stats_.prompt_eval_end_ms = llm::time_in_ms(); - ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); - uint64_t cur_token = prefill_res.get(); // print the first token from prefill. No prev_token so use cur_token for it. wrapped_callback( From d6c9e456ad19755c98fe10b8667d56401aa40f46 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 31 Mar 2025 15:46:54 -0700 Subject: [PATCH 2/5] Update on "[ExecuTorch][Llama] Change runner to enable chunked prefill" This diff adds code to chunk prompt longer than max_seq_len to enable prefill of larger context Differential Revision: [D71833061](https://our.internmc.facebook.com/intern/diff/D71833061/) [ghstack-poisoned] --- examples/models/llama/runner/runner.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/models/llama/runner/runner.cpp b/examples/models/llama/runner/runner.cpp index abefb98eb57..e0208fbb290 100644 --- a/examples/models/llama/runner/runner.cpp +++ b/examples/models/llama/runner/runner.cpp @@ -252,10 +252,10 @@ Error Runner::generate( std::vector prompt_tokens_to_process(num_tokens_to_prefill_with); std::copy( prompt_tokens.begin() + num_tokens_to_process, - prompt_tokens.begin() + num_tokens_to_process + num_tokens_to_prefill_with, + prompt_tokens.begin() + num_tokens_to_process + + num_tokens_to_prefill_with, prompt_tokens_to_process.begin()); - auto prefill_res = - text_prefiller_->prefill(prompt_tokens_to_process, pos); + auto prefill_res = text_prefiller_->prefill(prompt_tokens_to_process, pos); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); cur_token = prefill_res.get(); num_tokens_to_process += num_tokens_to_prefill_with; From b5dfef920734c6547479d38f0024b46f202a7e33 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 31 Mar 2025 19:26:48 -0700 Subject: [PATCH 3/5] Update on "[ExecuTorch][Llama] Change runner to enable chunked prefill" This diff adds code to chunk prompt longer than max_seq_len to enable prefill of larger context Differential Revision: [D71833061](https://our.internmc.facebook.com/intern/diff/D71833061/) [ghstack-poisoned] --- examples/models/llama/runner/runner.cpp | 24 +++----------- extension/llm/runner/text_prefiller.cpp | 44 +++++++++++++++++++++++-- extension/llm/runner/text_prefiller.h | 14 +++++++- 3 files changed, 60 insertions(+), 22 deletions(-) diff --git a/examples/models/llama/runner/runner.cpp b/examples/models/llama/runner/runner.cpp index e0208fbb290..429e4b61c36 100644 --- a/examples/models/llama/runner/runner.cpp +++ b/examples/models/llama/runner/runner.cpp @@ -141,7 +141,8 @@ Error Runner::load() { text_prefiller_ = std::make_unique( text_decoder_runner_.get(), metadata_.at(kUseKVCache), - metadata_.at(kEnableDynamicShape)); + metadata_.at(kEnableDynamicShape), + metadata_.at(kMaxSeqLen)); text_token_generator_ = std::make_unique( tokenizer_.get(), @@ -242,24 +243,9 @@ Error Runner::generate( wrapped_callback(prompt); } int64_t pos = 0; - uint64_t cur_token; - int max_seq_len = metadata_.at(kMaxSeqLen) - - 1; // -1 because for some reason tracing results in this upperbound - int num_tokens_to_process = 0; - while (num_tokens_to_process < num_prompt_tokens) { - auto num_tokens_to_prefill_with = - std::min(num_prompt_tokens - num_tokens_to_process, max_seq_len); - std::vector prompt_tokens_to_process(num_tokens_to_prefill_with); - std::copy( - prompt_tokens.begin() + num_tokens_to_process, - prompt_tokens.begin() + num_tokens_to_process + - num_tokens_to_prefill_with, - prompt_tokens_to_process.begin()); - auto prefill_res = text_prefiller_->prefill(prompt_tokens_to_process, pos); - ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); - cur_token = prefill_res.get(); - num_tokens_to_process += num_tokens_to_prefill_with; - } + auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + uint64_t cur_token = prefill_res.get(); stats_.first_token_ms = llm::time_in_ms(); stats_.prompt_eval_end_ms = llm::time_in_ms(); diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 473cc2a3d81..0bf3a478df5 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -10,6 +10,7 @@ // LLM. #include +#include namespace executorch { namespace extension { @@ -18,10 +19,12 @@ namespace llm { TextPrefiller::TextPrefiller( TextDecoderRunner* text_decoder_runner, bool use_kv_cache, - bool enable_parallel_prefill) + bool enable_parallel_prefill, + int64_t max_seq_len) : text_decoder_runner_(text_decoder_runner), use_kv_cache_(use_kv_cache), - enable_parallel_prefill_(enable_parallel_prefill) {} + enable_parallel_prefill_(enable_parallel_prefill), + max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127) {} // -1 because for some reason tracing results in this upperbound ::executorch::runtime::Result TextPrefiller::prefill( std::vector& prompt_tokens, @@ -30,6 +33,43 @@ ::executorch::runtime::Result TextPrefiller::prefill( if (!text_decoder_runner_->is_method_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load()); } + + // Check if we need to chunk the prompt tokens + int32_t num_prompt_tokens = prompt_tokens.size(); + + // If prompt tokens exceed max_seq_len_, we need to chunk them + if (num_prompt_tokens > max_seq_len_) { + uint64_t cur_token = 0; + int num_tokens_to_process = 0; + + while (num_tokens_to_process < num_prompt_tokens) { + auto num_tokens_to_prefill_with = + std::min(num_prompt_tokens - num_tokens_to_process, max_seq_len_); + + std::vector prompt_tokens_to_process(num_tokens_to_prefill_with); + std::copy( + prompt_tokens.begin() + num_tokens_to_process, + prompt_tokens.begin() + num_tokens_to_process + num_tokens_to_prefill_with, + prompt_tokens_to_process.begin()); + + // Process this chunk + auto chunk_result = prefillChunk(prompt_tokens_to_process, start_pos); + ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error()); + cur_token = chunk_result.get(); + + num_tokens_to_process += num_tokens_to_prefill_with; + } + + return cur_token; + } else { + // If prompt tokens don't exceed max_seq_len_, process them directly + return prefillChunk(prompt_tokens, start_pos); + } +} + +::executorch::runtime::Result TextPrefiller::prefillChunk( + std::vector& prompt_tokens, + int64_t& start_pos) { // enable_parallel_prefill_ maybe set even when not using kv cache // When kv cache is not used, start pos is ignored int32_t num_prompt_tokens = prompt_tokens.size(); diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index 007f8188f56..3450e417304 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -22,7 +22,8 @@ class ET_EXPERIMENTAL TextPrefiller { TextPrefiller( TextDecoderRunner* text_decoder_runner, bool use_kv_cache_, - bool enable_parallel_prefill); + bool enable_parallel_prefill, + int64_t max_seq_len = 128); /** * Prefill an LLM Module with the given text input. * @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by @@ -35,10 +36,21 @@ class ET_EXPERIMENTAL TextPrefiller { std::vector& prompt_tokens, int64_t& start_pos); + /** + * Helper method to prefill a chunk of tokens. + * @param prompt_tokens The chunk of text prompt tokens to process. + * @param start_pos The starting position in KV cache of the input in the LLM Module. + * @return The next token of the LLM Module after prefilling this chunk. + */ + ::executorch::runtime::Result prefillChunk( + std::vector& prompt_tokens, + int64_t& start_pos); + private: TextDecoderRunner* text_decoder_runner_; bool use_kv_cache_; bool enable_parallel_prefill_; + int64_t max_seq_len_; }; } // namespace llm From ae9307816347a20b2526e23c98e2b058ba08f9b8 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 31 Mar 2025 19:37:14 -0700 Subject: [PATCH 4/5] Update on "[ExecuTorch][Llama] Change runner to enable chunked prefill" This diff adds code to chunk prompt longer than max_seq_len to enable prefill of larger context Differential Revision: [D71833061](https://our.internmc.facebook.com/intern/diff/D71833061/) [ghstack-poisoned] --- examples/models/llava/runner/llava_runner.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/models/llava/runner/llava_runner.cpp b/examples/models/llava/runner/llava_runner.cpp index d368f8fb1a4..971e126a14c 100644 --- a/examples/models/llava/runner/llava_runner.cpp +++ b/examples/models/llava/runner/llava_runner.cpp @@ -55,7 +55,8 @@ Error LlavaRunner::load() { text_prefiller_ = std::make_unique( text_decoder_runner_.get(), /*use_kv_cache=*/true, - /*enable_parallel_prefill=*/true); + /*enable_parallel_prefill=*/true, + /*max_seq_len=*/128); // Load the image prefiller image_prefiller_ = std::make_unique(module_.get()); From 6312ea66716fc6bf927723e658d3504cf1ed2cde Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Mon, 31 Mar 2025 19:58:10 -0700 Subject: [PATCH 5/5] Update on "[ExecuTorch][Llama] Change runner to enable chunked prefill" This diff adds code to chunk prompt longer than max_seq_len to enable prefill of larger context Differential Revision: [D71833061](https://our.internmc.facebook.com/intern/diff/D71833061/) [ghstack-poisoned] --- extension/llm/runner/text_prefiller.cpp | 27 ++++++++++++++----------- extension/llm/runner/text_prefiller.h | 3 ++- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 0bf3a478df5..19c260f5be6 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -24,7 +24,8 @@ TextPrefiller::TextPrefiller( : text_decoder_runner_(text_decoder_runner), use_kv_cache_(use_kv_cache), enable_parallel_prefill_(enable_parallel_prefill), - max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127) {} // -1 because for some reason tracing results in this upperbound + max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127) { +} // -1 because for some reason tracing results in this upperbound ::executorch::runtime::Result TextPrefiller::prefill( std::vector& prompt_tokens, @@ -33,33 +34,35 @@ ::executorch::runtime::Result TextPrefiller::prefill( if (!text_decoder_runner_->is_method_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load()); } - + // Check if we need to chunk the prompt tokens int32_t num_prompt_tokens = prompt_tokens.size(); - + // If prompt tokens exceed max_seq_len_, we need to chunk them if (num_prompt_tokens > max_seq_len_) { uint64_t cur_token = 0; int num_tokens_to_process = 0; - + while (num_tokens_to_process < num_prompt_tokens) { - auto num_tokens_to_prefill_with = - std::min(num_prompt_tokens - num_tokens_to_process, max_seq_len_); - - std::vector prompt_tokens_to_process(num_tokens_to_prefill_with); + auto num_tokens_to_prefill_with = std::min( + num_prompt_tokens - num_tokens_to_process, max_seq_len_); + + std::vector prompt_tokens_to_process( + num_tokens_to_prefill_with); std::copy( prompt_tokens.begin() + num_tokens_to_process, - prompt_tokens.begin() + num_tokens_to_process + num_tokens_to_prefill_with, + prompt_tokens.begin() + num_tokens_to_process + + num_tokens_to_prefill_with, prompt_tokens_to_process.begin()); - + // Process this chunk auto chunk_result = prefillChunk(prompt_tokens_to_process, start_pos); ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error()); cur_token = chunk_result.get(); - + num_tokens_to_process += num_tokens_to_prefill_with; } - + return cur_token; } else { // If prompt tokens don't exceed max_seq_len_, process them directly diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index 3450e417304..0620eadfe9f 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -39,7 +39,8 @@ class ET_EXPERIMENTAL TextPrefiller { /** * Helper method to prefill a chunk of tokens. * @param prompt_tokens The chunk of text prompt tokens to process. - * @param start_pos The starting position in KV cache of the input in the LLM Module. + * @param start_pos The starting position in KV cache of the input in the LLM + * Module. * @return The next token of the LLM Module after prefilling this chunk. */ ::executorch::runtime::Result prefillChunk(