diff --git a/extension/llm/runner/irunner.h b/extension/llm/runner/irunner.h index 4c2efc91203..5bd5ef9d04e 100644 --- a/extension/llm/runner/irunner.h +++ b/extension/llm/runner/irunner.h @@ -130,7 +130,9 @@ class ET_EXPERIMENTAL IRunner { * given position in KV cache. * * @param prompt The input prompt to generate from - * @param start_pos The starting position in KV cache of the input + * @param start_pos The starting position in KV cache of the input. Note: + * Depending on the actual implementation, a runner may manage the position + * internally, and this may not be respected. * @param config Generation configuration parameters * @param token_callback Callback function called for each generated token * @param stats_callback Callback function for generation statistics @@ -146,6 +148,16 @@ class ET_EXPERIMENTAL IRunner { * Stop the generation process. */ virtual void stop() = 0; + /** + * Force remove prefilled tokens and reset KV cache start position + * + * For some existing runners, overriding this method is not needed because + * start_pos is passed as an argument to generate_from_pos. + * + * This method removes the prefilled tokens from the KV cache and resets the + * start position to 0. + */ + virtual void reset() {}; }; } // namespace llm diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index 8ec48b48ec3..ae56abd3d97 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -346,64 +346,4 @@ TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) { EXPECT_TRUE(runner.is_loaded()); } -// Test that generate_from_pos() errors out when max_new_tokens is negative -TEST_F(RunnerTest, GenerateFromPosErrorsWithNegativeMaxNewTokens) { - // Create mock instances using helper functions - auto tokenizer = createMockTokenizer(); - auto text_decoder_runner = createMockTextDecoderRunner(); - auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); - - // Set up expectations for the tokenizer encode method - ON_CALL(*tokenizer, encode(_, _, _)) - .WillByDefault([&](const std::string&, int8_t, int8_t) { - return ::tokenizers::Result>( - std::vector{1, 2, 3}); - }); - - // Set up expectations for load methods - ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); - - std::unique_ptr stats = - std::make_unique(); - // Create a real TextTokenGenerator - auto text_token_generator = createTextTokenGenerator( - tokenizer.get(), text_decoder_runner.get(), stats.get()); - - // Create a Runner with our mocked components - auto module = std::make_unique(); - auto io_manager = - std::make_unique(*module); - TextLLMRunner runner( - { - {"enable_dynamic_shape", false}, - {"get_max_seq_len", 10}, - {"get_max_context_len", 10}, - {"use_kv_cache", true}, - }, - std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), - std::move(module), - std::move(text_decoder_runner), - std::unique_ptr<::executorch::extension::llm::TextPrefiller>( - text_prefiller.release()), - std::move(io_manager), - std::move(text_token_generator), - std::move(stats)); - - // Load - runner.load(); - - // Set up the generation config with a negative max_new_tokens value - GenerationConfig config; - config.max_new_tokens = 5; - config.echo = false; - - // num_prompt_tokens = 3 - // max_context_len = 10 - // start_pos = 8, this should fail because 10 - 8 > 3, even though - // config.max_new_tokens = 5 > 3, it's still a failure. - Error err = runner.generate_from_pos("test prompt", 8, config); - - // Verify that an InvalidArgument error is returned - EXPECT_EQ(err, Error::InvalidArgument); -} } // namespace diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index f0ac9ed0781..b6f41fd7af6 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -72,7 +72,7 @@ Error TextLLMRunner::load() { Error TextLLMRunner::generate_from_pos( const std::string& prompt, - int64_t start_pos, + ET_UNUSED int64_t start_pos, const GenerationConfig& config, std::function token_callback, std::function stats_callback) { @@ -123,8 +123,8 @@ Error TextLLMRunner::generate_from_pos( std::vector prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); - // Reduce max_context_len by start_pos - int64_t max_context_len = metadata_.at(kMaxContextLen) - start_pos; + // Reduce max_context_len by pos_ + int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_; ET_CHECK_OR_RETURN_ERROR( num_prompt_tokens >= 1, InvalidArgument, @@ -138,16 +138,16 @@ Error TextLLMRunner::generate_from_pos( max_context_len); // Determine max_new_tokens using the GenerationConfig's resolve method, - // then subtract start_pos for max_new_tokens. + // then subtract pos_ for max_new_tokens. int max_new_tokens = config.resolve_max_new_tokens(max_context_len, num_prompt_tokens); ET_LOG( Info, - "Max new tokens resolved: %d, given start_pos %" PRId64 + "Max new tokens resolved: %d, given pos_ %" PRId64 ", num_prompt_tokens %zu, max_context_len %" PRId64, max_new_tokens, - start_pos, + pos_, prompt_tokens.size(), max_context_len); ET_CHECK_OR_RETURN_ERROR( @@ -163,8 +163,7 @@ Error TextLLMRunner::generate_from_pos( if (config.echo) { wrapped_callback(prompt); } - int64_t pos = start_pos; - auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos); + auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); stats_->first_token_ms = time_in_ms(); @@ -217,11 +216,13 @@ Error TextLLMRunner::generate_from_pos( return Error::Ok; } + Error TextLLMRunner::generate( const std::string& prompt, const GenerationConfig& config, std::function token_callback, std::function stats_callback) { + pos_ = 0; return generate_from_pos(prompt, 0, config, token_callback, stats_callback); } @@ -246,4 +247,9 @@ void TextLLMRunner::stop() { } } +void TextLLMRunner::reset() { + stats_->reset(); + pos_ = 0; +} + } // namespace executorch::extension::llm diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index fd0df786336..21b77fe1dfa 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -102,25 +102,20 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { std::function stats_callback = {}) override; /** - * @brief Generates text based on the provided prompt and start position + * Generate text based on the provided prompt and generation config, from a + * given position in KV cache. * - * This method performs text generation using the loaded model. It processes - * the input prompt, runs the model in prefill and decode phases using the - * start position until max tokens to generate is reached or eos token is - * generated, then returns generated text and perf stats through callbacks. - * - * @param prompt The input text to generate from - * @param start_pos The starting position in KV cache of the input - * @param config Configuration parameters for text generation (e.g., - * max_new_tokens, temperature) - * @param token_callback Function called for each generated token with the - * decoded text - * @param stats_callback Function called with performance statistics - * @return ::executorch::runtime::Error Success or error status + * @param prompt The input prompt to generate from + * @param start_pos [Unused] The starting position in KV cache of the input, + * ignored because the runner manages the position internally. + * @param config Generation configuration parameters + * @param token_callback Callback function called for each generated token + * @param stats_callback Callback function for generation statistics + * @return Error::Ok if successful, an error otherwise */ - ::executorch::runtime::Error generate_from_pos( + ET_DEPRECATED runtime::Error generate_from_pos( const std::string& prompt, - int64_t start_pos, + ET_UNUSED int64_t start_pos, const GenerationConfig& config, std::function token_callback = {}, std::function stats_callback = {}) override; @@ -138,6 +133,13 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { ::executorch::runtime::Error warmup( const std::string& prompt, int32_t max_new_tokens); + /** + * @brief Remove prefilled tokens and reset start position, and stats. + * + * This method removes the prefilled tokens from the KV cache and resets the + * start position to 0. It also clears the stats for previous runs. + */ + void reset() override; /** * @brief Stops the ongoing text generation process * @@ -169,6 +171,9 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { // temperature. // Deprecated, we should rely on the temperature in GenerationConfig instead. float temperature_ = -1.0f; + + // The position in KV cache of the input, starting from 0. + int64_t pos_ = 0; }; } // namespace executorch::extension::llm