diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 8d2641d9d78..d8b43b80415 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -46,6 +46,7 @@ class ModelArgs: rope_freq_base: float = 10000.0 # The base frequency for RoPE. Keep it for BC. use_scaled_rope: bool = False # Use scaled RoPE, introduced in llama3.1. rope_scale_factor: int = 8 + high_freq_factor: int = 4 # Additional Model Metadata needed at runtime bos_idx: int = 1 eos_idx: int = 3 diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index 02eb564ed76..ad69f159e7c 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -17,10 +17,9 @@ # ======================== Stock Implementation ======================== -def apply_scaling(freqs: torch.Tensor, scale_factor: int): +def apply_scaling(freqs: torch.Tensor, scale_factor: int, high_freq_factor: int): # Values obtained from grid search low_freq_factor = 1 - high_freq_factor = 4 old_context_len = 8192 # original llama3 length low_freq_wavelen = old_context_len / low_freq_factor @@ -47,6 +46,7 @@ def precompute_freqs_cis( theta: float = 10000.0, use_scaled: bool = False, scale_factor: Optional[int] = None, + high_freq_factor: int = 4, ): freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim) @@ -54,7 +54,7 @@ def precompute_freqs_cis( t = torch.arange(end, device=freqs.device) # pyre-ignore if use_scaled: assert scale_factor is not None - freqs = apply_scaling(freqs, scale_factor) # pyre-ignore + freqs = apply_scaling(freqs, scale_factor, high_freq_factor) # pyre-ignore freqs = torch.outer(t, freqs).float() freqs_cos = torch.cos(freqs) freqs_sin = torch.sin(freqs) @@ -242,6 +242,7 @@ def __init__(self, params: ModelArgs): precompute_freqs_cis, use_scaled=self.params.use_scaled_rope, scale_factor=self.params.rope_scale_factor, + high_freq_factor=self.params.high_freq_factor, ) self.apply_rotary_emb = RotaryEmbedding()