diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index a285f00dc..5106e2dc4 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -174,7 +174,7 @@ def forward( ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output, **kwargs) return attn_output, attn_weights, past_key_value