Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions examples/models/llama/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,32 @@ def hf_precompute_freqs_cis(
# Partial rotary embeddings.
dim = int(dim * partial_rotary_factor)

# Compute the RoPE table in fp64 to minimize ULP-level drift; cast to fp32
# once at the end. Phi-4 Mini's narrow decode-time logit margins make the
# exported model sensitive to 1-ULP differences in freqs_cos / freqs_sin
# under sampling, especially on the Vulkan delegate.
# fp64 precompute is required whenever cos/sin will be scaled by a
# non-trivial attention_factor (LongRoPE on Phi-3 / Phi-4 family). There,
# fp32 ULP-level rounding in the table is load-bearing on Vulkan under
# sampling -- a fp32-only regression manifests as decode-time n-gram
# looping, not a unit-test red. For vanilla HF RoPE, fp32 throughout
# produces cos/sin tables bit-identical to the non-HF precompute_freqs_cis
# path, which the static-attention vs MHA parity tests rely on.
#
# If you add a new model that needs cos/sin scaling but does not set
# short_factor / long_factor / attention_factor, extend the gate below.
longrope_active = (short_factor is not None) or (long_factor is not None)
needs_fp64 = longrope_active or (
attention_factor is not None and attention_factor != 1.0
)
compute_dtype = torch.float64 if needs_fp64 else torch.float32

inv_freq = 1.0 / (
theta
** (
torch.arange(0, dim, 2, device=device, dtype=torch.int64).to(torch.float64)
torch.arange(0, dim, 2, device=device, dtype=torch.int64).to(compute_dtype)
/ dim
)
)

# LongRoPE: divide inv_freq element-wise by short_factor or long_factor.
# Selection mirrors HF: long_factor when seq_len > original_max_position_embeddings.
longrope_active = (short_factor is not None) or (long_factor is not None)
if longrope_active:
chosen = (
long_factor
Expand All @@ -178,7 +189,7 @@ def hf_precompute_freqs_cis(
if chosen is None:
# Fall back to whichever factor was provided.
chosen = short_factor if long_factor is None else long_factor
ext_factors = torch.tensor(chosen, dtype=torch.float64, device=device)
ext_factors = torch.tensor(chosen, dtype=compute_dtype, device=device)
assert ext_factors.numel() == inv_freq.numel(), (
f"LongRoPE factor length {ext_factors.numel()} must equal dim/2 "
f"({inv_freq.numel()})"
Expand All @@ -200,8 +211,8 @@ def hf_precompute_freqs_cis(
)

# pyre-ignore Undefined attribute [16]: `float` has no attribute `device`.
t = torch.arange(end, device=inv_freq.device, dtype=torch.int64).to(torch.float64)
freqs = torch.outer(t, inv_freq).to(torch.float64) # pyre-ignore
t = torch.arange(end, device=inv_freq.device, dtype=torch.int64).to(compute_dtype)
freqs = torch.outer(t, inv_freq).to(compute_dtype) # pyre-ignore
emb = torch.cat((freqs, freqs), dim=-1)
cos_tab = torch.cos(emb)
sin_tab = torch.sin(emb)
Expand Down
Loading