Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,16 @@ def forward(self, hidden_states):


class DeepseekV3RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, dtype=torch.get_default_dtype()):
super().__init__()

if dtype is None:
dtype = torch.get_default_dtype()
if not isinstance(dtype, torch.dtype):
raise TypeError(
f"DeepseekV3RotaryEmbedding: dtype must be a torch.dtype, got {type(dtype).__name__} with value {dtype}"
)

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
Expand All @@ -94,7 +101,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
self._set_cos_sin_cache(
seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
dtype=dtype,
)
self.max_seq_len_cached = None

Expand Down Expand Up @@ -122,6 +129,7 @@ def forward(self, x, seq_len=None):
class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding):
def __init__(
self,
dtype,
dim,
max_position_embeddings=2048,
base=10000,
Expand All @@ -139,7 +147,7 @@ def __init__(
self.beta_slow = beta_slow
self.mscale = mscale
self.mscale_all_dim = mscale_all_dim
super().__init__(dim, max_position_embeddings, base, device)
super().__init__(dim, max_position_embeddings, base, device, dtype)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
Expand Down Expand Up @@ -987,6 +995,7 @@ def __qeff_init__(self):
if key in self.config.rope_scaling
}
self.rotary_emb = DeepseekV3YarnRotaryEmbedding(
self.config.torch_dtype,
self.config.qk_rope_head_dim,
max_position_embeddings=MAX_POSITION_EMBEDDINGS,
scaling_factor=scaling_factor,
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, config: FalconConfig, device=None):
super().__init__(config=config)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, config: GemmaConfig, device=None):

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, config: Gemma2Config, device=None):

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
4 changes: 1 addition & 3 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=config.torch_dtype)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def __init__(self, config: GptOssConfig, device=None):
super().__init__(config=config)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class QEffGraniteRotaryEmbedding(GraniteRotaryEmbedding):
def __init__(self, config: GraniteConfig, device=None):
super().__init__(config=config)
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
super().__init__(config=config) # Initialize nn.Module

self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len: int, device=None, dtype=None):
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, config: LlamaConfig, device=None):
super().__init__(config=config)

self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, config: MistralConfig, device=None):

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, config: MixtralConfig, device=None):
super().__init__(config=config)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(self, config: MllamaConfig, device=None):
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len,
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
dtype=config.torch_dtype,
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/molmo/modeling_molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(self, config, device=None):
self.inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
self.original_max_seq_len = config.max_position_embeddings or config.max_sequence_length
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=_non_meta_init_device(config), dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=_non_meta_init_device(config), dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/olmo2/modeling_olmo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, config: Olmo2Config, device=None):
super().__init__(config=config)

self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, config: Phi3Config, device=None):
super().__init__(config=config)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, config: Qwen2Config, device=None):
super().__init__(config=config)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/qwen3/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, config: Qwen3Config, device=None):
super().__init__(config=config)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, config: Qwen3MoeConfig, device=None):

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, config: Qwen3VLTextConfig, device=None):
super().__init__(config=config)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(self, config: Qwen3VLMoeTextConfig, device=None):
super().__init__(config=config)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
Expand Down
Loading