diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index ffcdb6cb9ce..395fce85613 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -1044,5 +1044,5 @@ def transfer_weight(linear, conv2d): @register_attention("static_mha") class StaticAttentionMHA(StaticAttention): - def __init__(self, config: ModelArgs, layer_id: int, rope: Rope): - super().__init__(config, layer_id, rope, split_mha=False) + def __init__(self, config: ModelArgs, layer_id: int, rope: Rope, **kwargs: Any): + super().__init__(config, layer_id, rope, split_mha=False, **kwargs)