Skip to content

graph: doc, interface, backend: support SDPA training #3396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 15, 2025
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
6e9f529
doc: graph: support new dtype combination for SoftMaxBackward
ElaineBao Jun 18, 2025
446e1c5
graph: interface: support new dtype combination for SoftMaxBackward
ElaineBao Jun 18, 2025
1bc92e1
graph: backend: dnnl: support new dtype combination for SoftMaxBackward
ElaineBao Jun 18, 2025
de352a8
benchdnn: inputs: graph: add cases for SoftMaxBackward
ElaineBao Jun 18, 2025
1645e1c
graph: interface: add optional stats output to SoftMax op
ElaineBao May 28, 2025
72a2487
graph: backend: dnnl: add optional stats output to SoftMax op
ElaineBao May 28, 2025
729af38
graph: backend: dnnl: assign external output buffer for softmax dst
ElaineBao Jun 20, 2025
264f951
graph: backend: dnnl: support sdpa training backward pattern
ElaineBao May 28, 2025
25a8196
doc: graph: add optional stats output to SoftMax op
ElaineBao Jun 5, 2025
e97be51
tests: benchdnn: extend softmax driver
ElaineBao Jul 8, 2025
8f018f5
tests: benchdnn: graph: check correctness for softmax stats
ElaineBao Jul 8, 2025
cc9a03a
graph: backend: dnnl: support sdpa training forward with large partition
ElaineBao Jun 12, 2025
f4975dc
tests: benchdnn: graph: recomputing softmax stats for input displacer
ElaineBao Jun 13, 2025
8ab9af0
tests: benchdnn: graph: adjust correctness check for sdpa training
ElaineBao Jun 18, 2025
b89f921
tests: benchdnn: graph: add cases for sdpa training
ElaineBao Jun 18, 2025
a321457
doc: graph: support SDPA training
ElaineBao Jul 3, 2025
7c97dfd
tests: benchdnn: graph: add case for softmax
ElaineBao Jun 20, 2025
0c7ef5b
tests: benchdnn: graph: rename variables in ref_partition
ElaineBao Jun 20, 2025
dc16f07
tests: benchdnn: graph: convert opkind2driver
ElaineBao Jul 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
84 changes: 66 additions & 18 deletions doc/graph/fusion_patterns/sdpa.md
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ SDPA graph, getting partition from the graph, and optimizing the kernels
underneath. In general, an SDPA pattern is defined as a directional acyclic
graph (DAG) using oneDNN Graph API.

### Floating-point SDPA
### Floating-point SDPA for Inference

oneDNN defines floating-point (f32, bf16, or f16) SDPA as follows. The blue
nodes are required when defining an SDPA pattern while the brown parts are
@@ -89,19 +89,65 @@ optional.

![SDPA-Reorder](images/sdpa-reorder.png)

### Floating-point SDPA for Training Forward Propagation

oneDNN defines floating-point (f32, bf16, or f16) SDPA for training forward
propagation as follows. The blue nodes are required while the brown nodes are optional.

![SDPA pattern](images/sdpa_forward.png)

The only difference between the inference and training forward propagation
patterns is that, for training forward propagation, the `Stats` output of the
SoftMax operation is needed. See [SoftMax](@ref dev_guide_op_softmax) in Graph
API for more details.

### Floating-point SDPA for Training Backpropagation

oneDNN defines floating-point (f32, bf16, or f16) SDPA for training
backpropagation as follows. The blue nodes are required while the brown nodes
are optional.

![SDPA backward pattern](images/sdpa_backward.png)

