Add LoRA support to StaticAttention for split_mha=False#18345
Add LoRA support to StaticAttention for split_mha=False#18345lucylq wants to merge 5 commits intogh/lucylq/142/headfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18345
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit de2e79b with merge base 02bad9d ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Pull request overview
This PR adds LoRA-aware projection construction to the StaticAttention implementation when using the non-split MHA path (split_mha=False), so that q/k/v/o projections can become LoRALinear based on ModelArgs.target_modules, while keeping existing behavior unchanged when target_modules is None.
Changes:
- For
split_mha=False, conditionally instantiateLoRALinearfor q/k/v projections when their corresponding target names are present inconfig.target_modules. - For
split_mha=False, conditionally instantiateLoRALinearfor the output projection (wo) whenoutput_projoro_projis targeted.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| has_lora = config.target_modules is not None | ||
| _PROJ_TARGET = { | ||
| "wqs": ("q_proj", self.dim, self.head_dim * self.n_heads), | ||
| "wks": ("k_proj", self.dim, self.head_dim * self.n_kv_heads), | ||
| "wvs": ("v_proj", self.dim, self.head_dim * self.n_kv_heads), | ||
| } | ||
| for attr, (target, in_dim, out_dim) in _PROJ_TARGET.items(): | ||
| if has_lora and target in config.target_modules: | ||
| proj = LoRALinear( | ||
| in_dim=in_dim, | ||
| out_dim=out_dim, | ||
| rank=config.r, | ||
| alpha=config.lora_alpha, | ||
| use_bias=self.attention_qkv_bias, | ||
| ) |
There was a problem hiding this comment.
When config.target_modules is set but config.r and/or config.lora_alpha are left as None (both are Optional in ModelArgs), this path will attempt to construct LoRALinear(rank=None, alpha=None) and fail with a low-signal TypeError. Consider adding an explicit validation (ValueError with a clear message) before creating any LoRALinear modules, similar to LoRAFeedForward.
| has_lora = config.target_modules is not None | ||
| _PROJ_TARGET = { | ||
| "wqs": ("q_proj", self.dim, self.head_dim * self.n_heads), | ||
| "wks": ("k_proj", self.dim, self.head_dim * self.n_kv_heads), | ||
| "wvs": ("v_proj", self.dim, self.head_dim * self.n_kv_heads), | ||
| } | ||
| for attr, (target, in_dim, out_dim) in _PROJ_TARGET.items(): | ||
| if has_lora and target in config.target_modules: | ||
| proj = LoRALinear( | ||
| in_dim=in_dim, | ||
| out_dim=out_dim, | ||
| rank=config.r, | ||
| alpha=config.lora_alpha, | ||
| use_bias=self.attention_qkv_bias, | ||
| ) | ||
| ] | ||
| ) | ||
| else: | ||
| proj = nn.Linear(in_dim, out_dim, bias=self.attention_qkv_bias) | ||
| setattr(self, attr, nn.ModuleList([proj])) | ||
|
|
There was a problem hiding this comment.
New behavior is introduced here (direct StaticAttention(..., split_mha=False) now conditionally builds LoRALinear based on config.target_modules), but existing tests in test_static_attention.py only exercise LoRA via from_attention_mha. Please add a unit test that directly constructs StaticAttention with split_mha=False and target_modules set, and asserts the expected projection types and a forward equivalence check.
When ModelArgs.target_modules is set, create LoRALinear instead of
nn.Linear for targeted q/k/v/o projections. Only applies to
split_mha=False path. Existing behavior unchanged when target_modules
is None.
Authored with Claude.