Skip to content

Commit 2828827

Browse files
committed
graph: backend: dnnl: support sdpa training forward pattern with larger_partition_kernel
1 parent 5641f54 commit 2828827

File tree

7 files changed

+14
-6
lines changed

7 files changed

+14
-6
lines changed

src/graph/backend/dnnl/kernels/sdp_decomp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ status_t sdp_decomp_kernel_t<quantized, dt>::compile_impl(
5151
BACKEND_DNNL_CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs));
5252

5353
// Check if it's supported by decomposition kernel
54-
if (!sdp_cfg_.initial_check(subgraph_, inputs))
54+
if (!sdp_cfg_.initial_check(subgraph_, inputs, outputs))
5555
return status::unimplemented;
5656

5757
subgraph_visualizer_t vis(part->id(), [this](const value_t *val) {

src/graph/backend/dnnl/kernels/sdp_decomp_config.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,17 @@ namespace graph {
2727
namespace dnnl_impl {
2828

2929
bool sdp_decomp_config_t::initial_check(const std::shared_ptr<subgraph_t> &sg,
30-
const std::vector<logical_tensor_t> &inputs) {
30+
const std::vector<logical_tensor_t> &inputs,
31+
const std::vector<logical_tensor_t> &outputs) {
3132
// The order of input logical tensors in inputs is not certain, we need
3233
// to record the input offset in a certain order of ops.
3334
CHECK_BOOL(record_input_offset(sg, inputs));
3435
dims src1_user_dims = ltw(inputs[graph_inport[mm1_src]]).vdims();
3536
ndims = src1_user_dims.size();
3637
VCHECK_SDP_DECOMP(ndims == 4 || ndims == 5, false,
3738
"Input dims should be 4 or 5, but got %zu", src1_user_dims.size());
39+
VCHECK_SDP_DECOMP(
40+
outputs.size() == 1, false, "Doesn't support SDPA training yet");
3841

3942
// Initialize SDP input dimension according to the src of mm1
4043
int index = 0;

src/graph/backend/dnnl/kernels/sdp_decomp_config.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ struct sdp_decomp_config_t {
156156
// If the check passes, initialize few members according to inputs
157157
// If no, return unimplemented status directly and fallback to large kernel
158158
bool initial_check(const std::shared_ptr<subgraph_t> &sg,
159-
const std::vector<logical_tensor_t> &inputs);
159+
const std::vector<logical_tensor_t> &inputs,
160+
const std::vector<logical_tensor_t> &outputs);
160161

161162
// Used to construct all params that SDP need
162163
template <bool quantized = false,

src/graph/backend/dnnl/kernels/sdp_primitive.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ status_t sdp_primitive_kernel_t<quantized>::compile_impl(
5959
p_engine_, part->get_fpmath_mode(), false, true);
6060
CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs));
6161

62-
CHECK(cfg_.initial_check(subgraph_, inputs));
62+
CHECK(cfg_.initial_check(subgraph_, inputs, outputs));
6363

6464
subgraph_visualizer_t vis(part->id(), [this](const value_t *val) {
6565
return this->memory_planner_.get_memory_info(val);

src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,13 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
169169

170170
status_t sdp_primitive_config_t::initial_check(
171171
const std::shared_ptr<subgraph_t> &sg,
172-
const std::vector<logical_tensor_t> &inputs, bool v1_kernel) {
172+
const std::vector<logical_tensor_t> &inputs,
173+
const std::vector<logical_tensor_t> &outputs, bool v1_kernel) {
173174
// At least 3 inputs: Q, K, V
174175
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments,
175176
"At least 3 inputs are required");
177+
VCHECK_SDP_PRIMITIVE(outputs.size() == 1, status::unimplemented,
178+
"Doesn't support SDPA training yet");
176179

177180
// Ukernel doesn't support f32 datatype now
178181
VCHECK_SDP_PRIMITIVE(inputs[0].data_type != dnnl_data_type_t::dnnl_f32,

src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ struct sdp_primitive_config_t {
8484
// 3. only support 4-dims tensor
8585
status_t initial_check(const std::shared_ptr<subgraph_t> &sg,
8686
const std::vector<logical_tensor_t> &inputs,
87+
const std::vector<logical_tensor_t> &outputs,
8788
bool v1_kernel = false);
8889

8990
// Initialize parameters and primitive.

src/graph/backend/dnnl/kernels/sdp_primitive_v1.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ status_t sdp_primitive_v1_kernel_t::compile_impl(
5959
p_engine_, part->get_fpmath_mode(), false, true);
6060
CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs));
6161

62-
CHECK(cfg_.initial_check(subgraph_, inputs, true));
62+
CHECK(cfg_.initial_check(subgraph_, inputs, outputs, true));
6363

6464
subgraph_visualizer_t vis(part->id(), [this](const value_t *val) {
6565
return this->memory_planner_.get_memory_info(val);

0 commit comments

Comments
 (0)