diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 817c77d6d..61b5c00f6 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -34,25 +34,33 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]: class FP16ClipTransform(OnnxTransform): """ - Clips the tensor values to be in FP16 range. + Clips the tensor values to be in FP16 range, but preserves -inf values. """ @classmethod def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwargs) -> Tuple[ModelProto, bool]: """ - :param onnx_base_dir: Base directory to load tensors (if not already loaded). + :param onnx_base_dir: Base directory to load tensors """ finfo = np.finfo(np.float16) fp16_max = finfo.max fp16_min = finfo.min transformed = False + for tensor in external_data_helper._get_all_tensors(model): nptensor = numpy_helper.to_array(tensor, onnx_base_dir) if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)): - nptensor = np.clip(nptensor, fp16_min, fp16_max) - new_tensor = numpy_helper.from_array(nptensor, tensor.name) + neg_inf_mask = np.isinf(nptensor) & (nptensor < 0) + clipped_tensor = np.clip(nptensor, fp16_min, fp16_max) + + # Restore -inf values + if neg_inf_mask.any(): + clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor) + + new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name) tensor.CopyFrom(new_tensor) transformed = True + return model, transformed diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 0a0e4d54b..72b7acd98 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -88,6 +88,7 @@ ) from QEfficient.customop import CustomRMSNormAIC +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE # Placeholder for all non-transformer models from .models.codegen.modeling_codegen import ( @@ -307,12 +308,12 @@ def _prepare_cross_attention_mask( # invert the mask inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) cross_attention_mask = inverted_cross_attn_mask.masked_fill( - inverted_cross_attn_mask.to(torch.bool), torch.tensor(-10000.0, dtype=torch.float32) + inverted_cross_attn_mask.to(torch.bool), torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32) ) # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's # last dimension contains negative infinity values, otherwise it's 1 - negative_inf_value = torch.tensor(-10000.0, dtype=torch.float32) + negative_inf_value = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32) full_text_row_masked_out_mask = ( (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] ) @@ -342,7 +343,11 @@ def _prepare_aspect_ratio_attention_mask( # Reshape to 2D and create 4D attention mask # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1) - attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.tensor(-10000.0, dtype=torch.float32) + attention_mask = ( + attention_mask + @ attention_mask.transpose(-1, -2) + * torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32) + ) attention_mask = attention_mask.unsqueeze(1) return attention_mask diff --git a/QEfficient/transformers/models/codegen/modeling_codegen.py b/QEfficient/transformers/models/codegen/modeling_codegen.py index 09400c51e..b4b33a8ff 100644 --- a/QEfficient/transformers/models/codegen/modeling_codegen.py +++ b/QEfficient/transformers/models/codegen/modeling_codegen.py @@ -23,6 +23,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffCodeGenAttention(CodeGenAttention): @@ -48,10 +49,10 @@ def _attn( attn_weights = attn_weights / self.scale_attn # Minimum value for causal mask - mask_value = -10000.0 + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + mask_value = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype).to(attn_weights.device) if attention_mask is not None: # Apply the attention mask diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 9dca5f050..79e3ebc01 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -31,6 +31,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffFalconRotaryEmbedding(FalconRotaryEmbedding): @@ -148,7 +149,9 @@ def forward( attention_scores = query_layer @ key_layer.transpose(-1, -2) attention_scores /= math.sqrt(self.head_dim) - attention_scores = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attention_scores) + attention_scores = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attention_scores + ) attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) # It is unclear why neither dropout nor head_mask is applied here (while it is with alibi). attn_output = attention_scores @ value_layer diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index bd5e85d84..0cefbcfee 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -27,6 +27,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffGemmaRotaryEmbedding(GemmaRotaryEmbedding): @@ -110,7 +111,9 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index fa0b3cc49..173da1798 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -30,6 +30,7 @@ # from transformers.utils import is_torchdynamo_compiling from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffGemma2RotaryEmbedding(Gemma2RotaryEmbedding): @@ -116,7 +117,9 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 9e9544b7e..4d615418c 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -162,7 +162,9 @@ def eager_attention_forward( attn_weights = attn_weights * softcap if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(constants.MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) @@ -265,7 +267,11 @@ def forward( attn_weights = attn_weights * self.config.attn_logit_softcapping if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask.bool(), torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask.bool(), + torch.tensor(constants.MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), + attn_weights, + ) # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) diff --git a/QEfficient/transformers/models/gpt2/modeling_gpt2.py b/QEfficient/transformers/models/gpt2/modeling_gpt2.py index 0b458fbbe..a2b84c139 100644 --- a/QEfficient/transformers/models/gpt2/modeling_gpt2.py +++ b/QEfficient/transformers/models/gpt2/modeling_gpt2.py @@ -17,6 +17,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs): @@ -30,15 +31,16 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length] - mask_value = -10000.0 # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + mask_value = torch.full([], MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype, device=attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) if attention_mask is not None: # Apply the attention mask - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1) diff --git a/QEfficient/transformers/models/gptj/modeling_gptj.py b/QEfficient/transformers/models/gptj/modeling_gptj.py index 5daa4ace3..6b11e3f4f 100644 --- a/QEfficient/transformers/models/gptj/modeling_gptj.py +++ b/QEfficient/transformers/models/gptj/modeling_gptj.py @@ -28,6 +28,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor: @@ -62,7 +63,9 @@ def _attn( if attention_mask is not None: # Apply the attention mask - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = attn_weights.to(value.dtype) diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index af4ebfc92..13b308547 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -26,6 +26,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffGraniteRotaryEmbedding(GraniteRotaryEmbedding): @@ -107,7 +108,9 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index 6e99e2ffa..8f840b4b4 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -29,6 +29,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffGraniteMoeRotaryEmbedding(GraniteMoeRotaryEmbedding): @@ -153,7 +154,9 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) dropout = 0.0 if not self.training else self.attention_dropout diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py index a28cb1699..21516ff5f 100644 --- a/QEfficient/transformers/models/grok_1/modeling_grok1.py +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -20,6 +20,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.transformers.models.llama.modeling_llama import qeff_apply_rotary_pos_emb +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEFFGrok1CustomRMSNormAIC(nn.Module): @@ -110,7 +111,9 @@ def forward( attn_weights = self.max_attn_val * F.tanh(attn_weights / self.max_attn_val) if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index 0cccd7fcf..a285f00dc 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -27,6 +27,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding): @@ -109,7 +110,9 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 50a7eba65..d936f629d 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -55,7 +55,9 @@ def eager_attention_forward_vision( causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(constants.MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(query.dtype) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) @@ -380,7 +382,9 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(constants.MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index f6cf2de49..b73e8d822 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -29,6 +29,7 @@ QEffLlamaRotaryEmbedding, qeff_apply_rotary_pos_emb, ) +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffLlamaSwiftKVConfig(LlamaConfig): @@ -120,7 +121,9 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index 59c19baa2..60b1c929d 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -31,6 +31,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffMistralRotaryEmbedding(MistralRotaryEmbedding): @@ -114,7 +115,9 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 808f6baf2..ef51c3421 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -32,6 +32,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffMixtralRotaryEmbedding(MixtralRotaryEmbedding): @@ -115,7 +116,9 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 1cfafae58..761773d9a 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -178,9 +178,6 @@ def forward( if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - # attn_weights = torch.where( - # attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights - # ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_output = torch.matmul(attn_weights, value_states) @@ -256,7 +253,9 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(constants.MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) diff --git a/QEfficient/transformers/models/mpt/modeling_mpt.py b/QEfficient/transformers/models/mpt/modeling_mpt.py index 359a32672..89d474e15 100644 --- a/QEfficient/transformers/models/mpt/modeling_mpt.py +++ b/QEfficient/transformers/models/mpt/modeling_mpt.py @@ -21,6 +21,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffMptAttention(MptAttention): @@ -78,7 +79,7 @@ def forward( if attention_mask is not None: attention_scores = torch.where( - attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attention_scores + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attention_scores ) # (batch_size, n_heads, seq_length, key_length) diff --git a/QEfficient/transformers/models/phi/modeling_phi.py b/QEfficient/transformers/models/phi/modeling_phi.py index e08dfa528..18557f1ca 100644 --- a/QEfficient/transformers/models/phi/modeling_phi.py +++ b/QEfficient/transformers/models/phi/modeling_phi.py @@ -24,6 +24,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE def eager_attention_forward( @@ -40,7 +41,9 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 3a54a1e83..602a73c84 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -27,6 +27,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffPhi3RotaryEmbedding(Phi3RotaryEmbedding): @@ -108,7 +109,9 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 67c71b32c..00a3989d8 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -30,6 +30,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE # Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology @@ -124,7 +125,9 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) diff --git a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py index 9ea508f5c..e3db4b490 100644 --- a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py +++ b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py @@ -27,6 +27,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE def eager_attention_forward( @@ -44,7 +45,9 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) diff --git a/QEfficient/transformers/models/whisper/modeling_whisper.py b/QEfficient/transformers/models/whisper/modeling_whisper.py index 2dda9ed96..afa2a6b07 100644 --- a/QEfficient/transformers/models/whisper/modeling_whisper.py +++ b/QEfficient/transformers/models/whisper/modeling_whisper.py @@ -30,6 +30,7 @@ from QEfficient.transformers.cache_utils import QEffEncoderDecoderCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils._utils import IOInfo +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffWhisperPositionalEmbedding(WhisperPositionalEmbedding): @@ -116,7 +117,9 @@ def forward( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) # updated to use torch.where, to prevent overflow in fp16 computation - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1) diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 9a8085bde..696cc564e 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -28,6 +28,8 @@ # Compiler defaults DEFAULT_AIC_NUM_CORES = 16 DEFAULT_AIC_MXPF6_MATMUL = False +# Minimum value for causal mask +MIN_MASKED_ATTENTION_VALUE = float("-inf") # Store the qeff_models inside the ~/.cache directory or over-ride with an env variable.