Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions vllm/model_executor/layers/rotary_embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,23 @@ def get_rope(
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow")
}
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor, dtype,
**extra_kwargs)
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved",
False),
scaling_factor=scaling_factor,
**extra_kwargs)
else:
rotary_emb = YaRNScalingRotaryEmbedding(
head_size, rotary_dim, original_max_position, base,
is_neox_style, scaling_factor, dtype, **extra_kwargs)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling[
Expand Down
31 changes: 31 additions & 0 deletions vllm/model_executor/layers/rotary_embedding/mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .base import RotaryEmbedding
from .common import apply_rotary_emb_dispatch
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale


@triton.jit
Expand Down Expand Up @@ -213,7 +214,27 @@ def __init__(
dtype: torch.dtype,
mrope_section: Optional[list[int]] = None,
mrope_interleaved: bool = False,
# YaRN parameters.
*,
scaling_factor: Optional[float] = None,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:

self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
if self.scaling_factor is not None:
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(
yarn_get_mscale(self.scaling_factor) * attn_factor)
else:
self.mscale = 1.0

# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
# a larger the cos and sin cache.
Expand All @@ -226,6 +247,16 @@ def __init__(
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2

def _compute_inv_freq(self, base: float) -> torch.Tensor:
if self.scaling_factor is None:
return super()._compute_inv_freq(base)
return YaRNScalingRotaryEmbedding._compute_inv_freq(self, base)

def _compute_cos_sin_cache(self) -> torch.Tensor:
if self.scaling_factor is None:
return super()._compute_cos_sin_cache()
return YaRNScalingRotaryEmbedding._compute_cos_sin_cache(self)

def forward_native(
self,
positions: torch.Tensor,
Expand Down