1. The first MatMul computes the score between Query and Key, similar to
inference and training forward propagation. See
[MatMul](@ref dev_guide_op_matmul) in Graph API.
2. The Scale node is optional and scales the output of the first MatMul using a
scaling factor. This can be implemented using [Multiply](@ref dev_guide_op_multiply)
or [Divide](@ref dev_guide_op_divide) in Graph API.
3. The Mask node is optional and applies an attention mask to the output of the
previous Scale node. For training backpropagation, only explicit user-generated
masks are currently supported. The mask definition is the same as in
inference and training forward propagation.
4. The Subtract and Exp operations take the masked output and `Stats` as inputs
and recover the probabilities computed by SoftMax in the training forward
propagation. See [Subtract](@ref dev_guide_op_subtract) and [Exp](@ref dev_guide_op_exp)
in Graph API.
5. The TypeCast and MatMul operations after Exp are used to compute the
gradients with respect to Value. TypeCast is required for bf16 and f16
training scenarios. See [TypeCast](@ref dev_guide_op_typecast) in Graph API.
6. The MatMul takes the output gradients (`dO`) and the Value as inputs to
compute the gradients of the probabilities.
7. The SoftMaxBackward operation computes the gradients of the scaled output.
See [SoftMaxBackward](@ref dev_guide_op_softmaxbackward) in Graph API.
8. The Scale node after SoftMaxBackward corresponds to the forward Scale node
and is used to compute the gradients of the score.
9. The TypeCast and two MatMul operations after the Scale node compute the
gradients with respect to Query and Key, respectively. TypeCast is required
for bf16 and f16 training scenarios.

## Data Types

oneDNN supports the floating-point SDPA pattern with data types f32, bf16, and
f16. You can specify the data type via the input and output logical tensors'
data type fields for each operation.

oneDNN supports bf16 or f16 SDPA with f32 intermediate type, which means the
Q/K/V tensors have bf16 or f16 data type while the output of the first MatMul,
Scale, Mask, and the input of SoftMax are in f32 data type.

oneDNN supports the quantized SDPA pattern with int8-f32 mixed precision,
int8-bf16 mixed precision, and int8-f16 mixed precision data types.
oneDNN supports bf16 or f16 SDPA with f32 intermediate type. For
inference and traing forward propagation, the Q, K and V tensors use bf16 or f16
data types, while the outputs of the first MatMul, Scale, Mask, and the input of
SoftMax are in f32. Similarly, in training backpropagation, the Q, K, V, dO, dQ,
dK and dV tensors use bf16 or f16, while the Stats input uses f32. The intermediate
tensors are in f32, except those after TypeCast, which cast to bf16 or f16.

The definition of the data types and support status on different CPU and GPU
platforms follow the general description in @ref dev_guide_data_types.
@@ -122,20 +168,22 @@ platforms follow the general description in @ref dev_guide_data_types.
Divide, and Select operations require the input tensors to have the same
shape or the shapes can be properly broadcasted based on the operation
attribute.
3. CPU
- Optimized implementation is available for 4D Q/K tensors with shape defined
as (N, H, S, D_qk) and V tensor with shape defined as (N, H, S, D_v).
- Optimized implementation is available for OpenMP runtime and Threadpool
3. Dropout is currently not supported in SDPA training.
4. CPU
- Optimized implementation for inference is available for 4D Q/K tensors with
shape defined as (N, H, S, D_qk) and V tensor with shape defined as
(N, H, S, D_v).
- Optimized implementation for inference is available for OpenMP runtime and Threadpool
runtime on Intel Architecture Processors.
- Specifically for OpenMP runtime, the optimized implementation requires `N *
H > 2 * thread number` to get enough parallelism.
4. GPU
- Optimized implementation is available for 4D Q/K tensors with shape defined
as (N, H, S, D_qk) and V tensor with shape defined as (N, H, S, D_v) where
D_qk equals D_v.
- Optimized implementation is available for `f16` or `bf16` SDPA with `f32`
intermediate data type and `D <= 512` on Intel Graphics Products with
Intel(R) Xe Matrix Extensions (Intel(R) XMX) support.
5. GPU
- Optimized implementation for inference is available for 4D Q/K tensors with
shape defined as (N, H, S, D_qk) and V tensor with shape defined as (N, H,
S, D_v) where D_qk equals D_v.
- Optimized implementation for inference is available for `f16` or `bf16`
SDPA with `f32` intermediate data type and `D <= 512` on Intel Graphics
Products with Intel(R) Xe Matrix Extensions (Intel(R) XMX) support.

