diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 6d6c3f7375..3dad27103a 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -81,9 +81,16 @@ def forward(self, hidden_states): class DeepseekV3RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, dtype=torch.get_default_dtype()): super().__init__() + if dtype is None: + dtype = torch.get_default_dtype() + if not isinstance(dtype, torch.dtype): + raise TypeError( + f"DeepseekV3RotaryEmbedding: dtype must be a torch.dtype, got {type(dtype).__name__} with value {dtype}" + ) + self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -94,7 +101,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, - dtype=torch.get_default_dtype(), + dtype=dtype, ) self.max_seq_len_cached = None @@ -122,6 +129,7 @@ def forward(self, x, seq_len=None): class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): def __init__( self, + dtype, dim, max_position_embeddings=2048, base=10000, @@ -139,7 +147,7 @@ def __init__( self.beta_slow = beta_slow self.mscale = mscale self.mscale_all_dim = mscale_all_dim - super().__init__(dim, max_position_embeddings, base, device) + super().__init__(dim, max_position_embeddings, base, device, dtype) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len @@ -987,6 +995,7 @@ def __qeff_init__(self): if key in self.config.rope_scaling } self.rotary_emb = DeepseekV3YarnRotaryEmbedding( + self.config.torch_dtype, self.config.qk_rope_head_dim, max_position_embeddings=MAX_POSITION_EMBEDDINGS, scaling_factor=scaling_factor, diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 9d6271cc84..7987200d6c 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -46,7 +46,7 @@ def __init__(self, config: FalconConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 481cbd9d19..8a99e5e178 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -48,7 +48,7 @@ def __init__(self, config: GemmaConfig, device=None): # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 38845f1f9c..3caedf94af 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -51,7 +51,7 @@ def __init__(self, config: Gemma2Config, device=None): # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index c0b7053ab6..3ca5e82ef7 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -85,9 +85,7 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=config.torch_dtype) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 6f805bfd4c..9b9d64cc8b 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -520,7 +520,7 @@ def __init__(self, config: GptOssConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 9c004e72ca..1373080197 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -46,7 +46,7 @@ class QEffGraniteRotaryEmbedding(GraniteRotaryEmbedding): def __init__(self, config: GraniteConfig, device=None): super().__init__(config=config) self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 8728b4d3e4..65349d72ea 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -52,7 +52,7 @@ def __init__( super().__init__(config=config) # Initialize nn.Module self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len: int, device=None, dtype=None): diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 0c30cc68fb..697e10a907 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -47,7 +47,7 @@ def __init__(self, config: LlamaConfig, device=None): super().__init__(config=config) self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index abad6ca04b..f782773f1f 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -51,7 +51,7 @@ def __init__(self, config: MistralConfig, device=None): # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 637787d3e7..7b8b4aad13 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -52,7 +52,7 @@ def __init__(self, config: MixtralConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index e9b661103f..5c498711c1 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -110,7 +110,7 @@ def __init__(self, config: MllamaConfig, device=None): self._set_cos_sin_cache( seq_len=self.original_max_seq_len, device=self.inv_freq.device, - dtype=torch.get_default_dtype(), + dtype=config.torch_dtype, ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py index cf1beddc87..0e545d8eab 100644 --- a/QEfficient/transformers/models/molmo/modeling_molmo.py +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -158,7 +158,7 @@ def __init__(self, config, device=None): self.inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)) self.original_max_seq_len = config.max_position_embeddings or config.max_sequence_length self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=_non_meta_init_device(config), dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=_non_meta_init_device(config), dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index 358526077a..ad4c32df47 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -41,7 +41,7 @@ def __init__(self, config: Olmo2Config, device=None): super().__init__(config=config) self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 1df0590f2f..556253e82a 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -41,7 +41,7 @@ def __init__(self, config: Phi3Config, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 6512255599..db709a9c09 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -51,7 +51,7 @@ def __init__(self, config: Qwen2Config, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index a20350205e..ccdff8160d 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -51,7 +51,7 @@ def __init__(self, config: Qwen3Config, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index de92eae8f7..d05a4a1b06 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -43,7 +43,7 @@ def __init__(self, config: Qwen3MoeConfig, device=None): # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 21847f25de..045ce62b54 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -112,7 +112,7 @@ def __init__(self, config: Qwen3VLTextConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 240d04d996..19ca83bfb1 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -111,7 +111,7 @@ def __init__(self, config: Qwen3VLMoeTextConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype):