1
1
/* ******************************************************************************
2
- * Copyright 2024- 2025 Intel Corporation
2
+ * Copyright 2025 Intel Corporation
3
3
*
4
4
* Licensed under the Apache License, Version 2.0 (the "License");
5
5
* you may not use this file except in compliance with the License.
@@ -49,6 +49,7 @@ status_t sdp_primitive_kernel_t<quantized>::compile_impl(
49
49
#if defined(DNNL_WITH_SYCL) && DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL
50
50
return status::unimplemented;
51
51
#endif
52
+
52
53
p_engine_ = make_dnnl_engine (*g_engine);
53
54
g_alloc_
54
55
= reinterpret_cast <graph::allocator_t *>(g_engine->get_allocator ());
@@ -68,7 +69,6 @@ status_t sdp_primitive_kernel_t<quantized>::compile_impl(
68
69
69
70
BACKEND_DNNL_ADD_PASS (pipeline, lower_down);
70
71
BACKEND_DNNL_ADD_PASS (pipeline, fuse_implicit_causal_mask);
71
- BACKEND_DNNL_ADD_PASS (pipeline, fuse_reshape_for_gqa);
72
72
if (quantized) {
73
73
BACKEND_DNNL_ADD_PASS (pipeline, lift_up_typecast);
74
74
BACKEND_DNNL_ADD_PASS (pipeline, lift_up_quantize);
@@ -92,44 +92,39 @@ status_t sdp_primitive_kernel_t<quantized>::compile_impl(
92
92
93
93
pipeline.reset_visualize_arg (true , false );
94
94
BACKEND_DNNL_ADD_PASS (pipeline, infer_shape);
95
+ BACKEND_DNNL_ADD_PASS (pipeline, fuse_sdpa);
95
96
BACKEND_DNNL_ADD_PASS (pipeline, fuse_dst_transpose_to_predecessor);
97
+ BACKEND_DNNL_ADD_PASS (pipeline, fuse_reshape_for_gqa_gpu);
98
+ BACKEND_DNNL_ADD_PASS (pipeline, insert_reshape_for_sdpa);
96
99
BACKEND_DNNL_ADD_PASS (pipeline, layout_propagation);
97
100
98
- // bind the memory for each op
101
+ // bind the memory for each op`
99
102
auto memory_plan = [&](std::shared_ptr<subgraph_t > &sg) {
100
103
return memory_planner_.run (sg);
101
104
};
102
105
pipeline.reset_visualize_arg (true , true );
103
106
BACKEND_DNNL_ADD_PASS (pipeline, memory_plan);
107
+ BACKEND_DNNL_ADD_PASS (pipeline, compile_ops);
104
108
105
- auto modify_subgraph = [&] {
106
- // Run the added passes
107
- CHECK (pipeline.run (subgraph_));
108
-
109
- // fill information for inputs logical tensors
110
- for (size_t i = 0 ; i < inputs.size (); i++) {
111
- auto &in = const_cast <logical_tensor_t &>(inputs[i]);
112
- in = subgraph_->ins_ [i];
113
- }
109
+ // Run the added passes
110
+ BACKEND_DNNL_CHECK (pipeline.run (subgraph_));
114
111
115
- // fill information for outputs logical tensors
116
- for (size_t i = 0 ; i < outputs .size (); i++) {
117
- auto &out = const_cast <logical_tensor_t &>(outputs [i]);
118
- out = subgraph_->outs_ [i];
119
- }
112
+ // fill information for inputs logical tensors
113
+ for (size_t i = 0 ; i < inputs .size (); i++) {
114
+ auto &in = const_cast <logical_tensor_t &>(inputs [i]);
115
+ in = subgraph_->ins_ [i];
116
+ }
120
117
121
- return status::success;
122
- };
118
+ // fill information for outputs logical tensors
119
+ for (size_t i = 0 ; i < outputs.size (); i++) {
120
+ auto &out = const_cast <logical_tensor_t &>(outputs[i]);
121
+ out = subgraph_->outs_ [i];
122
+ }
123
123
124
124
resource_ctor_ = [this ]() {
125
125
return this ->memory_planner_ .get_exec_args_set ().clone ();
126
126
};
127
127
128
- CHECK (modify_subgraph ());
129
-
130
- cfg_.quantized_ = quantized;
131
- CHECK (cfg_.init (subgraph_, p_engine_, inputs, outputs));
132
-
133
128
return status::success;
134
129
}
135
130
@@ -145,67 +140,13 @@ void sdp_primitive_kernel_t<quantized>::prepare_args_set(
145
140
mem_idx.first .set_data_handle (
146
141
outputs[mem_idx.second ].get_data_handle ());
147
142
}
148
- }
149
143
150
- template <bool quantized>
151
- status_t sdp_primitive_kernel_t <quantized>::get_prim_exec_args(
152
- exec_args_t &args, memory (&mem_storage)[10 ],
153
- const execution_args_set_t *res) const {
154
- bool ok = res->find_value_mem_map (cfg_.q_ .get (), mem_storage[0 ])
155
- && res->find_value_mem_map (cfg_.k_ .get (), mem_storage[1 ])
156
- && res->find_value_mem_map (cfg_.v_ .get (), mem_storage[2 ])
157
- && res->find_value_mem_map (cfg_.dst_ .get (), mem_storage[3 ]);
158
-
159
- if (cfg_.scale_ )
160
- ok = ok && res->find_value_mem_map (cfg_.scale_ .get (), mem_storage[4 ]);
161
- if (cfg_.attn_mask_ )
162
- ok = ok
163
- && res->find_value_mem_map (
164
- cfg_.attn_mask_ .get (), mem_storage[5 ]);
165
- if (quantized && !(cfg_.k_scale_ || cfg_.v_scale_ ))
166
- return status::invalid_arguments;
167
- if (cfg_.k_scale_ )
168
- ok = ok && res->find_value_mem_map (cfg_.k_scale_ .get (), mem_storage[6 ]);
169
- if (cfg_.v_scale_ )
170
- ok = ok && res->find_value_mem_map (cfg_.v_scale_ .get (), mem_storage[7 ]);
171
-
172
- if (cfg_.k_zero_points_ )
173
- ok = ok
174
- && res->find_value_mem_map (
175
- cfg_.k_zero_points_ .get (), mem_storage[8 ]);
176
- if (cfg_.v_zero_points_ )
177
- ok = ok
178
- && res->find_value_mem_map (
179
- cfg_.v_zero_points_ .get (), mem_storage[9 ]);
180
-
181
- VCONDCHECK (graph, exec, check, sdp_primitive_kernel, ok,
182
- status::runtime_error,
183
- " sdp_primitive_kernel get_prim_exec_args failed" );
184
-
185
- memory_arg_t mem_arg_q = {mem_storage[0 ].get (), true };
186
- memory_arg_t mem_arg_k = {mem_storage[1 ].get (), true };
187
- memory_arg_t mem_arg_v = {mem_storage[2 ].get (), true };
188
- memory_arg_t mem_arg_dst = {mem_storage[3 ].get (), false };
189
- memory_arg_t mem_arg_scale = {mem_storage[4 ].get (true ), true };
190
- memory_arg_t mem_arg_mask = {mem_storage[5 ].get (true ), true };
191
- memory_arg_t mem_arg_k_scale = {mem_storage[6 ].get (true ), true };
192
- memory_arg_t mem_arg_v_scale = {mem_storage[7 ].get (true ), true };
193
- memory_arg_t mem_arg_k_zero_points = {mem_storage[8 ].get (true ), true };
194
- memory_arg_t mem_arg_v_zero_points = {mem_storage[9 ].get (true ), true };
195
-
196
- args.clear ();
197
- args[DNNL_ARG_QUERIES] = mem_arg_q;
198
- args[DNNL_ARG_KEYS] = mem_arg_k;
199
- args[DNNL_ARG_VALUES] = mem_arg_v;
200
- args[DNNL_ARG_DST] = mem_arg_dst;
201
- args[DNNL_ARG_SCALE] = mem_arg_scale;
202
- args[DNNL_ARG_ATTN_MASK] = mem_arg_mask;
203
- args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS] = mem_arg_k_scale;
204
- args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES] = mem_arg_v_scale;
205
- args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS] = mem_arg_k_zero_points;
206
- args[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES] = mem_arg_v_zero_points;
144
+ grantor_t var_grantor = memory_planner_.internal_temporary_grantor (
145
+ scratchpad.get_buffer ());
207
146
208
- return status::success;
147
+ for (auto &mem_offkey : res->get_mems_use_internal_temporary ()) {
148
+ mem_offkey.first .set_data_handle (var_grantor.get (mem_offkey.second ));
149
+ }
209
150
}
210
151
211
152
template <bool quantized>
@@ -218,17 +159,16 @@ status_t sdp_primitive_kernel_t<quantized>::execute_impl(
218
159
execution_args_set_t *res = res_cache.get_or_add (
219
160
reinterpret_cast <size_t >(this ), resource_ctor_);
220
161
221
- // Micro kernel doesn't use scratchpad memory, here we force-set size as
222
- // zero to avoid redundant memory allocation and deallocation.
223
- temporary_scratchpad_t scratchpad ( 0 , p_engine_, *g_alloc_);
162
+ temporary_scratchpad_t scratchpad (
163
+ memory_planner_. total_internal_temporary_size (), p_engine_,
164
+ *g_alloc_);
224
165
prepare_args_set (res, inputs, outputs, scratchpad);
225
166
226
- memory mem_storage[10 ];
227
- exec_args_t args;
228
- CHECK (get_prim_exec_args (args, mem_storage, res));
229
- exec_ctx_t ctx (p_stream.get (), std::move (args));
167
+ for (size_t i = 0 ; i < subgraph_->execs_ .size (); i++) {
168
+ subgraph_->execs_ [i]->execute (p_stream, res->get_exec_args ()[i]);
169
+ }
230
170
231
- return cfg_. sdpa_prim_ -> execute (ctx) ;
171
+ return status::success ;
232
172
}
233
173
234
174
#ifdef DNNL_WITH_SYCL
@@ -242,42 +182,31 @@ status_t sdp_primitive_kernel_t<quantized>::sycl_execute_impl(
242
182
#if DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL
243
183
return status::unimplemented;
244
184
#endif
185
+ auto deps = sycl_deps;
186
+ ::sycl::event returned_event;
187
+
245
188
dnnl::stream p_stream = make_dnnl_stream (p_engine_, *g_stream);
246
189
247
190
thread_local_cache_t <execution_args_set_t > res_cache;
248
191
execution_args_set_t *res = res_cache.get_or_add (
249
192
reinterpret_cast <size_t >(this ), resource_ctor_);
250
193
251
- // Micro kernel doesn't use scratchpad memory, here we force-set size as
252
- // zero to avoid redundant memory allocation and deallocation.
253
- temporary_scratchpad_t scratchpad ( 0 , p_engine_, *g_alloc_);
194
+ temporary_scratchpad_t scratchpad (
195
+ memory_planner_. total_internal_temporary_size (), p_engine_,
196
+ *g_alloc_);
254
197
prepare_args_set (res, inputs, outputs, scratchpad);
255
198
256
- memory mem_storage[10 ];
257
- exec_args_t args;
258
- CHECK (get_prim_exec_args (args, mem_storage, res));
259
- exec_ctx_t ctx (p_stream.get (), std::move (args));
260
-
261
- // Relying on the library's internals here. Since graph API is currently
262
- // enabled only for the Intel vendor it is fine to cast stream to
263
- // gpu::intel::sycl::stream_t unconditionally.
264
- auto *sycl_stream = dnnl::impl::utils::downcast<
265
- dnnl::impl::gpu::intel::sycl::stream_t *>(p_stream.get ());
266
-
267
- sycl_stream->before_exec_hook ();
268
-
269
- if (!sycl_deps.empty ()) sycl_stream->sycl_ctx ().set_deps (sycl_deps);
270
-
271
- auto status = cfg_.sdpa_prim_ ->execute (ctx);
272
-
273
- auto return_event = sycl_stream->get_output_event ();
274
-
275
- scratchpad.set_deps (return_event);
276
- if (sycl_event) *sycl_event = return_event;
199
+ for (size_t i = 0 ; i < subgraph_->execs_ .size (); i++) {
200
+ if (subgraph_->is_constant_ [i]) continue ;
201
+ returned_event = subgraph_->execs_ [i]->execute_sycl (
202
+ p_stream, res->get_exec_args ()[i], deps);
203
+ deps = {returned_event};
204
+ }
277
205
278
- sycl_stream->after_exec_hook ();
206
+ scratchpad.set_deps (returned_event);
207
+ if (sycl_event) *sycl_event = returned_event;
279
208
280
- return status;
209
+ return status::success ;
281
210
}
282
211
#endif
283
212
@@ -287,50 +216,31 @@ status_t sdp_primitive_kernel_t<quantized>::ocl_execute_impl(
287
216
const stream_t *g_stream, const std::vector<tensor_t > &inputs,
288
217
const std::vector<tensor_t > &outputs,
289
218
const std::vector<cl_event> &cl_deps, cl_event *ret_event) {
219
+ auto deps = cl_deps;
220
+ cl_event returned_event {};
290
221
291
222
dnnl::stream p_stream = make_dnnl_stream (p_engine_, *g_stream);
292
223
293
224
thread_local_cache_t <execution_args_set_t > res_cache;
294
225
execution_args_set_t *res = res_cache.get_or_add (
295
226
reinterpret_cast <size_t >(this ), resource_ctor_);
296
227
297
- // Micro kernel doesn't use scratchpad memory, here we force-set size as
298
- // zero to avoid redundant memory allocation and deallocation.
299
- temporary_scratchpad_t scratchpad ( 0 , p_engine_, *g_alloc_);
228
+ temporary_scratchpad_t scratchpad (
229
+ memory_planner_. total_internal_temporary_size (), p_engine_,
230
+ *g_alloc_);
300
231
prepare_args_set (res, inputs, outputs, scratchpad);
301
232
302
- memory mem_storage[10 ];
303
- exec_args_t args;
304
- CHECK (get_prim_exec_args (args, mem_storage, res));
305
- exec_ctx_t ctx (p_stream.get (), std::move (args));
306
-
307
- // TODO (pc): refactor
308
- auto *ocl_stream = dnnl::impl::utils::downcast<gpu::intel::ocl::stream_t *>(
309
- p_stream.get ());
310
-
311
- ocl_stream->before_exec_hook ();
312
-
313
- if (!cl_deps.empty ()) {
314
- std::vector<xpu::ocl::wrapper_t <cl_event>> events (cl_deps.size ());
315
- for (size_t i = 0 ; i < cl_deps.size (); i++)
316
- events[i] = xpu::ocl::wrapper_t <cl_event>(cl_deps[i], true );
317
- ocl_stream->ocl_ctx ().set_deps (events);
233
+ for (size_t i = 0 ; i < subgraph_->execs_ .size (); i++) {
234
+ if (subgraph_->is_constant_ [i]) continue ;
235
+ returned_event = subgraph_->execs_ [i]->execute_ocl (
236
+ p_stream, res->get_exec_args ()[i], deps);
237
+ deps = {returned_event};
318
238
}
319
239
320
- auto status = cfg_.sdpa_prim_ ->execute (ctx);
321
-
322
- cl_event return_event = nullptr ;
323
- if ((ocl_stream->flags () & stream_flags::in_order) == 0 ) {
324
- auto last = ocl_stream->get_output_event ();
325
- return_event = last.release ();
326
- }
240
+ scratchpad.set_deps (returned_event);
241
+ if (ret_event) *ret_event = returned_event;
327
242
328
- scratchpad.set_deps (return_event);
329
- if (ret_event) *ret_event = return_event;
330
-
331
- ocl_stream->after_exec_hook ();
332
-
333
- return status;
243
+ return status::success;
334
244
}
335
245
#endif
336
246
0 commit comments