diff --git a/examples/models/llama/source_transformation/attention_sink.py b/examples/models/llama/source_transformation/attention_sink.py index 8addf2a6cec..cada2ac4e6d 100644 --- a/examples/models/llama/source_transformation/attention_sink.py +++ b/examples/models/llama/source_transformation/attention_sink.py @@ -29,13 +29,16 @@ class RopeWithAttentionSink(Rope): """ Rope subclass for Attention Sink models. - For torch.export compatibility, this passes through the original position - unchanged - the sliding window is handled by the cache index management - (ring buffer), not by position shifting. + Remaps input positions using modular arithmetic so RoPE frequencies stay + within the cache size bounds, enabling generation beyond max_context_len. - Note: This class uses the model's max_context_len (params.max_context_len) for - RoPE frequency table size, which should be large enough to support generation - beyond the sliding window. The actual KV cache size is sink_size + window_size * 2. + Position mapping: + - Sink tokens (pos < sink_size): position preserved as-is + - Window tokens (pos >= sink_size): wrapped into ring buffer range + [sink_size, sink_size + ring_size) via modulo + + The ring buffer is 2x window_size, so the live window (window_size tokens) + never spans a wrap boundary, preserving correct relative distances in RoPE. """ def __init__( @@ -47,19 +50,48 @@ def __init__( super().__init__(params) self.window_size = window_size self.sink_size = sink_size - # max_context_len from params is used for RoPE frequencies (should be large) - self.max_context_length = self.params.max_context_len + self.ring_size = window_size * 2 + + def _remap_input_pos(self, input_pos: torch.Tensor) -> torch.Tensor: + """Remap positions: sink tokens stay, window tokens wrap in ring buffer.""" + return torch.where( + input_pos < self.sink_size, + input_pos, + self.sink_size + (input_pos - self.sink_size) % self.ring_size, + ) def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): """ - Get rotary embedding frequencies. - For attention sink, we use the original position - the sliding window - is handled by the cache index management, not by position shifting. + Get rotary embedding frequencies with position remapping. + + For dynamic shape mode (input_pos is a single start position), we remap + the start and use narrow. For static shape mode (input_pos is the full + position tensor), we remap all positions and index directly. """ assert input_pos is not None - # Use torch._check for export compatibility (data-dependent guard) - torch._check(input_pos[0].item() + seq_len <= self.max_context_length) - return super().get_freqs(input_pos, seq_len) + if not self.params.use_kv_cache: + return self.freqs_cos[:seq_len], self.freqs_sin[:seq_len] + + if self.params.enable_dynamic_shape: + # Dynamic shape: input_pos is [start_pos], remap and narrow + input_pos_item = input_pos[-1].item() + if input_pos_item < self.sink_size: + remapped_item = input_pos_item + else: + remapped_item = ( + self.sink_size + (input_pos_item - self.sink_size) % self.ring_size + ) + torch._check_is_size(remapped_item) + torch._check(remapped_item + seq_len <= self.sink_size + self.ring_size) + freqs_cos = self.freqs_cos.narrow(0, remapped_item, seq_len) + freqs_sin = self.freqs_sin.narrow(0, remapped_item, seq_len) + else: + # Static shape: remap full position tensor and index + remapped = self._remap_input_pos(input_pos) + freqs_cos = self.freqs_cos[remapped] + freqs_sin = self.freqs_sin[remapped] + + return freqs_cos, freqs_sin def _create_causal_mask_for_attention_sink( diff --git a/examples/models/llama/source_transformation/test_attention_sink.py b/examples/models/llama/source_transformation/test_attention_sink.py index 54cf1e57ac5..8cdb00951f2 100644 --- a/examples/models/llama/source_transformation/test_attention_sink.py +++ b/examples/models/llama/source_transformation/test_attention_sink.py @@ -398,6 +398,24 @@ def test_beyond_context_window_basic(self): torch.isfinite(out).all(), "Output contains non-finite values" ) + def test_beyond_max_context_len(self): + """Generate tokens beyond max_context_len with RoPE position remapping.""" + sink_size = 4 + window_size = 16 + # KV cache size = 36, max_context_len = 64 + # Generate 100 tokens — well beyond max_context_len + args = self._make_args(max_context_len=64) + model = self._build_model(args, sink_size, window_size, use_custom_sdpa=False) + + outputs = self._run_generation(model, args, num_tokens=100) + + self.assertEqual(len(outputs), 97) # 1 prefill + 96 decode steps + for out in outputs: + self.assertTrue( + torch.isfinite(out).all(), + "Output contains non-finite values beyond max_context_len", + ) + def test_beyond_context_window_custom_sdpa(self): """Generate tokens beyond context window with custom SDPA + custom KV cache.""" sink_size = 4