Skip to content

graph: doc, interface, backend: support SDPA training #3396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

ElaineBao
Copy link
Contributor

@ElaineBao ElaineBao commented Jun 6, 2025

Description

  • This PR implements rfcs: graph api: support SDPA training #3233

    • extend MatMul and SoftMaxBackward operations to support new data type combinations that used in SDPA training. Related doc, library implementation, benchdnn cases are added.
    • extend SoftMax operation to support optional stats output. Related doc, library implementation are added.
    • extend benchdnn graph to support the validation of SoftMax stats.
    • support SDPA training backward pattern (the forward pattern is the same as inference). Related library reference implementation is added.
    • document of the training pattern
    • validate correctness of SDPA training forward and backward pattern with benchdnn graph.
    • add an example to demonstrate the entire flow of training forward and backward.

@ElaineBao ElaineBao self-assigned this Jun 6, 2025
@ElaineBao ElaineBao added the component:graph-api Codeowner: @oneapi-src/onednn-graph label Jun 6, 2025
@github-actions github-actions bot added documentation A request to change/fix/improve the documentation. Codeowner: @oneapi-src/onednn-doc component:tests Codeowner: @oneapi-src/onednn-arch component:examples labels Jun 6, 2025
@ElaineBao ElaineBao force-pushed the yixin/sdpa-training-impl branch 2 times, most recently from 2828827 to 9e6b119 Compare June 18, 2025 06:06
@ElaineBao ElaineBao marked this pull request as ready for review June 18, 2025 06:48
@ElaineBao ElaineBao requested review from a team as code owners June 18, 2025 06:48
@ElaineBao ElaineBao changed the title [DO NOT REVIEW] graph: doc, interface, backend: support SDPA training graph: doc, interface, backend: support SDPA training Jun 18, 2025
|:-----|:----------------|:------|
| f32 | f32, bf16, f16 | f32 |
| bf16 | bf16 | f32 |
| f16 | f16 | f32 |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LuFinch Could you please confirm that the stats tensor is always f32 in PyTorch?

@ElaineBao It would be better to clarify this in the RFC document as well. Thanks.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed that all the CUDA FlashAttention/EfficientAttention/CudnnAttention backends use fp32 for logsumexp.

// The order of input logical tensors in inputs is not certain, we need
// to record the input offset in a certain order of ops.
CHECK_BOOL(record_input_offset(sg, inputs));
dims src1_user_dims = ltw(inputs[graph_inport[mm1_src]]).vdims();
ndims = src1_user_dims.size();
VCHECK_SDP_DECOMP(ndims == 4 || ndims == 5, false,
"Input dims should be 4 or 5, but got %zu", src1_user_dims.size());
VCHECK_SDP_DECOMP(
outputs.size() == 1, false, "Doesn't support SDPA training yet");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to make the error message consistent with the checks.

Suggested change
outputs.size() == 1, false, "Doesn't support SDPA training yet");
outputs.size() == 1, false, "does not support multiple outputs");

// At least 3 inputs: Q, K, V
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments,
"At least 3 inputs are required");
VCHECK_SDP_PRIMITIVE(outputs.size() == 1, status::unimplemented,
"Doesn't support SDPA training yet");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

}

auto f32_dst = dst;
if (f32_dst->get_logical_tensor().data_type == dnnl::impl::data_type::f32) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess below should work:

Suggested change
if (f32_dst->get_logical_tensor().data_type == dnnl::impl::data_type::f32) {
if (f32_dst->get_logical_tensor().data_type == data_type::f32) {

= empty_logical_tensor_with_default_id();
f32_dst = std::make_shared<value_t>(
*new_softmax_op, 0, softmax_op_out_lt, true);
f32_dst->set_data_type(dnnl::impl::data_type::f32);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. I think it's not needed to spell out all the nested namespaces.

@@ -698,6 +698,168 @@ static status_t select_handler(
return status::success;
}

static status_t softmax_handler(
const std::shared_ptr<op_t> &op, subgraph_rewriter_t &rewriter) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of the lowering function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the main purpose is to compute stats with multiple primitives. If stats output does not exist, it will use a single softmax primitive to compute.

@@ -177,6 +177,48 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_fusion_gpu)
return std::make_shared<sdp_base_t<>>();
});

DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean the inference fusion pattern will be reused for training forward?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, correct

@ElaineBao ElaineBao force-pushed the yixin/sdpa-training-impl branch from 9e6b119 to 6f12739 Compare June 18, 2025 09:15
@ElaineBao ElaineBao force-pushed the yixin/sdpa-training-impl branch from 6f12739 to 4efa167 Compare June 19, 2025 02:00

switch (exec_arg) {
case DNNL_ARG_SRC:
SAFE(::custom::fill_mem(mem, ref_mem, -8, 7), WARN);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we decide this data range?

@@ -203,6 +210,13 @@ int ref_partition_t::init_graph_mem(
}

void ref_partition_t::exec_ops(res_t *res) {
// check if there's softmax backward op in the partition,
// which will be a candidate for sdpa training backward pattern
bool has_softmax_backward = std::any_of(partition_ops_ref_.begin(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May use static?

@@ -263,7 +288,7 @@ void ref_partition_t::exec_ops(res_t *res) {
|| (parent0 == "Multiply" && parent1 == "MatMul");
}

if (is_sdpa_pattern || is_gated_mlp_pattern) {
if (is_sdpa_pattern || is_sdpa_bwd_pattern || is_gated_mlp_pattern) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names are a little bit confusing: actually they are used to indicate whether the current op needs precision downcasting, but not the status of the whole pattern. Could you help improve it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:examples component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch documentation A request to change/fix/improve the documentation. Codeowner: @oneapi-src/onednn-doc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants