Skip to content

Commit 3654fa8

Browse files
authored
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
2 parents ccce52e + c7d5786 commit 3654fa8

File tree

2 files changed

+139
-23
lines changed

2 files changed

+139
-23
lines changed

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from onnxscript import script
1717
from onnxscript.onnx_opset import opset18 as op
1818
from onnxscript.onnx_types import FLOAT
19+
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
1920
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
21+
from onnxscript.rewriter.ort_fusions.sdpa_via_mha import replace_sdpa_by_mha
2022

2123
B = 2 # batch size
2224
N = 4 # number of heads
@@ -190,7 +192,7 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
190192

191193

192194
@script()
193-
def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
195+
def _masked_custom_scale_pre_div_sdpa_script(query, key, value, mask):
194196
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
195197
divisor = op.Constant(value_float=SQRT_CUSTOM_DIV_SCALE_FACTOR)
196198
scaled_query = op.Div(query, divisor)
@@ -203,7 +205,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
203205

204206

205207
@script()
206-
def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
208+
def _masked_custom_scale_pre_mul_sdpa_script(query, key, value, mask):
207209
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
208210
multiplier = op.Constant(value_float=SQRT_CUSTOM_MUL_SCALE_FACTOR)
209211
scaled_query = op.Mul(query, multiplier)
@@ -216,7 +218,7 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
216218

217219

218220
@script()
219-
def _custom_scale_post_div_sdpa_script(query, key, value, mask):
221+
def _masked_custom_scale_post_div_sdpa_script(query, key, value, mask):
220222
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
221223
divisor = op.Constant(value_float=CUSTOM_DIV_SCALE_FACTOR)
222224
attn_score = op.MatMul(query, key_transposed)
@@ -228,7 +230,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value, mask):
228230

229231

230232
@script()
231-
def _custom_scale_post_mul_sdpa_script(query, key, value, mask):
233+
def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask):
232234
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
233235
multiplier = op.Constant(value_float=CUSTOM_MUL_SCALE_FACTOR)
234236
attn_score = op.MatMul(query, key_transposed)
@@ -240,15 +242,19 @@ def _custom_scale_post_mul_sdpa_script(query, key, value, mask):
240242

241243

242244
class SDPATestCase:
243-
def __init__(self, script_func):
245+
def __init__(self, script_func, *, with_mask):
244246
self.script_func = script_func
247+
self.with_mask = with_mask
245248

246249
def get_onnx_model(self):
247250
if not hasattr(self, "_onnx_model"):
248251
qkv_type = FLOAT[B, N, S, H]
249252
mask_type = FLOAT[B, N, S, S]
253+
input_types = [qkv_type, qkv_type, qkv_type]
254+
if self.with_mask:
255+
input_types.append(mask_type)
250256
model_proto = self.script_func.to_model_proto(
251-
input_types=[qkv_type, qkv_type, qkv_type, mask_type], output_types=[qkv_type]
257+
input_types=input_types, output_types=[qkv_type]
252258
)
253259
self._onnx_model = ir.serde.deserialize_model(model_proto)
254260
return self._onnx_model
@@ -259,6 +265,35 @@ def get_ort_inputs(self):
259265
"query": numpy.random.rand(B, N, S, H).astype(numpy.float32),
260266
"key": numpy.random.rand(B, N, S, H).astype(numpy.float32),
261267
"value": numpy.random.rand(B, N, S, H).astype(numpy.float32),
268+
}
269+
if self.with_mask:
270+
inputs["mask"] = numpy.random.rand(B, N, S, S).astype(numpy.float32)
271+
self._ort_inputs = inputs
272+
return self._ort_inputs
273+
274+
275+
class InvalidSDPATestCase:
276+
def __init__(self, script_func):
277+
self.script_func = script_func
278+
279+
def get_onnx_model(self):
280+
if not hasattr(self, "_onnx_model"):
281+
qk_type = FLOAT[B, N, S, H]
282+
# We broadcast value in the batch dimension, which is not supported by SDPA fusion
283+
v_type = FLOAT[1, N, S, H]
284+
mask_type = FLOAT[B, N, S, S]
285+
model_proto = self.script_func.to_model_proto(
286+
input_types=[qk_type, qk_type, v_type, mask_type], output_types=[qk_type]
287+
)
288+
self._onnx_model = ir.serde.deserialize_model(model_proto)
289+
return self._onnx_model
290+
291+
def get_ort_inputs(self):
292+
if not hasattr(self, "_ort_inputs"):
293+
inputs = {
294+
"query": numpy.random.rand(B, N, S, H).astype(numpy.float32),
295+
"key": numpy.random.rand(B, N, S, H).astype(numpy.float32),
296+
"value": numpy.random.rand(1, N, S, H).astype(numpy.float32),
262297
"mask": numpy.random.rand(B, N, S, S).astype(numpy.float32),
263298
}
264299
self._ort_inputs = inputs
@@ -296,35 +331,35 @@ def get_ort_inputs(self):
296331
class TestSDPAFusion(unittest.TestCase):
297332
@parameterized.parameterized.expand(
298333
[
299-
("unmasked_pre_div", _unmasked_pre_div_sdpa_script),
300-
("unmasked_pre_mul", _unmasked_pre_mul_sdpa_script),
301-
("unmasked_post_div", _unmasked_post_div_sdpa_script),
302-
("unmasked_post_mul", _unmasked_post_mul_sdpa_script),
303-
("pre_div", _masked_pre_div_sdpa_script),
304-
("pre_mul", _masked_pre_mul_sdpa_script),
305-
("post_div", _masked_post_div_sdpa_script),
306-
("post_mul", _masked_post_mul_sdpa_script),
334+
("pre_div", _unmasked_pre_div_sdpa_script),
335+
("pre_mul", _unmasked_pre_mul_sdpa_script),
336+
("post_div", _unmasked_post_div_sdpa_script),
337+
("post_mul", _unmasked_post_mul_sdpa_script),
338+
("masked_pre_div", _masked_pre_div_sdpa_script),
339+
("masked_pre_mul", _masked_pre_mul_sdpa_script),
340+
("masked_post_div", _masked_post_div_sdpa_script),
341+
("masked_post_mul", _masked_post_mul_sdpa_script),
307342
("custom_scale_post_mul", _custom_scale_post_mul_sdpa_script),
308343
("custom_scale_post_div", _custom_scale_post_div_sdpa_script),
309344
("custom_scale_pre_mul", _custom_scale_pre_mul_sdpa_script),
310345
("custom_scale_pre_div", _custom_scale_pre_div_sdpa_script),
311-
("custom_scale_post_mul_masked", _custom_scale_post_mul_sdpa_script),
312-
("custom_scale_post_div_masked", _custom_scale_post_div_sdpa_script),
313-
("custom_scale_pre_mul_masked", _custom_scale_pre_mul_sdpa_script),
314-
("custom_scale_pre_div_masked", _custom_scale_pre_div_sdpa_script),
346+
("masked_custom_scale_post_mul", _masked_custom_scale_post_mul_sdpa_script),
347+
("masked_custom_scale_post_div", _masked_custom_scale_post_div_sdpa_script),
348+
("masked_custom_scale_pre_mul", _masked_custom_scale_pre_mul_sdpa_script),
349+
("masked_custom_scale_pre_div", _masked_custom_scale_pre_div_sdpa_script),
315350
(
316351
"_custom_multi_scale_pre_mul_sdpa_script",
317352
_custom_multi_scale_pre_mul_sdpa_script,
318353
),
319354
]
320355
)
321356
def test_sdpa_fusion(self, name, script_func):
322-
test_case = SDPATestCase(script_func)
357+
test_case = SDPATestCase(script_func, with_mask="masked" in name)
323358
model = test_case.get_onnx_model()
324359
onnxscript.optimizer.optimize(model)
325360

326-
# inputs = test_case.get_ort_inputs()
327-
# original_outputs = ort_run("original", model, inputs)
361+
inputs = test_case.get_ort_inputs()
362+
original_outputs = ort_run("original", model, inputs)
328363

329364
count = fuse_sdpa(model, debug=True)
330365
self.assertGreater(count, 0)
@@ -347,8 +382,19 @@ def test_sdpa_fusion(self, name, script_func):
347382
# of scale_factor (is =default_scaling_factor)
348383
self.assertIsNone(sdpa_node.attributes.get("scale"))
349384

