Skip to content

Commit 416c91b

Browse files
committed
Swap mha
Move to extension/llm/modules Lint Add tests
1 parent b4c6fe1 commit 416c91b

File tree

4 files changed

+554
-10
lines changed

4 files changed

+554
-10
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import torch
2+
import torchtune.modules.attention as TorchTuneAttention
3+
from executorch.examples.models.llama2.source_transformation.torchtune.modules.mha import (
4+
MultiHeadAttention,
5+
)
6+
7+
8+
def _replace_mha_with_inference_mha(module: torch.nn.Module) -> None:
9+
for name, child in module.named_children():
10+
if isinstance(child, TorchTuneAttention.MultiHeadAttention):
11+
setattr(
12+
module,
13+
name,
14+
MultiHeadAttention(
15+
embed_dim=child.embed_dim,
16+
num_heads=child.num_heads,
17+
num_kv_heads=child.num_kv_heads,
18+
head_dim=child.head_dim,
19+
q_proj=child.q_proj,
20+
k_proj=child.k_proj,
21+
v_proj=child.v_proj,
22+
output_proj=child.output_proj,
23+
pos_embeddings=child.pos_embeddings,
24+
q_norm=child.q_norm,
25+
k_norm=child.k_norm,
26+
kv_cache=child.kv_cache,
27+
max_seq_len=child.max_seq_len,
28+
is_causal=child.is_causal,
29+
attn_dropout=child.attn_dropout,
30+
),
31+
)
32+
else:
33+
replace_mha_with_inference_mha(child)
34+
35+
36+
def replace_mha_with_inference_mha(module: torch.nn.Module) -> torch.nn.Module:
37+
"""
38+
Replace TorchTune's MHA with an inference friendly version of MHA that
39+
separates out the inference-related parts for further optimization.
40+
"""
41+
_replace_mha_with_inference_mha(module)
42+
return module

extension/llm/modules/README.md

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
## Export Friendly Modules
1+
## Export-friendly Modules
22

3-
Modules in this directory are:
4-
* Extending `torch.nn.Module`.
5-
* Guranteed to work out of the box with `torch.export.export()` and `torch.aot_compile()`.
6-
* Guranteed to be able to work with ExecuTorch.
3+
Modules in this directory:
4+
* Extend `torch.nn.Module`.
5+
* Are guaranteed to work out of the box with `torch.export.export()`.
6+
* Should work out of the box with `torch.aot_compile()`.
7+
* Should be able to workt with ExecuTorch.
78

89
All modules should be covered by unit tests to make sure they are:
9-
1. giving the same output as the reference implementation in PyTorch or torchtune
10-
2. export friendly
11-
3. AOTI friendly
12-
4. ExecuTorch friendly
10+
1. Give the output as the reference eager model in PyTorch or TorrchTune
11+
2. Export-friendly
1312

14-
Notice that these modules are subject to change (may upstream to torchtune) so proceed with caution.
13+
Additionally, we aim to make these modules:
14+
3. AOTI-friendly
15+
4. ExecuTorch-friendly
16+
17+
These modules are subject to change (may upstream to TorchTune) so proceed with caution.

0 commit comments

Comments
 (0)