Enable infinite generation with RoPE position remapping for attention sink (#19011)#19011
Enable infinite generation with RoPE position remapping for attention sink (#19011)#19011meta-codesync[bot] merged 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19011
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 3 Unrelated FailuresAs of commit 9ae6844 with merge base 1d37abd ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D100728748. |
This PR needs a
|
… sink (#19011) Summary: Previously, attention sink models could not generate beyond max_context_len because RoPE used the raw monotonic input_pos to index into the pre-computed freqs_cis table, causing OOB when pos >= max_context_len. This change adds position remapping in RopeWithAttentionSink: - Sink token positions (< sink_size) are preserved as-is - Window token positions are wrapped into the ring buffer range [sink_size, sink_size + ring_size) using modular arithmetic The 2x ring buffer (ring_size = 2 * window_size) ensures the live window of tokens never spans a wrap boundary, preserving correct relative distances in RoPE space. This enables attention sink models to generate indefinitely — the KV cache ring buffer recycles space while RoPE positions stay bounded. Reviewed By: lucylq Differential Revision: D100728748
7cce9a4 to
2a34458
Compare
… sink (#19011) Summary: Previously, attention sink models could not generate beyond max_context_len because RoPE used the raw monotonic input_pos to index into the pre-computed freqs_cis table, causing OOB when pos >= max_context_len. This change adds position remapping in RopeWithAttentionSink: - Sink token positions (< sink_size) are preserved as-is - Window token positions are wrapped into the ring buffer range [sink_size, sink_size + ring_size) using modular arithmetic The 2x ring buffer (ring_size = 2 * window_size) ensures the live window of tokens never spans a wrap boundary, preserving correct relative distances in RoPE space. This enables attention sink models to generate indefinitely — the KV cache ring buffer recycles space while RoPE positions stay bounded. Reviewed By: lucylq Differential Revision: D100728748
2a34458 to
a6472e5
Compare
There was a problem hiding this comment.
Pull request overview
This PR enables “infinite” token generation for LLaMA attention-sink models by remapping RoPE positions into a bounded range aligned with the KV-cache ring buffer, preventing out-of-bounds indexing when decoding past max_context_len.
Changes:
- Add RoPE position remapping logic in
RopeWithAttentionSink.get_freqs(sink positions preserved; window positions wrapped into[sink_size, sink_size + 2*window_size)). - Add an end-to-end test that generates beyond
max_context_lenand validates outputs remain finite.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| examples/models/llama/source_transformation/attention_sink.py | Implements RoPE position remapping for attention-sink + ring-buffer KV cache to avoid OOB past max_context_len. |
| examples/models/llama/source_transformation/test_attention_sink.py | Adds E2E regression coverage for generating beyond max_context_len. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| assert input_pos is not None | ||
| # Use torch._check for export compatibility (data-dependent guard) | ||
| torch._check(input_pos[0].item() + seq_len <= self.max_context_length) | ||
| return super().get_freqs(input_pos, seq_len) | ||
| if not self.params.use_kv_cache: | ||
| return self.freqs_cos[:seq_len], self.freqs_sin[:seq_len] |
| self.sink_size = sink_size | ||
| # max_context_len from params is used for RoPE frequencies (should be large) | ||
| self.max_context_length = self.params.max_context_len | ||
| self.ring_size = window_size * 2 |
| def test_beyond_max_context_len(self): | ||
| """Generate tokens beyond max_context_len with RoPE position remapping.""" | ||
| sink_size = 4 | ||
| window_size = 16 | ||
| # KV cache size = 36, max_context_len = 64 | ||
| # Generate 100 tokens — well beyond max_context_len | ||
| args = self._make_args(max_context_len=64) | ||
| model = self._build_model(args, sink_size, window_size, use_custom_sdpa=False) | ||
|
|
||
| outputs = self._run_generation(model, args, num_tokens=100) | ||
|
|
||
| self.assertEqual(len(outputs), 97) # 1 prefill + 96 decode steps | ||
| for out in outputs: | ||
| self.assertTrue( | ||
| torch.isfinite(out).all(), | ||
| "Output contains non-finite values beyond max_context_len", | ||
| ) |
| # Dynamic shape: input_pos is [start_pos], remap and narrow | ||
| input_pos_item = input_pos[-1].item() | ||
| if input_pos_item < self.sink_size: | ||
| remapped_item = input_pos_item | ||
| else: | ||
| remapped_item = ( | ||
| self.sink_size | ||
| + (input_pos_item - self.sink_size) % self.ring_size | ||
| ) |
… sink (#19011) Summary: Pull Request resolved: #19011 Previously, attention sink models could not generate beyond max_context_len because RoPE used the raw monotonic input_pos to index into the pre-computed freqs_cis table, causing OOB when pos >= max_context_len. This change adds position remapping in RopeWithAttentionSink: - Sink token positions (< sink_size) are preserved as-is - Window token positions are wrapped into the ring buffer range [sink_size, sink_size + ring_size) using modular arithmetic The 2x ring buffer (ring_size = 2 * window_size) ensures the live window of tokens never spans a wrap boundary, preserving correct relative distances in RoPE space. This enables attention sink models to generate indefinitely — the KV cache ring buffer recycles space while RoPE positions stay bounded. Reviewed By: lucylq Differential Revision: D100728748
a6472e5 to
a451868
Compare
… sink (#19011) Summary: Previously, attention sink models could not generate beyond max_context_len because RoPE used the raw monotonic input_pos to index into the pre-computed freqs_cis table, causing OOB when pos >= max_context_len. This change adds position remapping in RopeWithAttentionSink: - Sink token positions (< sink_size) are preserved as-is - Window token positions are wrapped into the ring buffer range [sink_size, sink_size + ring_size) using modular arithmetic The 2x ring buffer (ring_size = 2 * window_size) ensures the live window of tokens never spans a wrap boundary, preserving correct relative distances in RoPE space. This enables attention sink models to generate indefinitely — the KV cache ring buffer recycles space while RoPE positions stay bounded. Reviewed By: lucylq Differential Revision: D100728748
a451868 to
db1328f
Compare
… sink (#19011) Summary: Previously, attention sink models could not generate beyond max_context_len because RoPE used the raw monotonic input_pos to index into the pre-computed freqs_cis table, causing OOB when pos >= max_context_len. This change adds position remapping in RopeWithAttentionSink: - Sink token positions (< sink_size) are preserved as-is - Window token positions are wrapped into the ring buffer range [sink_size, sink_size + ring_size) using modular arithmetic The 2x ring buffer (ring_size = 2 * window_size) ensures the live window of tokens never spans a wrap boundary, preserving correct relative distances in RoPE space. This enables attention sink models to generate indefinitely — the KV cache ring buffer recycles space while RoPE positions stay bounded. Reviewed By: lucylq Differential Revision: D100728748
db1328f to
cdf3644
Compare
… sink (#19011) Summary: Pull Request resolved: #19011 Previously, attention sink models could not generate beyond max_context_len because RoPE used the raw monotonic input_pos to index into the pre-computed freqs_cis table, causing OOB when pos >= max_context_len. This change adds position remapping in RopeWithAttentionSink: - Sink token positions (< sink_size) are preserved as-is - Window token positions are wrapped into the ring buffer range [sink_size, sink_size + ring_size) using modular arithmetic The 2x ring buffer (ring_size = 2 * window_size) ensures the live window of tokens never spans a wrap boundary, preserving correct relative distances in RoPE space. This enables attention sink models to generate indefinitely — the KV cache ring buffer recycles space while RoPE positions stay bounded. Reviewed By: lucylq Differential Revision: D100728748
cdf3644 to
5faf3a6
Compare
… sink (#19011) Summary: Previously, attention sink models could not generate beyond max_context_len because RoPE used the raw monotonic input_pos to index into the pre-computed freqs_cis table, causing OOB when pos >= max_context_len. This change adds position remapping in RopeWithAttentionSink: - Sink token positions (< sink_size) are preserved as-is - Window token positions are wrapped into the ring buffer range [sink_size, sink_size + ring_size) using modular arithmetic The 2x ring buffer (ring_size = 2 * window_size) ensures the live window of tokens never spans a wrap boundary, preserving correct relative distances in RoPE space. This enables attention sink models to generate indefinitely — the KV cache ring buffer recycles space while RoPE positions stay bounded. Reviewed By: lucylq Differential Revision: D100728748
5faf3a6 to
bfff183
Compare
… sink (#19011) Summary: Pull Request resolved: #19011 Previously, attention sink models could not generate beyond max_context_len because RoPE used the raw monotonic input_pos to index into the pre-computed freqs_cis table, causing OOB when pos >= max_context_len. This change adds position remapping in RopeWithAttentionSink: - Sink token positions (< sink_size) are preserved as-is - Window token positions are wrapped into the ring buffer range [sink_size, sink_size + ring_size) using modular arithmetic The 2x ring buffer (ring_size = 2 * window_size) ensures the live window of tokens never spans a wrap boundary, preserving correct relative distances in RoPE space. This enables attention sink models to generate indefinitely — the KV cache ring buffer recycles space while RoPE positions stay bounded. Reviewed By: lucylq Differential Revision: D100728748
bfff183 to
311be20
Compare
… sink (#19011) Summary: Previously, attention sink models could not generate beyond max_context_len because RoPE used the raw monotonic input_pos to index into the pre-computed freqs_cis table, causing OOB when pos >= max_context_len. This change adds position remapping in RopeWithAttentionSink: - Sink token positions (< sink_size) are preserved as-is - Window token positions are wrapped into the ring buffer range [sink_size, sink_size + ring_size) using modular arithmetic The 2x ring buffer (ring_size = 2 * window_size) ensures the live window of tokens never spans a wrap boundary, preserving correct relative distances in RoPE space. This enables attention sink models to generate indefinitely — the KV cache ring buffer recycles space while RoPE positions stay bounded. Reviewed By: lucylq Differential Revision: D100728748
311be20 to
9ae6844
Compare
Summary:
Previously, attention sink models could not generate beyond max_context_len
because RoPE used the raw monotonic input_pos to index into the pre-computed
freqs_cis table, causing OOB when pos >= max_context_len.
This change adds position remapping in RopeWithAttentionSink:
[sink_size, sink_size + ring_size) using modular arithmetic
The 2x ring buffer (ring_size = 2 * window_size) ensures the live window
of tokens never spans a wrap boundary, preserving correct relative
distances in RoPE space.
This enables attention sink models to generate indefinitely — the KV cache
ring buffer recycles space while RoPE positions stay bounded.
Reviewed By: lucylq
Differential Revision: D100728748