350-
# new_outputs = ort_run("optimized", model, inputs)
351-
# assert_allclose(new_outputs, original_outputs)
385+
replace_sdpa_by_mha(model, debug=True)
386+
387+
self.assertNotIn("SDPA", [n.op_type for n in model.graph])
388+
389+
new_outputs = ort_run("optimized", model, inputs)
390+
assert_allclose(new_outputs, original_outputs)
391+
392+
def test_invalid_sdpa_fusion_value_batch_dim(self):
393+
test_case = InvalidSDPATestCase(_masked_pre_mul_sdpa_script)
394+
model = test_case.get_onnx_model()
395+
onnxscript.optimizer.optimize(model)
396+
count = fuse_sdpa(model)
397+
self.assertEqual(count, 0)
352398

353399
def test_invalid_sdpa_fusion_value_batch_dim(self):
354400
test_case = InvalidSDPATestCase(_masked_pre_mul_sdpa_script)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
from typing import Union
6+
7+
import onnxscript.ir as ir
8+
from onnxscript.rewriter import _fusion_utils, pattern
9+
10+
Dim = Union[int, ir.SymbolicDim]
11+
12+
13+
class SDPAImplementation(pattern.RewriteRuleClassBase):
14+
def pattern(self, op, query, key_transposed, value):
15+
return op.SDPA(
16+
query,
17+
key_transposed,
18+
value,
19+
_allow_other_inputs=True, # Mask is optional
20+
_outputs=["sdpa_output"],
21+
_domain="ai.onnxruntime.fusion",
22+
)
23+
24+
def check(self, context, query, key_transposed, value, sdpa_output):
25+
bindings: dict[str, Dim] = {}
26+
_fusion_utils.check_shape(bindings, query, ["B", "H", "S", "Dh"])
27+
_fusion_utils.check_shape(bindings, key_transposed, ["B", "H", "Dh", "Skv"])
28+
_fusion_utils.check_shape(bindings, value, ["B", "H", "Skv", "Dv"])
29+
30+
self._num_heads = bindings["H"]
31+
if not isinstance(self._num_heads, int):
32+
return False
33+
self._use_mask_broadcast = True # TODO: optimize to avoid broadcast if not needed
34+
return isinstance(self._num_heads, int)
35+
36+
def rewrite(self, op, query, key_transposed, value, sdpa_output):
37+
sdpa_node = sdpa_output.producer()
38+
scale = sdpa_node.attributes.get("scale", None)
39+
to_3d_shape = op.Constant(value_ints=[0, 0, -1])
40+
to_4d_shape = op.Constant(value_ints=[0, 0, self._num_heads, -1])
41+
query_3d = op.Reshape(op.Transpose(query, perm=[0, 2, 1, 3]), to_3d_shape)
42+
key_3d = op.Reshape(op.Transpose(key_transposed, perm=[0, 3, 1, 2]), to_3d_shape)
43+
value_3d = op.Reshape(op.Transpose(value, perm=[0, 2, 1, 3]), to_3d_shape)
44+
45+
inputs = [query_3d, key_3d, value_3d]
46+
if len(sdpa_node.inputs) > 3:
47+
mask = sdpa_node.inputs[3]
48+
49+
if self._use_mask_broadcast:
50+
one = op.Constant(value_ints=[1])
51+
query_length = op.Shape(query, start=2, end=3)
52+
shape_11S1 = op.Concat(one, one, query_length, one, axis=0)
53+
mask = op.Expand(mask, shape_11S1)
54+
55+
inputs.extend([None, None, mask])
56+
57+
output = op.MultiHeadAttention(
58+
*inputs,
59+
num_heads=self._num_heads,
60+
scale=scale,
61+
_domain="com.microsoft",
62+
)
63+
output_4d = op.Reshape(output, to_4d_shape)
64+
output = op.Transpose(output_4d, perm=[0, 2, 1, 3])
65+
return output
66+
67+
68+
_rules = pattern.RewriteRuleSet([SDPAImplementation.rule()])
69+
70+
replace_sdpa_by_mha = _fusion_utils.apply_fusion_rules(_rules)

0 commit comments

Comments
 (0)