Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,33 @@ def apply(cls, model: ModelProto, **kwargs) -> Tuple[ModelProto, bool]:

class FP16ClipTransform(OnnxTransform):
"""
Clips the tensor values to be in FP16 range.
Clips the tensor values to be in FP16 range, but preserves -inf values.
"""

@classmethod
def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwargs) -> Tuple[ModelProto, bool]:
"""
:param onnx_base_dir: Base directory to load tensors (if not already loaded).
:param onnx_base_dir: Base directory to load tensors
"""
finfo = np.finfo(np.float16)
fp16_max = finfo.max
fp16_min = finfo.min
transformed = False

for tensor in external_data_helper._get_all_tensors(model):
nptensor = numpy_helper.to_array(tensor, onnx_base_dir)
if nptensor.dtype == np.float32 and (np.any(nptensor > fp16_max) or np.any(nptensor < fp16_min)):
nptensor = np.clip(nptensor, fp16_min, fp16_max)
new_tensor = numpy_helper.from_array(nptensor, tensor.name)
neg_inf_mask = np.isinf(nptensor) & (nptensor < 0)
clipped_tensor = np.clip(nptensor, fp16_min, fp16_max)

# Restore -inf values
if neg_inf_mask.any():
clipped_tensor = np.where(neg_inf_mask, np.float32("-inf"), clipped_tensor)

new_tensor = numpy_helper.from_array(clipped_tensor, tensor.name)
tensor.CopyFrom(new_tensor)
transformed = True

return model, transformed


Expand Down
11 changes: 8 additions & 3 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
)

from QEfficient.customop import CustomRMSNormAIC
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE

# Placeholder for all non-transformer models
from .models.codegen.modeling_codegen import (
Expand Down Expand Up @@ -307,12 +308,12 @@ def _prepare_cross_attention_mask(
# invert the mask
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
inverted_cross_attn_mask.to(torch.bool), torch.tensor(-10000.0, dtype=torch.float32)
inverted_cross_attn_mask.to(torch.bool), torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32)
)

# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
# last dimension contains negative infinity values, otherwise it's 1
negative_inf_value = torch.tensor(-10000.0, dtype=torch.float32)
negative_inf_value = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32)
full_text_row_masked_out_mask = (
(cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
)
Expand Down Expand Up @@ -342,7 +343,11 @@ def _prepare_aspect_ratio_attention_mask(
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1)
attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.tensor(-10000.0, dtype=torch.float32)
attention_mask = (
attention_mask
@ attention_mask.transpose(-1, -2)
* torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32)
)
attention_mask = attention_mask.unsqueeze(1)

return attention_mask
Expand Down
5 changes: 3 additions & 2 deletions QEfficient/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


class QEffCodeGenAttention(CodeGenAttention):
Expand All @@ -48,10 +49,10 @@ def _attn(

attn_weights = attn_weights / self.scale_attn
# Minimum value for causal mask
mask_value = -10000.0

# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
mask_value = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype).to(attn_weights.device)

if attention_mask is not None:
# Apply the attention mask
Expand Down
5 changes: 4 additions & 1 deletion QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


class QEffFalconRotaryEmbedding(FalconRotaryEmbedding):
Expand Down Expand Up @@ -148,7 +149,9 @@ def forward(

attention_scores = query_layer @ key_layer.transpose(-1, -2)
attention_scores /= math.sqrt(self.head_dim)
attention_scores = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attention_scores)
attention_scores = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attention_scores
)
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
# It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
attn_output = attention_scores @ value_layer
Expand Down
5 changes: 4 additions & 1 deletion QEfficient/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


