From 2c919426793ff0765825f5a3cd3f9c4605b858b7 Mon Sep 17 00:00:00 2001 From: Dipankar Sarkar Date: Wed, 15 Oct 2025 09:30:51 +0000 Subject: [PATCH] Olmo2 Bug fix Signed-off-by: Dipankar Sarkar --- QEfficient/transformers/models/olmo2/modeling_olmo2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index 6dae7ac84..0d23729c1 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -27,6 +27,7 @@ from QEfficient.transformers.cache_utils import QEffDynamicCache from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE class QEffOlmo2RotaryEmbedding(Olmo2RotaryEmbedding): @@ -109,7 +110,9 @@ def eager_attention_forward( attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: - attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous()