From bab4ac792e31a18b628abc44cde722cfb918483d Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Tue, 16 Sep 2025 09:55:08 -0700 Subject: [PATCH] StaticAttentionIOManager: optional callback on logits from prefill Summary: The prefill can be divided into batches, the logits callback will be called on each batch. Differential Revision: D82150606 --- .../models/llama/runner/static_attention_io_manager.h | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/runner/static_attention_io_manager.h b/examples/models/llama/runner/static_attention_io_manager.h index a696d92c40c..db7add8d16a 100644 --- a/examples/models/llama/runner/static_attention_io_manager.h +++ b/examples/models/llama/runner/static_attention_io_manager.h @@ -589,7 +589,9 @@ class StaticAttentionIOManager { size_t prefill( executorch::runtime::Span tokens, executorch::runtime::Span input_buffer, - executorch::runtime::Method& method) { + executorch::runtime::Method& method, + std::function)> + logits_callback = nullptr) { ET_LOG(Info, "Prefilling at position %zu", input_pos_); size_t input_len = input_buffer.size(); auto& masks = get_mask(input_buffer.size()); @@ -610,6 +612,13 @@ class StaticAttentionIOManager { config_.k_cache_output_indices, config_.v_cache_output_indices, batch_len); + if (logits_callback) { + auto logits_tensor = method.get_output(0).toTensor(); + auto* logits = logits_tensor.const_data_ptr(); + logits_callback(executorch::runtime::Span( + logits, + logits + batch_len * logits_tensor.size(logits_tensor.dim() - 1))); + } } return batch_len - 1; }