## Example

20 changes: 13 additions & 7 deletions doc/graph/operations/Softmax.md
Original file line number Diff line number Diff line change
@@ -6,8 +6,13 @@ SoftMax {#dev_guide_op_softmax}
SoftMax operation applies the following formula on every element of \src tensor
(the variable names follow the standard @ref dev_guide_conventions):

\f[ dst_i = \frac{exp(src_i)}{\sum_{j=1}^{C} exp(src_j)} \f]
where \f$ C \f$ is a size of tensor along axis dimension.
\f[ dst_i = \frac{exp(src_i - max)}{\sum_{j=1}^{C} exp(src_j - max)} \f]
where \f$ C \f$ is a size of tensor along axis dimension. Subtracting the
maximum value along the axis improves numerical stability.

If the optional `stats` output is requested, it is defined as:

\f[ \stats = max + \log{\sum_{j=1}^{C} exp(src_j - max)} \f]

## Operation attributes

@@ -38,13 +43,14 @@ constructing an operation.
| Index | Argument Name | Required or Optional |
|:------|:--------------|:---------------------|
| 0 | `dst` | Required |
| 1 | `stats` | Optional |

## Supported data types

SoftMax operation supports the following data type combinations.

| Src | Dst |
|:-----|:----------------|
| f32 | f32, bf16, f16 |
| bf16 | bf16 |
| f16 | f16 |
| Src | Dst | Stats |
|:-----|:----------------|:------|
| f32 | f32, bf16, f16 | f32 |
| bf16 | bf16 | f32 |
| f16 | f16 | f32 |
10 changes: 5 additions & 5 deletions doc/graph/operations/SoftmaxBackward.md
Original file line number Diff line number Diff line change
@@ -33,8 +33,8 @@ constructing an operation.

SoftMaxBackward operation supports the following data type combinations.

| Dst | Diff_dst | Diff_src |
|:-----|:---------|:---------|
| f32 | f32 | f32 |
| bf16 | bf16 | bf16 |
| f16 | f16 | f16 |
| Dst | Diff_dst | Diff_src |
|:-----|:---------|:--------------|
| f32 | f32 | f32 |
| bf16 | bf16 | f32, bf16 |
| f16 | f16 | f32, f16 |
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/sdp_decomp.cpp
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ status_t sdp_decomp_kernel_t<quantized, dt>::compile_impl(
BACKEND_DNNL_CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs));

// Check if it's supported by decomposition kernel
if (!sdp_cfg_.initial_check(subgraph_, inputs))
if (!sdp_cfg_.initial_check(subgraph_, inputs, outputs))
return status::unimplemented;

