Skip to content

Commit 3f87c3b

Browse files
committed
Fix tests
1 parent 416c91b commit 3f87c3b

File tree

2 files changed

+89
-40
lines changed

2 files changed

+89
-40
lines changed

extension/llm/modules/mha.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def forward(
354354
q = q.transpose(1, 2)
355355
k = k.transpose(1, 2)
356356
v = v.transpose(1, 2)
357+
357358
output = self._attention_fn(
358359
q,
359360
k,

extension/llm/modules/test/test_mha.py

Lines changed: 88 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import unittest
88

99
import torch
10+
from executorch.exir import EdgeCompileConfig, to_edge
1011

1112
from executorch.extension.llm.modules.mha import (
1213
MultiHeadAttention as ETMultiHeadAttention,
1314
)
15+
from executorch.runtime import Runtime
16+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
1417
from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention
15-
from torchtune.modules.kv_cache import KVCache
1618

1719

1820
torch.manual_seed(0)
@@ -21,76 +23,122 @@
2123
class AttentionTest(unittest.TestCase):
2224
def setUp(self):
2325
super().setUp()
24-
self.embed_dim=2048
25-
self.num_heads=32
26-
self.num_kv_heads=8
27-
self.head_dim=64
26+
27+
# Constants
28+
self.embed_dim = 2048
29+
self.num_heads = 32
30+
self.num_kv_heads = 8
31+
self.head_dim = 64
2832
self.max_seq_len = 128
33+
self.rope_base = 500_000
34+
self.scale_factor = 32
35+
36+
# Module dependency injections.
37+
self.q_proj = torch.nn.Linear(
38+
self.embed_dim, self.num_heads * self.head_dim, bias=False
39+
)
40+
self.k_proj = torch.nn.Linear(
41+
self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
42+
)
43+
self.v_proj = torch.nn.Linear(
44+
self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
45+
)
46+
self.output_proj = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False)
47+
self.pos_embeddings = Llama3ScaledRoPE(
48+
dim=self.head_dim,
49+
max_seq_len=self.max_seq_len,
50+
base=self.rope_base,
51+
scale_factor=self.scale_factor,
52+
)
53+
54+
# Original TorchTune reference module to test accuracy against.
2955
self.tt_mha = TTMultiHeadAttention(
3056
embed_dim=self.embed_dim,
3157
num_heads=self.num_heads,
3258
num_kv_heads=self.num_kv_heads,
3359
head_dim=self.head_dim,
34-
q_proj=torch.nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False),
35-
k_proj=torch.nn.Linear(self.embed_dim, self.num_kv_heads * self.head_dim, bias=False),
36-
v_proj=torch.nn.Linear(self.embed_dim, self.num_kv_heads * self.head_dim, bias=False),
37-
output_proj=torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False),
38-
# pos_embeddings=rope,
60+
q_proj=self.q_proj,
61+
k_proj=self.k_proj,
62+
v_proj=self.v_proj,
63+
output_proj=self.output_proj,
64+
pos_embeddings=self.pos_embeddings,
3965
max_seq_len=self.max_seq_len,
40-
# attn_dropout=attn_dropout,
4166
)
67+
68+
# Source transformed module that we are testing.
4269
self.et_mha = ETMultiHeadAttention(
4370
embed_dim=self.embed_dim,
4471
num_heads=self.num_heads,
4572
num_kv_heads=self.num_kv_heads,
4673
head_dim=self.head_dim,
47-
q_proj=torch.nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False),
48-
k_proj=torch.nn.Linear(self.embed_dim, self.num_kv_heads * self.head_dim, bias=False),
49-
v_proj=torch.nn.Linear(self.embed_dim, self.num_kv_heads * self.head_dim, bias=False),
50-
output_proj=torch.nn.Linear(self.embed_dim, self.embed_dim, bias=False),
51-
# pos_embeddings=rope,
74+
q_proj=self.q_proj,
75+
k_proj=self.k_proj,
76+
v_proj=self.v_proj,
77+
output_proj=self.output_proj,
78+
pos_embeddings=self.pos_embeddings,
5279
max_seq_len=self.max_seq_len,
53-
# attn_dropout=attn_dropout,
5480
)
5581

56-
def test_self_attention_eager(self):
82+
# Common inputs.
5783
seq_len = 10
58-
x = torch.randn(1, seq_len, self.embed_dim)
59-
et_res = self.et_mha(x, x) # Self attention.
60-
tt_res = self.tt_mha(x, x) # Self attention.
61-
84+
self.x = torch.randn(1, seq_len, self.embed_dim)
85+
seq_len_dim = torch.export.Dim("seq_len", min=1, max=100)
86+
self.dynamic_shapes = (
87+
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
88+
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
89+
)
90+
91+
def test_attention_eager(self):
92+
et_res = self.et_mha(self.x, self.x) # Self attention.
93+
tt_res = self.tt_mha(self.x, self.x) # Self attention.
94+
6295
self.assertTrue(torch.allclose(et_res, tt_res))
6396

6497
# TODO: KV cache.
6598
# self.et_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20)
6699
# self.tt_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20)
67-
68-
# et_res = self.et_mha(x, x) # Self attention.
69-
# tt_res = self.tt_mha(x, x) # Self attention.
70100

71-
# self.assertTrue(torch.allclose(et_res, tt_res))
101+
# et_res = self.et_mha(self.x, self.x) # Self attention.
102+
# tt_res = self.tt_mha(self.x, self.x) # Self attention.
72103

73-
def test_self_attention_export(self):
74-
seq_len = 10
75-
x = torch.randn(1, seq_len, self.embed_dim)
76-
seq_len_dim = torch.export.Dim("seq_len", min=1, max=100)
77-
dynamic_shapes = (
78-
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
79-
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
80-
)
104+
# self.assertTrue(torch.allclose(et_res, tt_res))
81105

106+
def test_attention_export(self):
82107
# Self attention.
83108
et_mha_ep = torch.export.export(
84109
self.et_mha,
85-
(x, x),
110+
(self.x, self.x),
86111
kwargs=None,
87-
dynamic_shapes=dynamic_shapes,
112+
dynamic_shapes=self.dynamic_shapes,
88113
)
89-
et_res = et_mha_ep.module()(x, x)
90-
tt_res = self.tt_mha(x, x)
114+
et_res = et_mha_ep.module()(self.x, self.x)
115+
tt_res = self.tt_mha(self.x, self.x)
91116
self.assertTrue(torch.allclose(et_res, tt_res))
92-
117+
93118
# TODO: KV cache.
94119

95-
def test_cross_attention_export(self):
120+
def test_attention_aoti(self):
121+
# TODO.
96122
pass
123+
124+
def test_attention_executorch(self):
125+
# Self attention.
126+
et_mha_ep = torch.export.export(
127+
self.et_mha,
128+
(self.x, self.x),
129+
kwargs=None,
130+
dynamic_shapes=self.dynamic_shapes,
131+
)
132+
et_program = to_edge(
133+
et_mha_ep,
134+
compile_config=EdgeCompileConfig(),
135+
).to_executorch()
136+
runtime = Runtime.get()
137+
program = runtime.load_program(et_program.buffer)
138+
method = program.load_method("forward")
139+
et_res = method.execute((self.x, self.x))
140+
tt_res = self.tt_mha(self.x, self.x)
141+
142+
self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06))
143+
144+
# TODO: KV cache.

0 commit comments

Comments
 (0)