From 74401d2651ec2454a828edd0d77e272d6c5b8603 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Fri, 15 May 2026 13:50:49 +0530 Subject: [PATCH 1/5] Updated the ROPE dtype for custom_dtype Signed-off-by: Asmita Goswami --- QEfficient/transformers/models/falcon/modeling_falcon.py | 2 +- QEfficient/transformers/models/gemma/modeling_gemma.py | 2 +- QEfficient/transformers/models/gemma2/modeling_gemma2.py | 2 +- QEfficient/transformers/models/gemma3/modeling_gemma3.py | 2 +- QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py | 2 +- QEfficient/transformers/models/granite/modeling_granite.py | 2 +- .../transformers/models/granitemoe/modeling_granitemoe.py | 2 +- QEfficient/transformers/models/llama/modeling_llama.py | 2 +- QEfficient/transformers/models/mistral/modeling_mistral.py | 2 +- QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py | 2 +- QEfficient/transformers/models/mllama/modeling_mllama.py | 2 +- QEfficient/transformers/models/molmo/modeling_molmo.py | 2 +- QEfficient/transformers/models/olmo2/modeling_olmo2.py | 2 +- QEfficient/transformers/models/phi3/modeling_phi3.py | 2 +- QEfficient/transformers/models/qwen2/modeling_qwen2.py | 2 +- QEfficient/transformers/models/qwen3/modeling_qwen3.py | 2 +- QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py | 2 +- QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py | 2 +- .../transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py | 2 +- 19 files changed, 19 insertions(+), 19 deletions(-) diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 9d6271cc84..7987200d6c 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -46,7 +46,7 @@ def __init__(self, config: FalconConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index 481cbd9d19..8a99e5e178 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -48,7 +48,7 @@ def __init__(self, config: GemmaConfig, device=None): # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index 38845f1f9c..3caedf94af 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -51,7 +51,7 @@ def __init__(self, config: Gemma2Config, device=None): # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index c0b7053ab6..e17257747e 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -86,7 +86,7 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 6f805bfd4c..9b9d64cc8b 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -520,7 +520,7 @@ def __init__(self, config: GptOssConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 9c004e72ca..1373080197 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -46,7 +46,7 @@ class QEffGraniteRotaryEmbedding(GraniteRotaryEmbedding): def __init__(self, config: GraniteConfig, device=None): super().__init__(config=config) self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 8728b4d3e4..65349d72ea 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -52,7 +52,7 @@ def __init__( super().__init__(config=config) # Initialize nn.Module self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len: int, device=None, dtype=None): diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 0c30cc68fb..697e10a907 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -47,7 +47,7 @@ def __init__(self, config: LlamaConfig, device=None): super().__init__(config=config) self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index abad6ca04b..f782773f1f 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -51,7 +51,7 @@ def __init__(self, config: MistralConfig, device=None): # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 637787d3e7..7b8b4aad13 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -52,7 +52,7 @@ def __init__(self, config: MixtralConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index e9b661103f..5c498711c1 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -110,7 +110,7 @@ def __init__(self, config: MllamaConfig, device=None): self._set_cos_sin_cache( seq_len=self.original_max_seq_len, device=self.inv_freq.device, - dtype=torch.get_default_dtype(), + dtype=config.torch_dtype, ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py index 438eeee4ed..22b87b87ca 100644 --- a/QEfficient/transformers/models/molmo/modeling_molmo.py +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -158,7 +158,7 @@ def __init__(self, config, device=None): self.inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)) self.original_max_seq_len = config.max_position_embeddings or config.max_sequence_length self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=_non_meta_init_device(config), dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=_non_meta_init_device(config), dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index 358526077a..ad4c32df47 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -41,7 +41,7 @@ def __init__(self, config: Olmo2Config, device=None): super().__init__(config=config) self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 1df0590f2f..556253e82a 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -41,7 +41,7 @@ def __init__(self, config: Phi3Config, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 6512255599..db709a9c09 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -51,7 +51,7 @@ def __init__(self, config: Qwen2Config, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index a20350205e..ccdff8160d 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -51,7 +51,7 @@ def __init__(self, config: Qwen3Config, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index de92eae8f7..d05a4a1b06 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -43,7 +43,7 @@ def __init__(self, config: Qwen3MoeConfig, device=None): # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 21847f25de..045ce62b54 100644 --- a/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -112,7 +112,7 @@ def __init__(self, config: Qwen3VLTextConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): diff --git a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 240d04d996..19ca83bfb1 100644 --- a/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -111,7 +111,7 @@ def __init__(self, config: Qwen3VLMoeTextConfig, device=None): super().__init__(config=config) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=config.torch_dtype ) def _set_cos_sin_cache(self, seq_len, device, dtype): From e679712234d2c92df92a27da14f4a0e6ea05a834 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Fri, 15 May 2026 13:51:30 +0530 Subject: [PATCH 2/5] Ruff format Signed-off-by: Asmita Goswami --- QEfficient/transformers/models/gemma3/modeling_gemma3.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index e17257747e..3ca5e82ef7 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -85,9 +85,7 @@ def __init__(self, dim, config, max_position_embeddings=2048, base=10000, device self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=config.torch_dtype - ) + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=config.torch_dtype) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len From 0d27aa62b72454fa436ace88c099ca172c41f6f4 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Mon, 18 May 2026 12:30:50 +0530 Subject: [PATCH 3/5] Added Deepseek custom_dtype rope support Signed-off-by: Asmita Goswami --- .../transformers/models/deepseek_v3/modeling_deepseek.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 6d6c3f7375..948b9c9992 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -81,7 +81,7 @@ def forward(self, hidden_states): class DeepseekV3RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, dtype=torch.get_default_dtype()): super().__init__() self.dim = dim @@ -94,7 +94,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, - dtype=torch.get_default_dtype(), + dtype=dtype, ) self.max_seq_len_cached = None @@ -122,6 +122,7 @@ def forward(self, x, seq_len=None): class DeepseekV3YarnRotaryEmbedding(DeepseekV3RotaryEmbedding): def __init__( self, + dtype, dim, max_position_embeddings=2048, base=10000, @@ -139,7 +140,7 @@ def __init__( self.beta_slow = beta_slow self.mscale = mscale self.mscale_all_dim = mscale_all_dim - super().__init__(dim, max_position_embeddings, base, device) + super().__init__(dim, max_position_embeddings, base, device, dtype) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len @@ -987,6 +988,7 @@ def __qeff_init__(self): if key in self.config.rope_scaling } self.rotary_emb = DeepseekV3YarnRotaryEmbedding( + self.config.torch_dtype, self.config.qk_rope_head_dim, max_position_embeddings=MAX_POSITION_EMBEDDINGS, scaling_factor=scaling_factor, From e18ca492b441053e353e7338eedf0374db8a8da5 Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Mon, 18 May 2026 14:35:21 +0530 Subject: [PATCH 4/5] Qgenie comments Signed-off-by: Asmita Goswami --- .../transformers/models/deepseek_v3/modeling_deepseek.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 948b9c9992..486334a97c 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -84,6 +84,11 @@ class DeepseekV3RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, dtype=torch.get_default_dtype()): super().__init__() + if dtype is None: + dtype = torch.get_default_dtype() + if not isinstance(dtype, torch.dtype): + raise TypeError(f"dtype must be a torch.dtype, got {type(dtype)}") + self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base @@ -140,7 +145,8 @@ def __init__( self.beta_slow = beta_slow self.mscale = mscale self.mscale_all_dim = mscale_all_dim - super().__init__(dim, max_position_embeddings, base, device, dtype) + self.dtype = dtype if dtype is not None else torch.get_default_dtype() + super().__init__(dim, max_position_embeddings, base, device, self.dtype) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len From 5b6ff0f11475cb57adadc2b004b9dacc4fc0509d Mon Sep 17 00:00:00 2001 From: Asmita Goswami Date: Mon, 18 May 2026 15:07:52 +0530 Subject: [PATCH 5/5] Qgenie comments Signed-off-by: Asmita Goswami --- .../transformers/models/deepseek_v3/modeling_deepseek.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py index 486334a97c..3dad27103a 100644 --- a/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py +++ b/QEfficient/transformers/models/deepseek_v3/modeling_deepseek.py @@ -87,7 +87,9 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, d if dtype is None: dtype = torch.get_default_dtype() if not isinstance(dtype, torch.dtype): - raise TypeError(f"dtype must be a torch.dtype, got {type(dtype)}") + raise TypeError( + f"DeepseekV3RotaryEmbedding: dtype must be a torch.dtype, got {type(dtype).__name__} with value {dtype}" + ) self.dim = dim self.max_position_embeddings = max_position_embeddings @@ -145,8 +147,7 @@ def __init__( self.beta_slow = beta_slow self.mscale = mscale self.mscale_all_dim = mscale_all_dim - self.dtype = dtype if dtype is not None else torch.get_default_dtype() - super().__init__(dim, max_position_embeddings, base, device, self.dtype) + super().__init__(dim, max_position_embeddings, base, device, dtype) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len