File tree Expand file tree Collapse file tree 4 files changed +554
-10
lines changed
examples/models/llama/source_transformation/torchtune Expand file tree Collapse file tree 4 files changed +554
-10
lines changed Original file line number Diff line number Diff line change 1+ import torch
2+ import torchtune .modules .attention as TorchTuneAttention
3+ from executorch .examples .models .llama2 .source_transformation .torchtune .modules .mha import (
4+ MultiHeadAttention ,
5+ )
6+
7+
8+ def _replace_mha_with_inference_mha (module : torch .nn .Module ) -> None :
9+ for name , child in module .named_children ():
10+ if isinstance (child , TorchTuneAttention .MultiHeadAttention ):
11+ setattr (
12+ module ,
13+ name ,
14+ MultiHeadAttention (
15+ embed_dim = child .embed_dim ,
16+ num_heads = child .num_heads ,
17+ num_kv_heads = child .num_kv_heads ,
18+ head_dim = child .head_dim ,
19+ q_proj = child .q_proj ,
20+ k_proj = child .k_proj ,
21+ v_proj = child .v_proj ,
22+ output_proj = child .output_proj ,
23+ pos_embeddings = child .pos_embeddings ,
24+ q_norm = child .q_norm ,
25+ k_norm = child .k_norm ,
26+ kv_cache = child .kv_cache ,
27+ max_seq_len = child .max_seq_len ,
28+ is_causal = child .is_causal ,
29+ attn_dropout = child .attn_dropout ,
30+ ),
31+ )
32+ else :
33+ replace_mha_with_inference_mha (child )
34+
35+
36+ def replace_mha_with_inference_mha (module : torch .nn .Module ) -> torch .nn .Module :
37+ """
38+ Replace TorchTune's MHA with an inference friendly version of MHA that
39+ separates out the inference-related parts for further optimization.
40+ """
41+ _replace_mha_with_inference_mha (module )
42+ return module
Original file line number Diff line number Diff line change 1- ## Export Friendly Modules
1+ ## Export-friendly Modules
22
3- Modules in this directory are:
4- * Extending ` torch.nn.Module ` .
5- * Guranteed to work out of the box with ` torch.export.export() ` and ` torch.aot_compile() ` .
6- * Guranteed to be able to work with ExecuTorch.
3+ Modules in this directory:
4+ * Extend ` torch.nn.Module ` .
5+ * Are guaranteed to work out of the box with ` torch.export.export() ` .
6+ * Should work out of the box with ` torch.aot_compile() ` .
7+ * Should be able to workt with ExecuTorch.
78
89All modules should be covered by unit tests to make sure they are:
9- 1 . giving the same output as the reference implementation in PyTorch or torchtune
10- 2 . export friendly
11- 3 . AOTI friendly
12- 4 . ExecuTorch friendly
10+ 1 . Give the output as the reference eager model in PyTorch or TorrchTune
11+ 2 . Export-friendly
1312
14- Notice that these modules are subject to change (may upstream to torchtune) so proceed with caution.
13+ Additionally, we aim to make these modules:
14+ 3 . AOTI-friendly
15+ 4 . ExecuTorch-friendly
16+
17+ These modules are subject to change (may upstream to TorchTune) so proceed with caution.
You can’t perform that action at this time.
0 commit comments