diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index c9653aa9e440..3576368981c7 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -153,11 +153,23 @@ def get_rope( if k in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") } - rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, - original_max_position, - base, is_neox_style, - scaling_factor, dtype, - **extra_kwargs) + if "mrope_section" in rope_scaling: + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + mrope_interleaved=rope_scaling.get("mrope_interleaved", + False), + scaling_factor=scaling_factor, + **extra_kwargs) + else: + rotary_emb = YaRNScalingRotaryEmbedding( + head_size, rotary_dim, original_max_position, base, + is_neox_style, scaling_factor, dtype, **extra_kwargs) elif scaling_type == "deepseek_yarn": scaling_factor = rope_scaling["factor"] original_max_position = rope_scaling[ diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 17d04a1ad715..9bf0d6bd15e7 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -12,6 +12,7 @@ from .base import RotaryEmbedding from .common import apply_rotary_emb_dispatch +from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale @triton.jit @@ -213,7 +214,27 @@ def __init__( dtype: torch.dtype, mrope_section: Optional[list[int]] = None, mrope_interleaved: bool = False, + # YaRN parameters. + *, + scaling_factor: Optional[float] = None, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, ) -> None: + + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + if self.scaling_factor is not None: + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + yarn_get_mscale(self.scaling_factor) * attn_factor) + else: + self.mscale = 1.0 + # In Qwen2.5-VL, the maximum index value is related to the duration of # the input video. We enlarge max_position_embeddings to 4 times to get # a larger the cos and sin cache. @@ -226,6 +247,16 @@ def __init__( if self.mrope_section: assert sum(self.mrope_section) == rotary_dim // 2 + def _compute_inv_freq(self, base: float) -> torch.Tensor: + if self.scaling_factor is None: + return super()._compute_inv_freq(base) + return YaRNScalingRotaryEmbedding._compute_inv_freq(self, base) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + if self.scaling_factor is None: + return super()._compute_cos_sin_cache() + return YaRNScalingRotaryEmbedding._compute_cos_sin_cache(self) + def forward_native( self, positions: torch.Tensor,