Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion extension/llm/runner/irunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ class ET_EXPERIMENTAL IRunner {
* given position in KV cache.
*
* @param prompt The input prompt to generate from
* @param start_pos The starting position in KV cache of the input
* @param start_pos The starting position in KV cache of the input. Note:
* Depending on the actual implementation, a runner may manage the position
* internally, and this may not be respected.
* @param config Generation configuration parameters
* @param token_callback Callback function called for each generated token
* @param stats_callback Callback function for generation statistics
Expand All @@ -146,6 +148,16 @@ class ET_EXPERIMENTAL IRunner {
* Stop the generation process.
*/
virtual void stop() = 0;
/**
* Force remove prefilled tokens and reset KV cache start position
*
* For some existing runners, overriding this method is not needed because
* start_pos is passed as an argument to generate_from_pos.
*
* This method removes the prefilled tokens from the KV cache and resets the
* start position to 0.
*/
virtual void reset() {};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be pure virtual after all migration

};

} // namespace llm
Expand Down
60 changes: 0 additions & 60 deletions extension/llm/runner/test/test_text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,64 +346,4 @@ TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) {
EXPECT_TRUE(runner.is_loaded());
}

// Test that generate_from_pos() errors out when max_new_tokens is negative
TEST_F(RunnerTest, GenerateFromPosErrorsWithNegativeMaxNewTokens) {
// Create mock instances using helper functions
auto tokenizer = createMockTokenizer();
auto text_decoder_runner = createMockTextDecoderRunner();
auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get());

// Set up expectations for the tokenizer encode method
ON_CALL(*tokenizer, encode(_, _, _))
.WillByDefault([&](const std::string&, int8_t, int8_t) {
return ::tokenizers::Result<std::vector<uint64_t>>(
std::vector<uint64_t>{1, 2, 3});
});

// Set up expectations for load methods
ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true));

std::unique_ptr<executorch::llm::Stats> stats =
std::make_unique<executorch::llm::Stats>();
// Create a real TextTokenGenerator
auto text_token_generator = createTextTokenGenerator(
tokenizer.get(), text_decoder_runner.get(), stats.get());

// Create a Runner with our mocked components
auto module = std::make_unique<MockModule>();
auto io_manager =
std::make_unique<executorch::extension::llm::IOManager>(*module);
TextLLMRunner runner(
{
{"enable_dynamic_shape", false},
{"get_max_seq_len", 10},
{"get_max_context_len", 10},
{"use_kv_cache", true},
},
std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()),
std::move(module),
std::move(text_decoder_runner),
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
text_prefiller.release()),
std::move(io_manager),
std::move(text_token_generator),
std::move(stats));

// Load
runner.load();

// Set up the generation config with a negative max_new_tokens value
GenerationConfig config;
config.max_new_tokens = 5;
config.echo = false;

// num_prompt_tokens = 3
// max_context_len = 10
// start_pos = 8, this should fail because 10 - 8 > 3, even though
// config.max_new_tokens = 5 > 3, it's still a failure.
Error err = runner.generate_from_pos("test prompt", 8, config);

// Verify that an InvalidArgument error is returned
EXPECT_EQ(err, Error::InvalidArgument);
}
} // namespace
22 changes: 14 additions & 8 deletions extension/llm/runner/text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Error TextLLMRunner::load() {

Error TextLLMRunner::generate_from_pos(
const std::string& prompt,
int64_t start_pos,
ET_UNUSED int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
Expand Down Expand Up @@ -123,8 +123,8 @@ Error TextLLMRunner::generate_from_pos(
std::vector<uint64_t> prompt_tokens = encode_res.get();
int num_prompt_tokens = prompt_tokens.size();

// Reduce max_context_len by start_pos
int64_t max_context_len = metadata_.at(kMaxContextLen) - start_pos;
// Reduce max_context_len by pos_
int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_;
ET_CHECK_OR_RETURN_ERROR(
num_prompt_tokens >= 1,
InvalidArgument,
Expand All @@ -138,16 +138,16 @@ Error TextLLMRunner::generate_from_pos(
max_context_len);

// Determine max_new_tokens using the GenerationConfig's resolve method,
// then subtract start_pos for max_new_tokens.
// then subtract pos_ for max_new_tokens.
int max_new_tokens =
config.resolve_max_new_tokens(max_context_len, num_prompt_tokens);

ET_LOG(
Info,
"Max new tokens resolved: %d, given start_pos %" PRId64
"Max new tokens resolved: %d, given pos_ %" PRId64
", num_prompt_tokens %zu, max_context_len %" PRId64,
max_new_tokens,
start_pos,
pos_,
prompt_tokens.size(),
max_context_len);
ET_CHECK_OR_RETURN_ERROR(
Expand All @@ -163,8 +163,7 @@ Error TextLLMRunner::generate_from_pos(
if (config.echo) {
wrapped_callback(prompt);
}
int64_t pos = start_pos;
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos);
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_);
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
uint64_t cur_token = prefill_res.get();
stats_->first_token_ms = time_in_ms();
Expand Down Expand Up @@ -217,11 +216,13 @@ Error TextLLMRunner::generate_from_pos(

return Error::Ok;
}

Error TextLLMRunner::generate(
const std::string& prompt,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
pos_ = 0;
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
}

Expand All @@ -246,4 +247,9 @@ void TextLLMRunner::stop() {
}
}

void TextLLMRunner::reset() {
stats_->reset();
pos_ = 0;
}

} // namespace executorch::extension::llm
37 changes: 21 additions & 16 deletions extension/llm/runner/text_llm_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,25 +102,20 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
std::function<void(const Stats&)> stats_callback = {}) override;

/**
* @brief Generates text based on the provided prompt and start position
* Generate text based on the provided prompt and generation config, from a
* given position in KV cache.
*
* This method performs text generation using the loaded model. It processes
* the input prompt, runs the model in prefill and decode phases using the
* start position until max tokens to generate is reached or eos token is
* generated, then returns generated text and perf stats through callbacks.
*
* @param prompt The input text to generate from
* @param start_pos The starting position in KV cache of the input
* @param config Configuration parameters for text generation (e.g.,
* max_new_tokens, temperature)
* @param token_callback Function called for each generated token with the
* decoded text
* @param stats_callback Function called with performance statistics
* @return ::executorch::runtime::Error Success or error status
* @param prompt The input prompt to generate from
* @param start_pos [Unused] The starting position in KV cache of the input,
* ignored because the runner manages the position internally.
* @param config Generation configuration parameters
* @param token_callback Callback function called for each generated token
* @param stats_callback Callback function for generation statistics
* @return Error::Ok if successful, an error otherwise
*/
::executorch::runtime::Error generate_from_pos(
ET_DEPRECATED runtime::Error generate_from_pos(
const std::string& prompt,
int64_t start_pos,
ET_UNUSED int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {}) override;
Expand All @@ -138,6 +133,13 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
::executorch::runtime::Error warmup(
const std::string& prompt,
int32_t max_new_tokens);
/**
* @brief Remove prefilled tokens and reset start position, and stats.
*
* This method removes the prefilled tokens from the KV cache and resets the
* start position to 0. It also clears the stats for previous runs.
*/
void reset() override;
/**
* @brief Stops the ongoing text generation process
*
Expand Down Expand Up @@ -169,6 +171,9 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
// temperature.
// Deprecated, we should rely on the temperature in GenerationConfig instead.
float temperature_ = -1.0f;

// The position in KV cache of the input, starting from 0.
int64_t pos_ = 0;
};

} // namespace executorch::extension::llm
Loading