Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions QEfficient/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformers.models.granite.modeling_granite import (
GraniteAttention,
GraniteConfig,
GraniteDecoderLayer,
GraniteForCausalLM,
GraniteModel,
GraniteRotaryEmbedding,
Expand Down Expand Up @@ -173,6 +174,80 @@ def forward(
return attn_output, attn_weights


class QEffGraniteDecoderLayer(GraniteDecoderLayer):
"""
Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granite/modeling_granite.py
The only differences are:
- add new args batch idx for the CB models although its not supported yet.
"""

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
batch_index: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_values (`Cache`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
batch_index=batch_index,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states * self.residual_multiplier

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states * self.residual_multiplier # main diff with Llama

outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)

return outputs


class QEffGraniteModel(GraniteModel):
def forward(
self,
Expand Down
3 changes: 3 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJForCausalLM, GPTJModel
from transformers.models.granite.modeling_granite import (
GraniteAttention,
GraniteDecoderLayer,
GraniteForCausalLM,
GraniteModel,
GraniteRMSNorm,
Expand Down Expand Up @@ -268,6 +269,7 @@
)
from QEfficient.transformers.models.granite.modeling_granite import (
QEffGraniteAttention,
QEffGraniteDecoderLayer,
QEffGraniteForCausalLM,
QEffGraniteModel,
)
Expand Down Expand Up @@ -531,6 +533,7 @@ class KVCacheTransform(ModuleMappingTransform):
GraniteModel: QEffGraniteModel,
GraniteForCausalLM: QEffGraniteForCausalLM,
GraniteAttention: QEffGraniteAttention,
GraniteDecoderLayer: QEffGraniteDecoderLayer,
# GraniteMoe
GraniteMoeModel: QEffGraniteMoeModel,
GraniteMoeForCausalLM: QEffGraniteMoeForCausalLM,
Expand Down
Loading