class QEffGemmaRotaryEmbedding(GemmaRotaryEmbedding):
Expand Down Expand Up @@ -110,7 +111,9 @@ def eager_attention_forward(

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down
5 changes: 4 additions & 1 deletion QEfficient/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

# from transformers.utils import is_torchdynamo_compiling
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


class QEffGemma2RotaryEmbedding(Gemma2RotaryEmbedding):
Expand Down Expand Up @@ -116,7 +117,9 @@ def eager_attention_forward(

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down
10 changes: 8 additions & 2 deletions QEfficient/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def eager_attention_forward(
attn_weights = attn_weights * softcap

if attention_mask is not None:
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(constants.MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down Expand Up @@ -265,7 +267,11 @@ def forward(
attn_weights = attn_weights * self.config.attn_logit_softcapping

if attention_mask is not None: # no matter the length, we just slice it
attn_weights = torch.where(attention_mask.bool(), torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask.bool(),
torch.tensor(constants.MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32),
attn_weights,
)

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
Expand Down
8 changes: 5 additions & 3 deletions QEfficient/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


def eager_attention_forward(module, query, key, value, attention_mask, head_mask=None, **kwargs):
Expand All @@ -30,15 +31,16 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = module.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = -10000.0
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
mask_value = torch.full([], MIN_MASKED_ATTENTION_VALUE, dtype=attn_weights.dtype, device=attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

if attention_mask is not None:
# Apply the attention mask
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1)

Expand Down
5 changes: 4 additions & 1 deletion QEfficient/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -62,7 +63,9 @@ def _attn(

if attention_mask is not None:
# Apply the attention mask
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


class QEffGraniteRotaryEmbedding(GraniteRotaryEmbedding):
Expand Down Expand Up @@ -107,7 +108,9 @@ def eager_attention_forward(

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


class QEffGraniteMoeRotaryEmbedding(GraniteMoeRotaryEmbedding):
Expand Down Expand Up @@ -153,7 +154,9 @@ def forward(

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if attention_mask is not None:
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
dropout = 0.0 if not self.training else self.attention_dropout
Expand Down
5 changes: 4 additions & 1 deletion QEfficient/transformers/models/grok_1/modeling_grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.transformers.models.llama.modeling_llama import qeff_apply_rotary_pos_emb
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


class QEFFGrok1CustomRMSNormAIC(nn.Module):
Expand Down Expand Up @@ -110,7 +111,9 @@ def forward(
attn_weights = self.max_attn_val * F.tanh(attn_weights / self.max_attn_val)

if attention_mask is not None:
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down
5 changes: 4 additions & 1 deletion QEfficient/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
Expand Down Expand Up @@ -109,7 +110,9 @@ def eager_attention_forward(

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down
8 changes: 6 additions & 2 deletions QEfficient/transformers/models/llama4/modeling_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def eager_attention_forward_vision(
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
if attention_mask is not None:
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(constants.MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
Expand Down Expand Up @@ -380,7 +382,9 @@ def eager_attention_forward(

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(constants.MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
QEffLlamaRotaryEmbedding,
qeff_apply_rotary_pos_emb,
)
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


class QEffLlamaSwiftKVConfig(LlamaConfig):
Expand Down Expand Up @@ -120,7 +121,9 @@ def forward(

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None: # no matter the length, we just slice it
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


class QEffMistralRotaryEmbedding(MistralRotaryEmbedding):
Expand Down Expand Up @@ -114,7 +115,9 @@ def eager_attention_forward(

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from QEfficient.transformers.cache_utils import QEffDynamicCache
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE


class QEffMixtralRotaryEmbedding(MixtralRotaryEmbedding):
Expand Down Expand Up @@ -115,7 +116,9 @@ def eager_attention_forward(

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down
7 changes: 3 additions & 4 deletions QEfficient/transformers/models/mllama/modeling_mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ def forward(
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# attn_weights = torch.where(
# attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights
# )

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down Expand Up @@ -256,7 +253,9 @@ def forward(
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None: # no matter the length, we just slice it
attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights)
attn_weights = torch.where(
attention_mask, torch.tensor(constants.MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
)

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
Expand Down
Loading