Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions examples/models/llama/source_transformation/attention_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@ class RopeWithAttentionSink(Rope):
"""
Rope subclass for Attention Sink models.

For torch.export compatibility, this passes through the original position
unchanged - the sliding window is handled by the cache index management
(ring buffer), not by position shifting.
Remaps input positions using modular arithmetic so RoPE frequencies stay
within the cache size bounds, enabling generation beyond max_context_len.

Note: This class uses the model's max_context_len (params.max_context_len) for
RoPE frequency table size, which should be large enough to support generation
beyond the sliding window. The actual KV cache size is sink_size + window_size * 2.
Position mapping:
- Sink tokens (pos < sink_size): position preserved as-is
- Window tokens (pos >= sink_size): wrapped into ring buffer range
[sink_size, sink_size + ring_size) via modulo

The ring buffer is 2x window_size, so the live window (window_size tokens)
never spans a wrap boundary, preserving correct relative distances in RoPE.
"""

def __init__(
Expand All @@ -47,19 +50,48 @@ def __init__(
super().__init__(params)
self.window_size = window_size
self.sink_size = sink_size
# max_context_len from params is used for RoPE frequencies (should be large)
self.max_context_length = self.params.max_context_len
self.ring_size = window_size * 2

def _remap_input_pos(self, input_pos: torch.Tensor) -> torch.Tensor:
"""Remap positions: sink tokens stay, window tokens wrap in ring buffer."""
return torch.where(
input_pos < self.sink_size,
input_pos,
self.sink_size + (input_pos - self.sink_size) % self.ring_size,
)

def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int):
"""
Get rotary embedding frequencies.
For attention sink, we use the original position - the sliding window
is handled by the cache index management, not by position shifting.
Get rotary embedding frequencies with position remapping.

For dynamic shape mode (input_pos is a single start position), we remap
the start and use narrow. For static shape mode (input_pos is the full
position tensor), we remap all positions and index directly.
"""
assert input_pos is not None
# Use torch._check for export compatibility (data-dependent guard)
torch._check(input_pos[0].item() + seq_len <= self.max_context_length)
return super().get_freqs(input_pos, seq_len)
if not self.params.use_kv_cache:
return self.freqs_cos[:seq_len], self.freqs_sin[:seq_len]
Comment on lines 71 to +73

if self.params.enable_dynamic_shape:
# Dynamic shape: input_pos is [start_pos], remap and narrow
input_pos_item = input_pos[-1].item()
if input_pos_item < self.sink_size:
remapped_item = input_pos_item
else:
remapped_item = (
self.sink_size + (input_pos_item - self.sink_size) % self.ring_size
)
torch._check_is_size(remapped_item)
torch._check(remapped_item + seq_len <= self.sink_size + self.ring_size)
freqs_cos = self.freqs_cos.narrow(0, remapped_item, seq_len)
freqs_sin = self.freqs_sin.narrow(0, remapped_item, seq_len)
else:
# Static shape: remap full position tensor and index
remapped = self._remap_input_pos(input_pos)
freqs_cos = self.freqs_cos[remapped]
freqs_sin = self.freqs_sin[remapped]

return freqs_cos, freqs_sin


def _create_causal_mask_for_attention_sink(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,24 @@ def test_beyond_context_window_basic(self):
torch.isfinite(out).all(), "Output contains non-finite values"
)

def test_beyond_max_context_len(self):
"""Generate tokens beyond max_context_len with RoPE position remapping."""
sink_size = 4
window_size = 16
# KV cache size = 36, max_context_len = 64
# Generate 100 tokens — well beyond max_context_len
args = self._make_args(max_context_len=64)
model = self._build_model(args, sink_size, window_size, use_custom_sdpa=False)

outputs = self._run_generation(model, args, num_tokens=100)

self.assertEqual(len(outputs), 97) # 1 prefill + 96 decode steps
for out in outputs:
self.assertTrue(
torch.isfinite(out).all(),
"Output contains non-finite values beyond max_context_len",
)
Comment on lines +401 to +417

def test_beyond_context_window_custom_sdpa(self):
"""Generate tokens beyond context window with custom SDPA + custom KV cache."""
sink_size = 4
Expand Down
Loading