Skip to content

Commit 5b6dac7

Browse files
committed
graph: backend: dnnl: merge 2 sdpa primitive kernel
1 parent 4dbfac2 commit 5b6dac7

File tree

7 files changed

+62
-531
lines changed

7 files changed

+62
-531
lines changed

src/graph/backend/dnnl/kernels/sdp.hpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
#include "graph/backend/dnnl/kernels/large_partition.hpp"
2828
#include "graph/backend/dnnl/kernels/sdp_decomp.hpp"
2929
#include "graph/backend/dnnl/kernels/sdp_primitive.hpp"
30-
#include "graph/backend/dnnl/kernels/sdp_primitive_v1.hpp"
3130

3231
#include "graph/backend/dnnl/dnnl_partition_impl.hpp"
3332

@@ -66,15 +65,7 @@ struct sdp_base_t : public kernel_base_t {
6665

6766
status_t ret = status::unimplemented;
6867

69-
// SDPA Ukernel v1 with fused internal sdpa solution. Support fload sdpa
70-
// only.
71-
// TODO(GX): Support quantized sdpa and merge with sdp_primitive_kernel_t.
7268
if (enable_ukernel) {
73-
kernel = std::make_shared<sdp_primitive_v1_kernel_t<quantized>>();
74-
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
75-
}
76-
77-
if (ret != status::success && enable_ukernel) {
7869
kernel = std::make_shared<sdp_primitive_kernel_t<quantized>>();
7970
ret = kernel->compile_impl(part, g_engine, inputs, outputs);
8071
}

src/graph/backend/dnnl/kernels/sdp_primitive.cpp

Lines changed: 59 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2024-2025 Intel Corporation
2+
* Copyright 2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -49,6 +49,7 @@ status_t sdp_primitive_kernel_t<quantized>::compile_impl(
4949
#if defined(DNNL_WITH_SYCL) && DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL
5050
return status::unimplemented;
5151
#endif
52+
5253
p_engine_ = make_dnnl_engine(*g_engine);
5354
g_alloc_
5455
= reinterpret_cast<graph::allocator_t *>(g_engine->get_allocator());
@@ -68,7 +69,6 @@ status_t sdp_primitive_kernel_t<quantized>::compile_impl(
6869

6970
BACKEND_DNNL_ADD_PASS(pipeline, lower_down);
7071
BACKEND_DNNL_ADD_PASS(pipeline, fuse_implicit_causal_mask);
71-
BACKEND_DNNL_ADD_PASS(pipeline, fuse_reshape_for_gqa);
7272
if (quantized) {
7373
BACKEND_DNNL_ADD_PASS(pipeline, lift_up_typecast);
7474
BACKEND_DNNL_ADD_PASS(pipeline, lift_up_quantize);
@@ -92,44 +92,39 @@ status_t sdp_primitive_kernel_t<quantized>::compile_impl(
9292

9393
pipeline.reset_visualize_arg(true, false);
9494
BACKEND_DNNL_ADD_PASS(pipeline, infer_shape);
95+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_sdpa);
9596
BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_transpose_to_predecessor);
97+
BACKEND_DNNL_ADD_PASS(pipeline, fuse_reshape_for_gqa_gpu);
98+
BACKEND_DNNL_ADD_PASS(pipeline, insert_reshape_for_sdpa);
9699
BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation);
97100

98-
// bind the memory for each op
101+
// bind the memory for each op`
99102
auto memory_plan = [&](std::shared_ptr<subgraph_t> &sg) {
100103
return memory_planner_.run(sg);
101104
};
102105
pipeline.reset_visualize_arg(true, true);
103106
BACKEND_DNNL_ADD_PASS(pipeline, memory_plan);
107+
BACKEND_DNNL_ADD_PASS(pipeline, compile_ops);
104108

105-
auto modify_subgraph = [&] {
106-
// Run the added passes
107-
CHECK(pipeline.run(subgraph_));
108-
109-
// fill information for inputs logical tensors
110-
for (size_t i = 0; i < inputs.size(); i++) {
111-
auto &in = const_cast<logical_tensor_t &>(inputs[i]);
112-
in = subgraph_->ins_[i];
113-
}
109+
// Run the added passes
110+
BACKEND_DNNL_CHECK(pipeline.run(subgraph_));
114111

115-
// fill information for outputs logical tensors
116-
for (size_t i = 0; i < outputs.size(); i++) {
117-
auto &out = const_cast<logical_tensor_t &>(outputs[i]);
118-
out = subgraph_->outs_[i];
119-
}
112+
// fill information for inputs logical tensors
113+
for (size_t i = 0; i < inputs.size(); i++) {
114+
auto &in = const_cast<logical_tensor_t &>(inputs[i]);
115+
in = subgraph_->ins_[i];
116+
}
120117

121-
return status::success;
122-
};
118+
// fill information for outputs logical tensors
119+
for (size_t i = 0; i < outputs.size(); i++) {
120+
auto &out = const_cast<logical_tensor_t &>(outputs[i]);
121+
out = subgraph_->outs_[i];
122+
}
123123

124124
resource_ctor_ = [this]() {
125125
return this->memory_planner_.get_exec_args_set().clone();
126126
};
127127

128-
CHECK(modify_subgraph());
129-
130-
cfg_.quantized_ = quantized;
131-
CHECK(cfg_.init(subgraph_, p_engine_, inputs, outputs));
132-
133128
return status::success;
134129
}
135130

@@ -145,67 +140,13 @@ void sdp_primitive_kernel_t<quantized>::prepare_args_set(
145140
mem_idx.first.set_data_handle(
146141
outputs[mem_idx.second].get_data_handle());
147142
}
148-
}
149143

