16
16
from onnxscript import script
17
17
from onnxscript .onnx_opset import opset18 as op
18
18
from onnxscript .onnx_types import FLOAT
19
+ from onnxscript .rewriter .ort_fusions ._test_utils import assert_allclose , ort_run
19
20
from onnxscript .rewriter .ort_fusions .sdpa import fuse_sdpa
21
+ from onnxscript .rewriter .ort_fusions .sdpa_via_mha import replace_sdpa_by_mha
20
22
21
23
B = 2 # batch size
22
24
N = 4 # number of heads
@@ -190,7 +192,7 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
190
192
191
193
192
194
@script ()
193
- def _custom_scale_pre_div_sdpa_script (query , key , value , mask ):
195
+ def _masked_custom_scale_pre_div_sdpa_script (query , key , value , mask ):
194
196
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
195
197
divisor = op .Constant (value_float = SQRT_CUSTOM_DIV_SCALE_FACTOR )
196
198
scaled_query = op .Div (query , divisor )
@@ -203,7 +205,7 @@ def _custom_scale_pre_div_sdpa_script(query, key, value, mask):
203
205
204
206
205
207
@script ()
206
- def _custom_scale_pre_mul_sdpa_script (query , key , value , mask ):
208
+ def _masked_custom_scale_pre_mul_sdpa_script (query , key , value , mask ):
207
209
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
208
210
multiplier = op .Constant (value_float = SQRT_CUSTOM_MUL_SCALE_FACTOR )
209
211
scaled_query = op .Mul (query , multiplier )
@@ -216,7 +218,7 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value, mask):
216
218
217
219
218
220
@script ()
219
- def _custom_scale_post_div_sdpa_script (query , key , value , mask ):
221
+ def _masked_custom_scale_post_div_sdpa_script (query , key , value , mask ):
220
222
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
221
223
divisor = op .Constant (value_float = CUSTOM_DIV_SCALE_FACTOR )
222
224
attn_score = op .MatMul (query , key_transposed )
@@ -228,7 +230,7 @@ def _custom_scale_post_div_sdpa_script(query, key, value, mask):
228
230
229
231
230
232
@script ()
231
- def _custom_scale_post_mul_sdpa_script (query , key , value , mask ):
233
+ def _masked_custom_scale_post_mul_sdpa_script (query , key , value , mask ):
232
234
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
233
235
multiplier = op .Constant (value_float = CUSTOM_MUL_SCALE_FACTOR )
234
236
attn_score = op .MatMul (query , key_transposed )
@@ -240,15 +242,19 @@ def _custom_scale_post_mul_sdpa_script(query, key, value, mask):
240
242
241
243
242
244
class SDPATestCase :
243
- def __init__ (self , script_func ):
245
+ def __init__ (self , script_func , * , with_mask ):
244
246
self .script_func = script_func
247
+ self .with_mask = with_mask
245
248
246
249
def get_onnx_model (self ):
247
250
if not hasattr (self , "_onnx_model" ):
248
251
qkv_type = FLOAT [B , N , S , H ]
249
252
mask_type = FLOAT [B , N , S , S ]
253
+ input_types = [qkv_type , qkv_type , qkv_type ]
254
+ if self .with_mask :
255
+ input_types .append (mask_type )
250
256
model_proto = self .script_func .to_model_proto (
251
- input_types = [ qkv_type , qkv_type , qkv_type , mask_type ] , output_types = [qkv_type ]
257
+ input_types = input_types , output_types = [qkv_type ]
252
258
)
253
259
self ._onnx_model = ir .serde .deserialize_model (model_proto )
254
260
return self ._onnx_model
@@ -259,6 +265,35 @@ def get_ort_inputs(self):
259
265
"query" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
260
266
"key" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
261
267
"value" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
268
+ }
269
+ if self .with_mask :
270
+ inputs ["mask" ] = numpy .random .rand (B , N , S , S ).astype (numpy .float32 )
271
+ self ._ort_inputs = inputs
272
+ return self ._ort_inputs
273
+
274
+
275
+ class InvalidSDPATestCase :
276
+ def __init__ (self , script_func ):
277
+ self .script_func = script_func
278
+
279
+ def get_onnx_model (self ):
280
+ if not hasattr (self , "_onnx_model" ):
281
+ qk_type = FLOAT [B , N , S , H ]
282
+ # We broadcast value in the batch dimension, which is not supported by SDPA fusion
283
+ v_type = FLOAT [1 , N , S , H ]
284
+ mask_type = FLOAT [B , N , S , S ]
285
+ model_proto = self .script_func .to_model_proto (
286
+ input_types = [qk_type , qk_type , v_type , mask_type ], output_types = [qk_type ]
287
+ )
288
+ self ._onnx_model = ir .serde .deserialize_model (model_proto )
289
+ return self ._onnx_model
290
+
291
+ def get_ort_inputs (self ):
292
+ if not hasattr (self , "_ort_inputs" ):
293
+ inputs = {
294
+ "query" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
295
+ "key" : numpy .random .rand (B , N , S , H ).astype (numpy .float32 ),
296
+ "value" : numpy .random .rand (1 , N , S , H ).astype (numpy .float32 ),
262
297
"mask" : numpy .random .rand (B , N , S , S ).astype (numpy .float32 ),
263
298
}
264
299
self ._ort_inputs = inputs
@@ -296,35 +331,35 @@ def get_ort_inputs(self):
296
331
class TestSDPAFusion (unittest .TestCase ):
297
332
@parameterized .parameterized .expand (
298
333
[
299
- ("unmasked_pre_div " , _unmasked_pre_div_sdpa_script ),
300
- ("unmasked_pre_mul " , _unmasked_pre_mul_sdpa_script ),
301
- ("unmasked_post_div " , _unmasked_post_div_sdpa_script ),
302
- ("unmasked_post_mul " , _unmasked_post_mul_sdpa_script ),
303
- ("pre_div " , _masked_pre_div_sdpa_script ),
304
- ("pre_mul " , _masked_pre_mul_sdpa_script ),
305
- ("post_div " , _masked_post_div_sdpa_script ),
306
- ("post_mul " , _masked_post_mul_sdpa_script ),
334
+ ("pre_div " , _unmasked_pre_div_sdpa_script ),
335
+ ("pre_mul " , _unmasked_pre_mul_sdpa_script ),
336
+ ("post_div " , _unmasked_post_div_sdpa_script ),
337
+ ("post_mul " , _unmasked_post_mul_sdpa_script ),
338
+ ("masked_pre_div " , _masked_pre_div_sdpa_script ),
339
+ ("masked_pre_mul " , _masked_pre_mul_sdpa_script ),
340
+ ("masked_post_div " , _masked_post_div_sdpa_script ),
341
+ ("masked_post_mul " , _masked_post_mul_sdpa_script ),
307
342
("custom_scale_post_mul" , _custom_scale_post_mul_sdpa_script ),
308
343
("custom_scale_post_div" , _custom_scale_post_div_sdpa_script ),
309
344
("custom_scale_pre_mul" , _custom_scale_pre_mul_sdpa_script ),
310
345
("custom_scale_pre_div" , _custom_scale_pre_div_sdpa_script ),
311
- ("custom_scale_post_mul_masked " , _custom_scale_post_mul_sdpa_script ),
312
- ("custom_scale_post_div_masked " , _custom_scale_post_div_sdpa_script ),
313
- ("custom_scale_pre_mul_masked " , _custom_scale_pre_mul_sdpa_script ),
314
- ("custom_scale_pre_div_masked " , _custom_scale_pre_div_sdpa_script ),
346
+ ("masked_custom_scale_post_mul " , _masked_custom_scale_post_mul_sdpa_script ),
347
+ ("masked_custom_scale_post_div " , _masked_custom_scale_post_div_sdpa_script ),
348
+ ("masked_custom_scale_pre_mul " , _masked_custom_scale_pre_mul_sdpa_script ),
349
+ ("masked_custom_scale_pre_div " , _masked_custom_scale_pre_div_sdpa_script ),
315
350
(
316
351
"_custom_multi_scale_pre_mul_sdpa_script" ,
317
352
_custom_multi_scale_pre_mul_sdpa_script ,
318
353
),
319
354
]
320
355
)
321
356
def test_sdpa_fusion (self , name , script_func ):
322
- test_case = SDPATestCase (script_func )
357
+ test_case = SDPATestCase (script_func , with_mask = "masked" in name )
323
358
model = test_case .get_onnx_model ()
324
359
onnxscript .optimizer .optimize (model )
325
360
326
- # inputs = test_case.get_ort_inputs()
327
- # original_outputs = ort_run("original", model, inputs)
361
+ inputs = test_case .get_ort_inputs ()
362
+ original_outputs = ort_run ("original" , model , inputs )
328
363
329
364
count = fuse_sdpa (model , debug = True )
330
365
self .assertGreater (count , 0 )
@@ -347,8 +382,19 @@ def test_sdpa_fusion(self, name, script_func):
347
382
# of scale_factor (is =default_scaling_factor)
348
383
self .assertIsNone (sdpa_node .attributes .get ("scale" ))
349
384
350
- # new_outputs = ort_run("optimized", model, inputs)
351
- # assert_allclose(new_outputs, original_outputs)
385
+ replace_sdpa_by_mha (model , debug = True )
386
+
387
+ self .assertNotIn ("SDPA" , [n .op_type for n in model .graph ])
388
+
389
+ new_outputs = ort_run ("optimized" , model , inputs )
390
+ assert_allclose (new_outputs , original_outputs )
391
+
392
+ def test_invalid_sdpa_fusion_value_batch_dim (self ):
393
+ test_case = InvalidSDPATestCase (_masked_pre_mul_sdpa_script )
394
+ model = test_case .get_onnx_model ()
395
+ onnxscript .optimizer .optimize (model )
396
+ count = fuse_sdpa (model )
397
+ self .assertEqual (count , 0 )
352
398
353
399
def test_invalid_sdpa_fusion_value_batch_dim (self ):
354
400
test_case = InvalidSDPATestCase (_masked_pre_mul_sdpa_script )
0 commit comments