Skip to content

Commit 6f12739

Browse files
committed
tests: benchdnn: graph: add cases for sdpa training
1 parent df167b4 commit 6f12739

File tree

5 files changed

+2015
-0
lines changed

5 files changed

+2015
-0
lines changed

tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/GQA-fp16-v2.json
1212
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json
1313
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json
14+
--reset --dt=0:f32+1:f32+10:f32+13:f32+14:f32 --case=complex_fusion/mha/sdpa-plain-training-forward-bf16-f32.json
15+
--reset --case=complex_fusion/mha/sdpa-plain-training-backward-f32.json
1416

1517
# f16 inputs + f32 intermediates + f16 outputs
1618
--reset --op-kind=1:Multiply,1:Divide --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json
@@ -38,6 +40,8 @@
3840
--reset --dt=2:f32+5:f32 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json
3941
--reset --dt=2:f32+6:f32 --case=complex_fusion/mha/sdpa-plain-wo-scale-f16-bs1.json
4042
--reset --case=complex_fusion/mha/sdpa-plain-bottom-right-implicit-causal-mask-f16-f32.json
43+
--reset --dt=0:f16+1:f16+10:f16+13:f16+14:f16 --case=complex_fusion/mha/sdpa-plain-training-forward-bf16-f32.json
44+
--reset --dt=16:f16+17:f16+32:f16+33:f16+34:f16+36:f16+44:f16+45:f16+47:f16 --case=complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json
4145
# q_seq_len != kv_seq_len
4246
--reset --in-shapes=1:1x16x128x64+24:1x16x128x64 --case=complex_fusion/mha/sdpa-plain-bottom-right-implicit-causal-mask-f16-f32.json
4347

@@ -63,6 +67,8 @@
6367
--reset --dt=2:f32+5:f32+0:bf16+1:bf16+4:bf16+7:bf16+9:bf16+10:bf16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json
6468
--reset --dt=2:f32+6:f32+0:bf16+1:bf16+5:bf16+7:bf16+8:bf16+9:bf16 --case=complex_fusion/mha/sdpa-plain-wo-scale-f16-bs1.json
6569
--reset --dt=0:bf16+1:bf16+4:bf16+22:bf16+24:bf16+25:bf16 --case=complex_fusion/mha/sdpa-plain-bottom-right-implicit-causal-mask-f16-f32.json
70+
--reset --case=complex_fusion/mha/sdpa-plain-training-forward-bf16-f32.json
71+
--reset --case=complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json
6672

6773
# int8 graphs
6874
--reset --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json

tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,24 @@
1212
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/GQA-fp16-v2.json
1313
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f16.json
1414
--reset --dt=f32,bf16,f16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json
15+
--reset --dt=0:f32+1:f32+10:f32+13:f32+14:f32 --case=complex_fusion/mha/sdpa-plain-training-forward-bf16-f32.json
16+
--reset --case=complex_fusion/mha/sdpa-plain-training-backward-f32.json
17+
1518
# f16 inputs + f32 intermediates + f16 outputs
1619
--reset --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json
1720
--reset --dt=4:f32+9:f32+14:f32 --case=complex_fusion/mha/GQA-fp16-v2.json
1821
--reset --case=complex_fusion/mha/sdpa-plain-bottom-right-implicit-causal-mask-f16-f32.json
1922
--reset --case=complex_fusion/mha/codegemma-bf16-f32.json
2023
--reset --case=complex_fusion/mha/gemma2-bf16-f32.json
24+
--reset --dt=0:f16+1:f16+10:f16+13:f16+14:f16 --case=complex_fusion/mha/sdpa-plain-training-forward-bf16-f32.json
25+
--reset --dt=16:f16+17:f16+32:f16+33:f16+34:f16+36:f16+44:f16+45:f16+47:f16 --case=complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json
2126

2227
# bf16 inputs + f32 intermediates + bf16 outputs
2328
--reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+5:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json
2429
--reset --dt=4:f32+9:f32+14:f32+1:bf16+3:bf16+8:bf16+11:bf16+16:bf16+20:bf16+19:bf16 --case=complex_fusion/mha/GQA-fp16-v2.json
2530
--reset --dt=0:bf16+1:bf16+4:bf16+22:bf16+24:bf16+25:bf16 --case=complex_fusion/mha/sdpa-plain-bottom-right-implicit-causal-mask-f16-f32.json
31+
--reset --case=complex_fusion/mha/sdpa-plain-training-forward-bf16-f32.json
32+
--reset --case=complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json
2633

2734
# int8 graphs
2835
--reset --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json

0 commit comments

Comments
 (0)