subgraph_visualizer_t vis(part->id(), [this](const value_t *val) {
5 changes: 4 additions & 1 deletion src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp
Original file line number Diff line number Diff line change
@@ -27,14 +27,17 @@ namespace graph {
namespace dnnl_impl {

bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
const std::vector<logical_tensor_t> &inputs) {
const std::vector<logical_tensor_t> &inputs,
const std::vector<logical_tensor_t> &outputs) {
// The order of input logical tensors in inputs is not certain, we need
// to record the input offset in a certain order of ops.
CHECK_BOOL(record_input_offset(sg, inputs));
dims src1_user_dims = ltw(inputs[graph_inport[mm1_src]]).vdims();
ndims = src1_user_dims.size();
VCHECK_SDP_DECOMP(ndims == 4 || ndims == 5, false,
"Input dims should be 4 or 5, but got %zu", src1_user_dims.size());
VCHECK_SDP_DECOMP(
outputs.size() == 1, false, "does not support multiple outputs");

// Initialize SDP input dimension according to the src of mm1
int index = 0;
3 changes: 2 additions & 1 deletion src/graph/backend/dnnl/kernels/sdp_decomp_config.hpp
Original file line number Diff line number Diff line change
@@ -156,7 +156,8 @@ struct sdp_decomp_config_t {
// If the check passes, initialize few members according to inputs
// If no, return unimplemented status directly and fallback to large kernel
bool initial_check(const std::shared_ptr<subgraph_t> &sg,
const std::vector<logical_tensor_t> &inputs);
const std::vector<logical_tensor_t> &inputs,
const std::vector<logical_tensor_t> &outputs);

// Used to construct all params that SDP need
template <bool quantized = false,
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/sdp_primitive.cpp
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ status_t sdp_primitive_kernel_t<quantized>::compile_impl(
p_engine_, part->get_fpmath_mode(), false, true);
CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs));

CHECK(cfg_.initial_check(subgraph_, inputs));
CHECK(cfg_.initial_check(subgraph_, inputs, outputs));

subgraph_visualizer_t vis(part->id(), [this](const value_t *val) {
return this->memory_planner_.get_memory_info(val);
5 changes: 4 additions & 1 deletion src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp
Original file line number Diff line number Diff line change
@@ -169,10 +169,13 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,

status_t sdp_primitive_config_t::initial_check(
const std::shared_ptr<subgraph_t> &sg,
const std::vector<logical_tensor_t> &inputs, bool v1_kernel) {
const std::vector<logical_tensor_t> &inputs,
const std::vector<logical_tensor_t> &outputs, bool v1_kernel) {
// At least 3 inputs: Q, K, V
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments,
"At least 3 inputs are required");
VCHECK_SDP_PRIMITIVE(outputs.size() == 1, status::unimplemented,
"does not support multiple outputs");

// Ukernel doesn't support f32 datatype now
VCHECK_SDP_PRIMITIVE(inputs[0].data_type != dnnl_data_type_t::dnnl_f32,
1 change: 1 addition & 0 deletions src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp
Original file line number Diff line number Diff line change
@@ -84,6 +84,7 @@ struct sdp_primitive_config_t {
// 3. only support 4-dims tensor
status_t initial_check(const std::shared_ptr<subgraph_t> &sg,
const std::vector<logical_tensor_t> &inputs,
const std::vector<logical_tensor_t> &outputs,
bool v1_kernel = false);

// Initialize parameters and primitive.
2 changes: 1 addition & 1 deletion src/graph/backend/dnnl/kernels/sdp_primitive_v1.cpp
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ status_t sdp_primitive_v1_kernel_t::compile_impl(
p_engine_, part->get_fpmath_mode(), false, true);
CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs));

CHECK(cfg_.initial_check(subgraph_, inputs, true));
CHECK(cfg_.initial_check(subgraph_, inputs, outputs, true));

subgraph_visualizer_t vis(part->id(), [this](const value_t *val) {
return this->memory_planner_.get_memory_info(val);
10 changes: 6 additions & 4 deletions src/graph/backend/dnnl/op_executable.cpp
Original file line number Diff line number Diff line change
@@ -1388,12 +1388,14 @@ softmax_bwd_executable_t::desc_t softmax_bwd_executable_t::create_desc(
assertm(res.first, "Incorrect axis value.");
const auto axis = res.second;

// construct src with layout information from dst and data type information
// from diff_src.
auto dst_lt = op->get_input_value(1)->get_logical_tensor();
dst_lt.data_type = diff_src_lt.data_type;
auto dst = make_dnnl_memory_desc(dst_lt);
const dnnl::memory::desc &src = dst;

// construct src with layout information from dst and data type information
// from diff_src.
auto src_lt = dst_lt;
src_lt.data_type = diff_src_lt.data_type;
auto src = make_dnnl_memory_desc(src_lt);

const dnnl::algorithm algo
= op->get_kind() == dnnl_impl::op_kind::dnnl_logsoftmax_bwd
Loading
Oops, something went wrong.