From beb17840b98515fa24b9176993b632cf1ad5e195 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 22 Sep 2025 15:19:28 -0700 Subject: [PATCH] Add a prefill() method for text llm runner --- extension/llm/runner/text_llm_runner.cpp | 22 ++++++++++++++++++++++ extension/llm/runner/text_llm_runner.h | 11 +++++++++++ 2 files changed, 33 insertions(+) diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 333716ac831..ec9c6c5242f 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -217,6 +217,28 @@ Error TextLLMRunner::generate( return Error::Ok; } +Error TextLLMRunner::prefill( + const std::string& prompt, + const GenerationConfig& config) { + if (!is_loaded()) { + ET_CHECK_OK_OR_RETURN_ERROR(load()); + } + + ::tokenizers::Result> encode_res = tokenizer_->encode( + prompt, + /*bos=*/config.num_bos, + /*eos=*/config.num_eos); + + ET_CHECK_TK_OK_OR_RETURN_ERROR( + encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); + + // encode the (string) prompt into tokens sequence + std::vector prompt_tokens = encode_res.get(); + auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + return Error::Ok; +} + Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { // Create a GenerationConfig for warmup GenerationConfig config{ diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index 9dd99d82d59..98fcef94f96 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -101,6 +101,17 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { std::function token_callback = {}, std::function stats_callback = {}) override; + /** + * Prefill text inputs, for example to reload chat history. + * @param prompt Text prompt to prefill. + * @param config Configuration parameters for text generation (e.g., + * max_new_tokens, temperature) + * @return The error code. KV cache position is tracked internally in pos_. + */ + ::executorch::runtime::Error prefill( + const std::string& prompt, + const GenerationConfig& config); + /** * @brief Warms up the model with a sample prompt *