From 1bc10dd244a97736752952f448b1e2c20cad8423 Mon Sep 17 00:00:00 2001 From: Shen Xu Date: Wed, 15 Oct 2025 16:59:25 -0700 Subject: [PATCH] Source transform to use static attention (#15176) Summary: Introduce a source transform to be more aligned with other transforms we run, also makes it less error prone (e.g. HF RoPE transformation needs to happen before turning linears into conv2ds). Differential Revision: D84769599 --- examples/models/llama/static_attention.py | 67 +++++++++++++++++++ .../llama/tests/test_static_attention.py | 26 +++---- 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index 849718527ed..1880a09f5c6 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -764,6 +764,39 @@ def __init__( self.q_norm = torch.nn.Identity() self.k_norm = torch.nn.Identity() + @classmethod + def from_attention_mha( + cls, + other: AttentionMHA, + split_mha: bool = True, + rms_norm_class=torch.nn.RMSNorm, + **kwargs: Any, + ) -> "StaticAttention": + config = ModelArgs( + dim=other.dim, + n_layers=1, # Not used in attention layer + n_heads=other.n_heads, + n_kv_heads=other.n_kv_heads, + head_dim=other.head_dim, + max_batch_size=other.max_batch_size, + max_context_len=other.max_context_len, + attention_qkv_bias=other.attention_qkv_bias, + use_qk_norm=other.use_qk_norm, + qk_norm_before_rope=other.qk_norm_before_rope, + norm_eps=other.q_norm_fn.eps if other.use_qk_norm else 1e-5, + ) + + instance = cls( + config=config, + layer_id=other.layer_id, + rope=other.rope, + split_mha=split_mha, + **kwargs, + ) + instance.load_weights_from_attention_mha(other, rms_norm_class=rms_norm_class) + + return instance + def forward( self, x: torch.Tensor, @@ -1059,3 +1092,37 @@ def transfer_weight(linear, conv2d): class StaticAttentionMHA(StaticAttention): def __init__(self, config: ModelArgs, layer_id: int, rope: Rope, **kwargs: Any): super().__init__(config, layer_id, rope, split_mha=False, **kwargs) + + +def transform_attention_mha_to_static_attention( + model: nn.Module, + split_mha: bool = True, + inplace: bool = True, + use_conv2d: bool = False, + use_hf_rope: bool = False, + **kwargs: Any, +) -> nn.Module: + if not inplace: + import copy + + model = copy.deepcopy(model) + + def helper(m): + for name, child in list(m.named_children()): + if isinstance(child, AttentionMHA): + static_attn = StaticAttention.from_attention_mha( + child, split_mha=split_mha, **kwargs + ) + # Note: HF RoPE needs to be applied before linear to conv2d + if use_hf_rope: + static_attn.adopt_hf_rope() + if use_conv2d: + static_attn.linear_to_conv2d() + + setattr(m, name, static_attn) + else: + helper(child) + + return m + + return helper(model) diff --git a/examples/models/llama/tests/test_static_attention.py b/examples/models/llama/tests/test_static_attention.py index 8786c70da11..0d407968c0e 100644 --- a/examples/models/llama/tests/test_static_attention.py +++ b/examples/models/llama/tests/test_static_attention.py @@ -14,6 +14,7 @@ StaticAttentionMask, StaticKCache, StaticKVCache, + transform_attention_mha_to_static_attention, ) @@ -76,7 +77,6 @@ def test( layer_id = 0 rope = Rope(config) attn_mha = AttentionMHA(config, layer_id, rope).eval() - static_attn = StaticAttention(config, layer_id, rope).eval() if use_qk_norm: with torch.no_grad(): attn_mha.q_norm_fn.weight.copy_( @@ -85,7 +85,9 @@ def test( attn_mha.k_norm_fn.weight.copy_( torch.rand(config.head_dim) * 0.2 + 0.9 ) - static_attn.load_weights_from_attention_mha(attn_mha) + static_attn = StaticAttention.from_attention_mha( + attn_mha, split_mha=split_mha + ).eval() if adopt_hf_rope: static_attn.adopt_hf_rope() if use_conv2d: @@ -131,8 +133,7 @@ def test_with_cache(self): layer_id = 0 rope = Rope(config) attn_mha = AttentionMHA(config, layer_id, rope).eval() - static_attn = StaticAttention(config, layer_id, rope).eval() - static_attn.load_weights_from_attention_mha(attn_mha) + static_attn = StaticAttention.from_attention_mha(attn_mha).eval() static_attn.adopt_hf_rope() x = torch.rand(1, config.max_seq_len, config.dim) @@ -198,17 +199,16 @@ def test_with_style(style): def _get_test_transformers(self, config, attention_type="static", use_conv2d=False): mha_transformer = construct_transformer(config).eval() + static_transformer = transform_attention_mha_to_static_attention( + mha_transformer, + split_mha=(attention_type == "static"), + inplace=False, + use_conv2d=use_conv2d, + use_hf_rope=True, + ).eval() + config = copy.copy(config) config.attention_type = attention_type - static_transformer = construct_transformer(config).eval() - static_transformer.load_state_dict(mha_transformer.state_dict(), strict=False) - for mha_layer, static_layer in zip( - mha_transformer.layers, static_transformer.layers - ): - static_layer.attention.load_weights_from_attention_mha(mha_layer.attention) - static_layer.attention.adopt_hf_rope() - if use_conv2d: - static_layer.linear_to_conv2d() config.use_hf_rope = True return mha_transformer, static_transformer, config