@@ -255,27 +255,39 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt,
255
255
// attention_probs = softmax(masked_score) = exp(masked_score - stats)
256
256
auto stats_b
257
257
= logical_tensor (id++, dt_inter, stats_sz, layout_type::strided);
258
- auto sub_out_b = logical_tensor (id++, dt, score_sz, layout_type::strided);
258
+ auto sub_out_b
259
+ = logical_tensor (id++, dt_inter, score_sz, layout_type::strided);
259
260
auto subtract_b = op (id++, op::kind::Subtract, " subtract" );
260
261
subtract_b.add_inputs ({masked_score_b, stats_b});
261
262
subtract_b.add_outputs ({sub_out_b});
262
263
263
- auto probs_b = logical_tensor (id++, dt, score_sz, layout_type::strided);
264
+ auto probs_b
265
+ = logical_tensor (id++, dt_inter, score_sz, layout_type::strided);
264
266
auto exp_b = op (id++, op::kind::Exp, " exp" );
265
267
exp_b.add_inputs ({sub_out_b});
266
268
exp_b.add_outputs ({probs_b});
267
269
270
+ // the following bmm doesn't support different input dtypes, insert a typecast
271
+ auto probs_b_cast = probs_b;
272
+ auto typecast_b = op (id++, op::kind::TypeCast, " typecast" );
273
+ if (dt != dt_inter) {
274
+ probs_b_cast = logical_tensor (id++, dt, score_sz, layout_type::strided);
275
+ typecast_b.add_inputs ({probs_b});
276
+ typecast_b.add_outputs ({probs_b_cast});
277
+ }
278
+
268
279
// compute dvalue = P^T * doutput
269
280
auto doutput = logical_tensor (id++, dt, qv_sz, layout_type::strided);
270
281
auto dvalue = logical_tensor (id++, dt, k_sz, layout_type::strided);
271
282
auto bmm_p_do = op (id++, op::kind::MatMul, " bmm1" );
272
283
bmm_p_do.set_attr <bool >(op::attr::transpose_a, true );
273
- bmm_p_do.add_inputs ({probs_b , doutput});
284
+ bmm_p_do.add_inputs ({probs_b_cast , doutput});
274
285
bmm_p_do.add_outputs ({dvalue});
275
286
276
287
// compute dprobs = doutput * value^T
277
288
auto value_b = logical_tensor (id++, dt, k_sz, layout_type::strided);
278
- auto dprobs = logical_tensor (id++, dt, score_sz, layout_type::strided);
289
+ auto dprobs
290
+ = logical_tensor (id++, dt_inter, score_sz, layout_type::strided);
279
291
auto bmm_do_v = op (id++, op::kind::MatMul, " bmm2" );
280
292
bmm_do_v.set_attr <bool >(op::attr::transpose_b, true );
281
293
bmm_do_v.add_inputs ({doutput, value_b});
@@ -291,7 +303,7 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt,
291
303
292
304
// compute dscored_score = dmasked_score / scale
293
305
auto dscaled_score
294
- = logical_tensor (id++, dt_inter , score_sz, layout_type::strided);
306
+ = logical_tensor (id++, dt , score_sz, layout_type::strided);
295
307
auto scale_div_b2 = op (id++, op::kind::Divide, " scale_div" );
296
308
scale_div_b2.add_inputs ({dmasked_score, scale_b});
297
309
scale_div_b2.add_outputs ({dscaled_score});
@@ -322,6 +334,10 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt,
322
334
sdpa_bwd.add_op (scale_div_b2);
323
335
sdpa_bwd.add_op (bmm_dscaled_score_k);
324
336
sdpa_bwd.add_op (bmm_dscaled_score_q);
337
+ if (dt != dt_inter)
338
+ // Add typecast op to the sdpa graph.
339
+ sdpa_bwd.add_op (typecast_b);
340
+
325
341
sdpa_bwd.finalize ();
326
342
327
343
// Get partitions from the sdpa graph.
0 commit comments