150-
template <bool quantized>
151-
status_t sdp_primitive_kernel_t<quantized>::get_prim_exec_args(
152-
exec_args_t &args, memory (&mem_storage)[10],
153-
const execution_args_set_t *res) const {
154-
bool ok = res->find_value_mem_map(cfg_.q_.get(), mem_storage[0])
155-
&& res->find_value_mem_map(cfg_.k_.get(), mem_storage[1])
156-
&& res->find_value_mem_map(cfg_.v_.get(), mem_storage[2])
157-
&& res->find_value_mem_map(cfg_.dst_.get(), mem_storage[3]);
158-
159-
if (cfg_.scale_)
160-
ok = ok && res->find_value_mem_map(cfg_.scale_.get(), mem_storage[4]);
161-
if (cfg_.attn_mask_)
162-
ok = ok
163-
&& res->find_value_mem_map(
164-
cfg_.attn_mask_.get(), mem_storage[5]);
165-
if (quantized && !(cfg_.k_scale_ || cfg_.v_scale_))
166-
return status::invalid_arguments;
167-
if (cfg_.k_scale_)
168-
ok = ok && res->find_value_mem_map(cfg_.k_scale_.get(), mem_storage[6]);
169-
if (cfg_.v_scale_)
170-
ok = ok && res->find_value_mem_map(cfg_.v_scale_.get(), mem_storage[7]);
171-
172-
if (cfg_.k_zero_points_)
173-
ok = ok
174-
&& res->find_value_mem_map(
175-
cfg_.k_zero_points_.get(), mem_storage[8]);
176-
if (cfg_.v_zero_points_)
177-
ok = ok
178-
&& res->find_value_mem_map(
179-
cfg_.v_zero_points_.get(), mem_storage[9]);
180-
181-
VCONDCHECK(graph, exec, check, sdp_primitive_kernel, ok,
182-
status::runtime_error,
183-
"sdp_primitive_kernel get_prim_exec_args failed");
184-
185-
memory_arg_t mem_arg_q = {mem_storage[0].get(), true};
186-
memory_arg_t mem_arg_k = {mem_storage[1].get(), true};
187-
memory_arg_t mem_arg_v = {mem_storage[2].get(), true};
188-
memory_arg_t mem_arg_dst = {mem_storage[3].get(), false};
189-
memory_arg_t mem_arg_scale = {mem_storage[4].get(true), true};
190-
memory_arg_t mem_arg_mask = {mem_storage[5].get(true), true};
191-
memory_arg_t mem_arg_k_scale = {mem_storage[6].get(true), true};
192-
memory_arg_t mem_arg_v_scale = {mem_storage[7].get(true), true};
193-
memory_arg_t mem_arg_k_zero_points = {mem_storage[8].get(true), true};
194-
memory_arg_t mem_arg_v_zero_points = {mem_storage[9].get(true), true};
195-
196-
args.clear();
197-
args[DNNL_ARG_QUERIES] = mem_arg_q;
198-
args[DNNL_ARG_KEYS] = mem_arg_k;
199-
args[DNNL_ARG_VALUES] = mem_arg_v;
200-
args[DNNL_ARG_DST] = mem_arg_dst;
201-
args[DNNL_ARG_SCALE] = mem_arg_scale;
202-
args[DNNL_ARG_ATTN_MASK] = mem_arg_mask;
203-
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS] = mem_arg_k_scale;
204-
args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES] = mem_arg_v_scale;
205-
args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS] = mem_arg_k_zero_points;
206-
args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES] = mem_arg_v_zero_points;
144+
grantor_t var_grantor = memory_planner_.internal_temporary_grantor(
145+
scratchpad.get_buffer());
207146

208-
return status::success;
147+
for (auto &mem_offkey : res->get_mems_use_internal_temporary()) {
148+
mem_offkey.first.set_data_handle(var_grantor.get(mem_offkey.second));
149+
}
209150
}
210151

