Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/mediatek/executor_runner/mtk_llama_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class MTKLlamaRunner : public executorch::extension::llm::IRunner {
std::function<void(const std::string&)> token_callback);
std::unique_ptr<Tokenizer> load_tokenizer();

void reset() {}

private:
// model
const LlamaModelOptions modeloptions_;
Expand Down
2 changes: 2 additions & 0 deletions examples/models/llama/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ int32_t main(int32_t argc, char** argv) {
ET_LOG(Error, "Failed to warmup llama runner");
return 1;
}
// reset kv cache pos to 0
runner->reset();
}
// generate
executorch::extension::llm::GenerationConfig config{
Expand Down
14 changes: 2 additions & 12 deletions examples/qualcomm/oss_scripts/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,17 +354,6 @@ Error Runner<T>::generate(
const llm::GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
}

template <typename T>
Error Runner<T>::generate_from_pos(
const std::string& prompt,
int64_t start_pos,
const llm::GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
// TODO: currently only support start_pos == 0
return generate_from_prompt_or_file(
prompt, false, config, token_callback, stats_callback);
}
Expand Down Expand Up @@ -435,7 +424,8 @@ Error Runner<T>::generate_from_prompt_or_file(
stats_.first_token_ms = time_in_ms();
stats_.prompt_eval_end_ms = time_in_ms();

// print the first token from prefill. No prev_token so use cur_token for it.
// print the first token from prefill. No prev_token so use cur_token for
// it.
if (token_callback) {
token_callback(
ET_UNWRAP_TOKENIZER(tokenizer_->decode(cur_token, cur_token)));
Expand Down
9 changes: 2 additions & 7 deletions examples/qualcomm/oss_scripts/llama/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,15 @@ class Runner : public executorch::extension::llm::IRunner {
std::function<void(const std::string&)> token_callback = {},
std::function<void(const executorch::llm::Stats&)> stats_callback = {})
override;
executorch::runtime::Error generate_from_pos(
const std::string& prompt,
int64_t start_pos,
const executorch::extension::llm::GenerationConfig& config,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const executorch::llm::Stats&)> stats_callback = {})
override;

executorch::runtime::Error generate_from_prompt_or_file(
const std::string& prompt,
bool tokenized_prompt,
const executorch::extension::llm::GenerationConfig& config,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const executorch::llm::Stats&)> stats_callback = {});
void stop() override {};
void reset() override {};
executorch::runtime::Result<DecoderModelVersion> get_decoder_model_version();

private:
Expand Down
25 changes: 2 additions & 23 deletions extension/llm/runner/irunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,39 +125,18 @@ class ET_EXPERIMENTAL IRunner {
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) = 0;

/**
* Generate text based on the provided prompt and generation config, from a
* given position in KV cache.
*
* @param prompt The input prompt to generate from
* @param start_pos The starting position in KV cache of the input. Note:
* Depending on the actual implementation, a runner may manage the position
* internally, and this may not be respected.
* @param config Generation configuration parameters
* @param token_callback Callback function called for each generated token
* @param stats_callback Callback function for generation statistics
* @return Error::Ok if successful, an error otherwise
*/
virtual runtime::Error generate_from_pos(
const std::string& prompt,
int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) = 0;
/**
* Stop the generation process.
*/
virtual void stop() = 0;

/**
* Force remove prefilled tokens and reset KV cache start position
*
* For some existing runners, overriding this method is not needed because
* start_pos is passed as an argument to generate_from_pos.
*
* This method removes the prefilled tokens from the KV cache and resets the
* start position to 0.
*/
virtual void reset() {};
virtual void reset() = 0;
};

} // namespace llm
Expand Down
13 changes: 2 additions & 11 deletions extension/llm/runner/text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ TextLLMRunner::TextLLMRunner(
io_manager_(std::move(io_manager)),
text_token_generator_(std::move(text_token_generator)),
stats_(std::move(stats)),
pos_(0),
temperature_(temperature) {
// Note: This constructor assumes that text_prefiller and text_token_generator
// already have references to the Module and TextDecoderRunner they need
Expand Down Expand Up @@ -70,9 +71,8 @@ Error TextLLMRunner::load() {
ET_LOG(Info, format, __VA_ARGS__); \
}

Error TextLLMRunner::generate_from_pos(
Error TextLLMRunner::generate(
const std::string& prompt,
ET_UNUSED int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
Expand Down Expand Up @@ -217,15 +217,6 @@ Error TextLLMRunner::generate_from_pos(
return Error::Ok;
}

Error TextLLMRunner::generate(
const std::string& prompt,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
pos_ = 0;
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
}

Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) {
// Create a GenerationConfig for warmup
GenerationConfig config{
Expand Down
21 changes: 2 additions & 19 deletions extension/llm/runner/text_llm_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,6 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
std::function<void(const std::string&)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {}) override;

/**
* Generate text based on the provided prompt and generation config, from a
* given position in KV cache.
*
* @param prompt The input prompt to generate from
* @param start_pos [Unused] The starting position in KV cache of the input,
* ignored because the runner manages the position internally.
* @param config Generation configuration parameters
* @param token_callback Callback function called for each generated token
* @param stats_callback Callback function for generation statistics
* @return Error::Ok if successful, an error otherwise
*/
ET_DEPRECATED runtime::Error generate_from_pos(
const std::string& prompt,
ET_UNUSED int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {}) override;

/**
* @brief Warms up the model with a sample prompt
*
Expand All @@ -133,13 +114,15 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
::executorch::runtime::Error warmup(
const std::string& prompt,
int32_t max_new_tokens);

/**
* @brief Remove prefilled tokens and reset start position, and stats.
*
* This method removes the prefilled tokens from the KV cache and resets the
* start position to 0. It also clears the stats for previous runs.
*/
void reset() override;

/**
* @brief Stops the ongoing text generation process
*
Expand Down
Loading