Skip to content

Commit 4b7df7a

Browse files
committed
xe: sdpa: compute scale/iscale on host for host-side-scale
1 parent 79742d6 commit 4b7df7a

File tree

2 files changed

+14
-21
lines changed

2 files changed

+14
-21
lines changed

src/gpu/intel/ocl/micro_sdpa.cl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ __attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void
212212
micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
213213
const global VAL_DATA_T *V, global DST_DATA_T *A,
214214
#if HOST_SIDE_SCALE
215-
const SCALE_DATA_T scale_value,
215+
const float scale, const float iscale,
216216
#else
217217
const global SCALE_DATA_T *scale_ptr,
218218
#endif
@@ -346,23 +346,16 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
346346
if (k0end > 0) {
347347
#if WITH_ATTN_SCALE
348348
#if INVERT_SCALE
349-
#if HOST_SIDE_SCALE
350-
iscale = SCALES_TO_FLOAT(scale_value);
351-
#else
352349
iscale = SCALES_TO_FLOAT(*scale_ptr);
353-
#endif
354350
scale = native_recip(iscale);
355-
#else
356-
#if HOST_SIDE_SCALE
357-
scale = SCALES_TO_FLOAT(scale_value);
358351
#else
359352
scale = SCALES_TO_FLOAT(*scale_ptr);
360-
#endif
361353
iscale = native_recip(scale);
362354
#endif
363355
#endif
364356
scale *= 1.442695f; // log2(e)
365357
}
358+
#endif
366359

367360
#ifdef PREFETCH_K0
368361
if (k0end > 0) {

src/gpu/intel/ocl/reusable_softmax.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -205,18 +205,18 @@ struct reusable_softmax_fwd_t : public gpu_primitive_t {
205205

206206
conf.algorithm_number = [&]() { // -> int
207207
if (arch != arch_t::xe_hpg) {
208-
// if (rt_conf.softmax_axis_stride == 1
209-
// && rt_conf.softmax_axis_size >= 128
210-
// && nelems > (1 << 17)
211-
// && dnnl::impl::utils::div_up(
212-
// rt_conf.softmax_axis_size,
213-
// conf.subgroup_size)
214-
// <= 1024)
215-
// return vectorized;
216-
// if (rt_conf.softmax_axis_stride == 1
217-
// && rt_conf.softmax_axis_size <= conf.subgroup_size
218-
// && nelems < (1 << 15))
219-
// return small;
208+
if (rt_conf.softmax_axis_stride == 1
209+
&& rt_conf.softmax_axis_size >= 128
210+
&& nelems > (1 << 17)
211+
&& dnnl::impl::utils::div_up(
212+
rt_conf.softmax_axis_size,
213+
conf.subgroup_size)
214+
<= 1024)
215+
return vectorized;
216+
if (rt_conf.softmax_axis_stride == 1
217+
&& rt_conf.softmax_axis_size <= conf.subgroup_size
218+
&& nelems < (1 << 15))
219+
return small;
220220
}
221221
if (rt_conf.softmax_axis_size < 6 && nelems > 64000)
222222
return many_reductions_per_workgroup;

0 commit comments

Comments
 (0)