211152
template <bool quantized>
@@ -218,17 +159,16 @@ status_t sdp_primitive_kernel_t<quantized>::execute_impl(
218159
execution_args_set_t *res = res_cache.get_or_add(
219160
reinterpret_cast<size_t>(this), resource_ctor_);
220161

221-
// Micro kernel doesn't use scratchpad memory, here we force-set size as
222-
// zero to avoid redundant memory allocation and deallocation.
223-
temporary_scratchpad_t scratchpad(0, p_engine_, *g_alloc_);
162+
temporary_scratchpad_t scratchpad(
163+
memory_planner_.total_internal_temporary_size(), p_engine_,
164+
*g_alloc_);
224165
prepare_args_set(res, inputs, outputs, scratchpad);
225166

226-
memory mem_storage[10];
227-
exec_args_t args;
228-
CHECK(get_prim_exec_args(args, mem_storage, res));
229-
exec_ctx_t ctx(p_stream.get(), std::move(args));
167+
for (size_t i = 0; i < subgraph_->execs_.size(); i++) {
168+
subgraph_->execs_[i]->execute(p_stream, res->get_exec_args()[i]);
169+
}
230170

231-
return cfg_.sdpa_prim_->execute(ctx);
171+
return status::success;
232172
}
233173

234174
#ifdef DNNL_WITH_SYCL
@@ -242,42 +182,31 @@ status_t sdp_primitive_kernel_t<quantized>::sycl_execute_impl(
242182
#if DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL
243183
return status::unimplemented;
244184
#endif
185+
auto deps = sycl_deps;
186+
::sycl::event returned_event;
187+
245188
dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream);
246189

247190
thread_local_cache_t<execution_args_set_t> res_cache;
248191
execution_args_set_t *res = res_cache.get_or_add(
249192
reinterpret_cast<size_t>(this), resource_ctor_);
250193

251-
// Micro kernel doesn't use scratchpad memory, here we force-set size as
252-
// zero to avoid redundant memory allocation and deallocation.
253-
temporary_scratchpad_t scratchpad(0, p_engine_, *g_alloc_);
194+
temporary_scratchpad_t scratchpad(
195+
memory_planner_.total_internal_temporary_size(), p_engine_,
196+
*g_alloc_);
254197
prepare_args_set(res, inputs, outputs, scratchpad);
255198

256-
memory mem_storage[10];
257-
exec_args_t args;
258-
CHECK(get_prim_exec_args(args, mem_storage, res));
259-
exec_ctx_t ctx(p_stream.get(), std::move(args));
260-
261-
// Relying on the library's internals here. Since graph API is currently
262-
// enabled only for the Intel vendor it is fine to cast stream to
263-
// gpu::intel::sycl::stream_t unconditionally.
264-
auto *sycl_stream = dnnl::impl::utils::downcast<
265-
dnnl::impl::gpu::intel::sycl::stream_t *>(p_stream.get());
266-
267-
sycl_stream->before_exec_hook();
268-
269-
if (!sycl_deps.empty()) sycl_stream->sycl_ctx().set_deps(sycl_deps);
270-
271-
auto status = cfg_.sdpa_prim_->execute(ctx);
272-
273-
auto return_event = sycl_stream->get_output_event();
274-
275-
scratchpad.set_deps(return_event);
276-
if (sycl_event) *sycl_event = return_event;
199+
for (size_t i = 0; i < subgraph_->execs_.size(); i++) {
200+
if (subgraph_->is_constant_[i]) continue;
201+
returned_event = subgraph_->execs_[i]->execute_sycl(
202+
p_stream, res->get_exec_args()[i], deps);
203+
deps = {returned_event};
204+
}
277205

278-
sycl_stream->after_exec_hook();
206+
scratchpad.set_deps(returned_event);
207+
if (sycl_event) *sycl_event = returned_event;
279208

280-
return status;
209+
return status::success;
281210
}
282211
#endif
283212

@@ -287,50 +216,31 @@ status_t sdp_primitive_kernel_t<quantized>::ocl_execute_impl(
287216
const stream_t *g_stream, const std::vector<tensor_t> &inputs,
288217
const std::vector<tensor_t> &outputs,
289218
const std::vector<cl_event> &cl_deps, cl_event *ret_event) {
219+
auto deps = cl_deps;
220+
cl_event returned_event {};
290221

291222
dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream);
292223

293224
thread_local_cache_t<execution_args_set_t> res_cache;
294225
execution_args_set_t *res = res_cache.get_or_add(
295226
reinterpret_cast<size_t>(this), resource_ctor_);
296227

297-
// Micro kernel doesn't use scratchpad memory, here we force-set size as
298-
// zero to avoid redundant memory allocation and deallocation.
299-
temporary_scratchpad_t scratchpad(0, p_engine_, *g_alloc_);
228+
temporary_scratchpad_t scratchpad(
229+
memory_planner_.total_internal_temporary_size(), p_engine_,
230+
*g_alloc_);
300231
prepare_args_set(res, inputs, outputs, scratchpad);
301232

