diff --git a/examples/models/llama/runner/static_attention_io_manager.h b/examples/models/llama/runner/static_attention_io_manager.h index 74925a777a2..aeff4e09478 100644 --- a/examples/models/llama/runner/static_attention_io_manager.h +++ b/examples/models/llama/runner/static_attention_io_manager.h @@ -602,6 +602,39 @@ class StaticAttentionIOManager { } } + /** + * Prefill helper. Run multiple inferences as needed depending on the length + * of the prompt and method's input length. Returns the position in the output + * that corresponds to the end of the prompt during the last inference. + */ + template + size_t prefill( + executorch::runtime::Span tokens, + executorch::runtime::Span input_buffer, + executorch::runtime::Method& method) { + size_t input_len = input_buffer.size(); + get_mask(input_buffer.size()).set_causal_mask(); + + size_t batch_len = 0; + for (size_t i = 0; i < tokens.size(); i += input_len) { + batch_len = std::min(input_len, tokens.size() - i); + std::copy(&tokens[i], &tokens[i + batch_len], input_buffer.begin()); + prepare(method); + ET_CHECK(method.execute() == executorch::runtime::Error::Ok); + update( + method, + config_.k_cache_output_indices, + config_.v_cache_output_indices, + batch_len); + } + return batch_len - 1; + } + + /** + * Decode helper. The `sample` argument is called after each inference and + * should retrieve the logits from the `method` argument's output and return + * the sampled token. + */ template std::vector decode( TokenT prev_tok, @@ -632,6 +665,11 @@ class StaticAttentionIOManager { return generated_tokens; } + /** + * Lookahead decode helper. The `sample` argument is called after each + * inference and should retrieve the logits from the `method` argument's + * output and return the sampled token for all output positions. + */ template std::vector lookahead_decode( TokenT prev_tok,