diff --git a/examples/models/llama/runner/static_attention_io_manager.h b/examples/models/llama/runner/static_attention_io_manager.h index a696d92c40c..db7add8d16a 100644 --- a/examples/models/llama/runner/static_attention_io_manager.h +++ b/examples/models/llama/runner/static_attention_io_manager.h @@ -589,7 +589,9 @@ class StaticAttentionIOManager { size_t prefill( executorch::runtime::Span tokens, executorch::runtime::Span input_buffer, - executorch::runtime::Method& method) { + executorch::runtime::Method& method, + std::function)> + logits_callback = nullptr) { ET_LOG(Info, "Prefilling at position %zu", input_pos_); size_t input_len = input_buffer.size(); auto& masks = get_mask(input_buffer.size()); @@ -610,6 +612,13 @@ class StaticAttentionIOManager { config_.k_cache_output_indices, config_.v_cache_output_indices, batch_len); + if (logits_callback) { + auto logits_tensor = method.get_output(0).toTensor(); + auto* logits = logits_tensor.const_data_ptr(); + logits_callback(executorch::runtime::Span( + logits, + logits + batch_len * logits_tensor.size(logits_tensor.dim() - 1))); + } } return batch_len - 1; }