diff --git a/examples/models/llama/runner/runner.cpp b/examples/models/llama/runner/runner.cpp index 0ecc611ef6c..429e4b61c36 100644 --- a/examples/models/llama/runner/runner.cpp +++ b/examples/models/llama/runner/runner.cpp @@ -11,6 +11,7 @@ #include +#include #include #include @@ -140,7 +141,8 @@ Error Runner::load() { text_prefiller_ = std::make_unique( text_decoder_runner_.get(), metadata_.at(kUseKVCache), - metadata_.at(kEnableDynamicShape)); + metadata_.at(kEnableDynamicShape), + metadata_.at(kMaxSeqLen)); text_token_generator_ = std::make_unique( tokenizer_.get(), @@ -221,11 +223,11 @@ Error Runner::generate( ET_CHECK_MSG(num_prompt_tokens >= 1, "Expected at least 1 prompt token"); ET_CHECK_MSG( - num_prompt_tokens < metadata_.at(kMaxSeqLen), + num_prompt_tokens < metadata_.at(kMaxContextLen), "num_prompt_tokens %d >= max_seq_len_ %" PRId64 ", Max seq length exceeded - please increase max seq len value in .../llama2/model.py", num_prompt_tokens, - metadata_.at(kMaxSeqLen)); + metadata_.at(kMaxContextLen)); ET_CHECK_MSG( num_prompt_tokens < seq_len, "num_prompt_tokens %d >= seq_len %d, Sequence length exceeded - please increase the seq_len value passed to generate()", @@ -242,10 +244,10 @@ Error Runner::generate( } int64_t pos = 0; auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos); - stats_.first_token_ms = llm::time_in_ms(); - stats_.prompt_eval_end_ms = llm::time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); + stats_.first_token_ms = llm::time_in_ms(); + stats_.prompt_eval_end_ms = llm::time_in_ms(); // print the first token from prefill. No prev_token so use cur_token for it. wrapped_callback( diff --git a/examples/models/llava/runner/llava_runner.cpp b/examples/models/llava/runner/llava_runner.cpp index d368f8fb1a4..971e126a14c 100644 --- a/examples/models/llava/runner/llava_runner.cpp +++ b/examples/models/llava/runner/llava_runner.cpp @@ -55,7 +55,8 @@ Error LlavaRunner::load() { text_prefiller_ = std::make_unique( text_decoder_runner_.get(), /*use_kv_cache=*/true, - /*enable_parallel_prefill=*/true); + /*enable_parallel_prefill=*/true, + /*max_seq_len=*/128); // Load the image prefiller image_prefiller_ = std::make_unique(module_.get()); diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 473cc2a3d81..19c260f5be6 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -10,6 +10,7 @@ // LLM. #include +#include namespace executorch { namespace extension { @@ -18,10 +19,13 @@ namespace llm { TextPrefiller::TextPrefiller( TextDecoderRunner* text_decoder_runner, bool use_kv_cache, - bool enable_parallel_prefill) + bool enable_parallel_prefill, + int64_t max_seq_len) : text_decoder_runner_(text_decoder_runner), use_kv_cache_(use_kv_cache), - enable_parallel_prefill_(enable_parallel_prefill) {} + enable_parallel_prefill_(enable_parallel_prefill), + max_seq_len_(max_seq_len > 0 ? max_seq_len - 1 : 127) { +} // -1 because for some reason tracing results in this upperbound ::executorch::runtime::Result TextPrefiller::prefill( std::vector& prompt_tokens, @@ -30,6 +34,45 @@ ::executorch::runtime::Result TextPrefiller::prefill( if (!text_decoder_runner_->is_method_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load()); } + + // Check if we need to chunk the prompt tokens + int32_t num_prompt_tokens = prompt_tokens.size(); + + // If prompt tokens exceed max_seq_len_, we need to chunk them + if (num_prompt_tokens > max_seq_len_) { + uint64_t cur_token = 0; + int num_tokens_to_process = 0; + + while (num_tokens_to_process < num_prompt_tokens) { + auto num_tokens_to_prefill_with = std::min( + num_prompt_tokens - num_tokens_to_process, max_seq_len_); + + std::vector prompt_tokens_to_process( + num_tokens_to_prefill_with); + std::copy( + prompt_tokens.begin() + num_tokens_to_process, + prompt_tokens.begin() + num_tokens_to_process + + num_tokens_to_prefill_with, + prompt_tokens_to_process.begin()); + + // Process this chunk + auto chunk_result = prefillChunk(prompt_tokens_to_process, start_pos); + ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error()); + cur_token = chunk_result.get(); + + num_tokens_to_process += num_tokens_to_prefill_with; + } + + return cur_token; + } else { + // If prompt tokens don't exceed max_seq_len_, process them directly + return prefillChunk(prompt_tokens, start_pos); + } +} + +::executorch::runtime::Result TextPrefiller::prefillChunk( + std::vector& prompt_tokens, + int64_t& start_pos) { // enable_parallel_prefill_ maybe set even when not using kv cache // When kv cache is not used, start pos is ignored int32_t num_prompt_tokens = prompt_tokens.size(); diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index 007f8188f56..0620eadfe9f 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -22,7 +22,8 @@ class ET_EXPERIMENTAL TextPrefiller { TextPrefiller( TextDecoderRunner* text_decoder_runner, bool use_kv_cache_, - bool enable_parallel_prefill); + bool enable_parallel_prefill, + int64_t max_seq_len = 128); /** * Prefill an LLM Module with the given text input. * @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by @@ -35,10 +36,22 @@ class ET_EXPERIMENTAL TextPrefiller { std::vector& prompt_tokens, int64_t& start_pos); + /** + * Helper method to prefill a chunk of tokens. + * @param prompt_tokens The chunk of text prompt tokens to process. + * @param start_pos The starting position in KV cache of the input in the LLM + * Module. + * @return The next token of the LLM Module after prefilling this chunk. + */ + ::executorch::runtime::Result prefillChunk( + std::vector& prompt_tokens, + int64_t& start_pos); + private: TextDecoderRunner* text_decoder_runner_; bool use_kv_cache_; bool enable_parallel_prefill_; + int64_t max_seq_len_; }; } // namespace llm