302-
memory mem_storage[10];
303-
exec_args_t args;
304-
CHECK(get_prim_exec_args(args, mem_storage, res));
305-
exec_ctx_t ctx(p_stream.get(), std::move(args));
306-
307-
// TODO (pc): refactor
308-
auto *ocl_stream = dnnl::impl::utils::downcast<gpu::intel::ocl::stream_t *>(
309-
p_stream.get());
310-
311-
ocl_stream->before_exec_hook();
312-
313-
if (!cl_deps.empty()) {
314-
std::vector<xpu::ocl::wrapper_t<cl_event>> events(cl_deps.size());
315-
for (size_t i = 0; i < cl_deps.size(); i++)
316-
events[i] = xpu::ocl::wrapper_t<cl_event>(cl_deps[i], true);
317-
ocl_stream->ocl_ctx().set_deps(events);
233+
for (size_t i = 0; i < subgraph_->execs_.size(); i++) {
234+
if (subgraph_->is_constant_[i]) continue;
235+
returned_event = subgraph_->execs_[i]->execute_ocl(
236+
p_stream, res->get_exec_args()[i], deps);
237+
deps = {returned_event};
318238
}
319239

320-
auto status = cfg_.sdpa_prim_->execute(ctx);
321-
322-
cl_event return_event = nullptr;
323-
if ((ocl_stream->flags() & stream_flags::in_order) == 0) {
324-
auto last = ocl_stream->get_output_event();
325-
return_event = last.release();
326-
}
240+
scratchpad.set_deps(returned_event);
241+
if (ret_event) *ret_event = returned_event;
327242

328-
scratchpad.set_deps(return_event);
329-
if (ret_event) *ret_event = return_event;
330-
331-
ocl_stream->after_exec_hook();
332-
333-
return status;
243+
return status::success;
334244
}
335245
#endif
336246

src/graph/backend/dnnl/kernels/sdp_primitive.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2024 Intel Corporation
2+
* Copyright 2025 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -73,9 +73,6 @@ struct sdp_primitive_kernel_t : public kernel_base_t {
7373
const std::vector<tensor_t> &outputs,
7474
const scratchpad_t &scratchpad);
7575

76-
status_t get_prim_exec_args(exec_args_t &args, memory (&mem_storage)[10],
77-
const execution_args_set_t *res) const;
78-
7976
status_t execute_impl(const stream_t *g_stream,
8077
const std::vector<tensor_t> &inputs,
8178
const std::vector<tensor_t> &outputs) override;

src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ status_t sdp_primitive_config_t::locate_io(std::shared_ptr<subgraph_t> &sg,
169169

170170
status_t sdp_primitive_config_t::initial_check(
171171
const std::shared_ptr<subgraph_t> &sg,
172-
const std::vector<logical_tensor_t> &inputs, bool v1_kernel) {
172+
const std::vector<logical_tensor_t> &inputs) {
173173
// At least 3 inputs: Q, K, V
174174
VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments,
175175
"At least 3 inputs are required");
@@ -302,15 +302,6 @@ status_t sdp_primitive_config_t::initial_check(
302302
VCHECK_SDP_PRIMITIVE(q_id != -1 && k_id != -1 && v_id != -1,
303303
status::unimplemented, "Q, K, V are not found");
304304

305-
// Note: sdpa_primitive_v1 kernel accept 5D GQA pattern, and will reshape to
306-
// 4D in later compilation pass.
307-
if (!v1_kernel) {
308-
VCHECK_SDP_PRIMITIVE(ltw(inputs[q_id]).vdims().size() == 4
309-
&& ltw(inputs[k_id]).vdims().size() == 4
310-
&& ltw(inputs[v_id]).vdims().size() == 4,
311-
status::unimplemented, "Q, K, V should be 4-dims");
312-
}
313-
314305
// sdp_primitive only supports single scale value.
315306
if (scale) {
316307
const auto &s = scale->get_input_value(1)->get_logical_tensor();

src/graph/backend/dnnl/kernels/sdp_primitive_config.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ struct sdp_primitive_config_t {
8383
// 2. only support fp16 data type
8484
// 3. only support 4-dims tensor
8585
status_t initial_check(const std::shared_ptr<subgraph_t> &sg,
86-
const std::vector<logical_tensor_t> &inputs,
87-
bool v1_kernel = false);
86+
const std::vector<logical_tensor_t> &inputs);
8887

8988
// Initialize parameters and primitive.
9089
status_t init(std::shared_ptr<subgraph_t> &sg, const dnnl::engine &p_engine,

0 commit comments

Comments
 (0)