From dab2eb038faf3a2c150b7e83c2cfae0f9858e4b8 Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Fri, 26 Sep 2025 11:02:54 -0700 Subject: [PATCH] StaticAttentionIOManager: Fix out of bound errors on precomuted RoPE frequencies Summary: The precomputed RoPE frequencies can run out despite the KV caches are circular buffers by default. Differential Revision: D83361153 --- .../llama/runner/static_attention_io_manager.h | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/examples/models/llama/runner/static_attention_io_manager.h b/examples/models/llama/runner/static_attention_io_manager.h index db7add8d16a..e2d2bc40c60 100644 --- a/examples/models/llama/runner/static_attention_io_manager.h +++ b/examples/models/llama/runner/static_attention_io_manager.h @@ -434,6 +434,7 @@ class StaticAttentionIOManager { std::vector k_cache_output_indices; std::vector v_cache_input_indices; std::vector v_cache_output_indices; + size_t max_context_len{}; RopeT* rope_freqs_cos; RopeT* rope_freqs_sin; StaticAttentionUpdateStyle style = StaticAttentionUpdateStyle::SMART_MASK; @@ -604,6 +605,10 @@ class StaticAttentionIOManager { size_t batch_len = 0; for (size_t i = 0; i < tokens.size(); i += input_len) { batch_len = std::min(input_len, tokens.size() - i); + if (input_pos_ + batch_len > config_.max_context_len) { + ET_LOG(Error, "Maximum context size reached, stopping prefill."); + return input_len - 1; + } std::copy(&tokens[i], &tokens[i + batch_len], input_buffer.begin()); prepare(method); ET_CHECK(method.execute() == executorch::runtime::Error::Ok); @@ -646,6 +651,10 @@ class StaticAttentionIOManager { while (true) { input_buffer[0] = prev_tok; + if (input_pos_ + 1 > config_.max_context_len) { + ET_LOG(Error, "Maximum context size reached, stopping decode."); + break; + } prepare(method); ET_CHECK(method.execute() == executorch::runtime::Error::Ok); update( @@ -730,6 +739,11 @@ class StaticAttentionIOManager { } // Setup input pointers and RoPE frequencies. + if (input_pos_ + ngram_size > config_.max_context_len) { + ET_LOG( + Error, "Maximum context size reached, stopping lookahead decode."); + break; + } prepare( method, executorch::runtime::Span(pos_offsets.data(), pos_offsets.size()));