Skip to content

Commit 9e6b119

Browse files
committed
examples: graph: update sdpa training example
1 parent c8f29a0 commit 9e6b119

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

examples/graph/sdpa_training.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,27 +255,39 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt,
255255
// attention_probs = softmax(masked_score) = exp(masked_score - stats)
256256
auto stats_b
257257
= logical_tensor(id++, dt_inter, stats_sz, layout_type::strided);
258-
auto sub_out_b = logical_tensor(id++, dt, score_sz, layout_type::strided);
258+
auto sub_out_b
259+
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
259260
auto subtract_b = op(id++, op::kind::Subtract, "subtract");
260261
subtract_b.add_inputs({masked_score_b, stats_b});
261262
subtract_b.add_outputs({sub_out_b});
262263

263-
auto probs_b = logical_tensor(id++, dt, score_sz, layout_type::strided);
264+
auto probs_b
265+
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
264266
auto exp_b = op(id++, op::kind::Exp, "exp");
265267
exp_b.add_inputs({sub_out_b});
266268
exp_b.add_outputs({probs_b});
267269

270+
// the following bmm doesn't support different input dtypes, insert a typecast
271+
auto probs_b_cast = probs_b;
272+
auto typecast_b = op(id++, op::kind::TypeCast, "typecast");
273+
if (dt != dt_inter) {
274+
probs_b_cast = logical_tensor(id++, dt, score_sz, layout_type::strided);
275+
typecast_b.add_inputs({probs_b});
276+
typecast_b.add_outputs({probs_b_cast});
277+
}
278+
268279
// compute dvalue = P^T * doutput
269280
auto doutput = logical_tensor(id++, dt, qv_sz, layout_type::strided);
270281
auto dvalue = logical_tensor(id++, dt, k_sz, layout_type::strided);
271282
auto bmm_p_do = op(id++, op::kind::MatMul, "bmm1");
272283
bmm_p_do.set_attr<bool>(op::attr::transpose_a, true);
273-
bmm_p_do.add_inputs({probs_b, doutput});
284+
bmm_p_do.add_inputs({probs_b_cast, doutput});
274285
bmm_p_do.add_outputs({dvalue});
275286

276287
// compute dprobs = doutput * value^T
277288
auto value_b = logical_tensor(id++, dt, k_sz, layout_type::strided);
278-
auto dprobs = logical_tensor(id++, dt, score_sz, layout_type::strided);
289+
auto dprobs
290+
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
279291
auto bmm_do_v = op(id++, op::kind::MatMul, "bmm2");
280292
bmm_do_v.set_attr<bool>(op::attr::transpose_b, true);
281293
bmm_do_v.add_inputs({doutput, value_b});
@@ -291,7 +303,7 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt,
291303

292304
// compute dscored_score = dmasked_score / scale
293305
auto dscaled_score
294-
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
306+
= logical_tensor(id++, dt, score_sz, layout_type::strided);
295307
auto scale_div_b2 = op(id++, op::kind::Divide, "scale_div");
296308
scale_div_b2.add_inputs({dmasked_score, scale_b});
297309
scale_div_b2.add_outputs({dscaled_score});
@@ -322,6 +334,10 @@ void bench_sdpa(engine::kind ekind, logical_tensor::data_type dt,
322334
sdpa_bwd.add_op(scale_div_b2);
323335
sdpa_bwd.add_op(bmm_dscaled_score_k);
324336
sdpa_bwd.add_op(bmm_dscaled_score_q);
337+
if (dt != dt_inter)
338+
// Add typecast op to the sdpa graph.
339+
sdpa_bwd.add_op(typecast_b);
340+
325341
sdpa_bwd.finalize();
326342

327343
// Get partitions from the sdpa graph.

0 commit comments

Comments
 (0)