Skip to content

Commit 2b9207e

Browse files
committed
xe: sdpa: compute scale/iscale on host for host-side-scale
1 parent 748501a commit 2b9207e

File tree

3 files changed

+39
-29
lines changed

3 files changed

+39
-29
lines changed

src/gpu/intel/ocl/micro_sdpa.cl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ __attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void
212212
micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
213213
const global VAL_DATA_T *V, global DST_DATA_T *A,
214214
#if HOST_SIDE_SCALE
215-
const SCALE_DATA_T scale_value,
215+
const float scale, const float iscale,
216216
#else
217217
const global SCALE_DATA_T *scale_ptr,
218218
#endif
@@ -339,29 +339,23 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
339339
#endif
340340
}
341341

342+
#if HOST_SIDE_SCALE == 0
342343
/* Load scale */
343344
float scale = 1.0f;
344345
float iscale = 1.0f;
345346
if (k0end > 0) {
346347
#if WITH_ATTN_SCALE
347348
#if INVERT_SCALE
348-
#if HOST_SIDE_SCALE
349-
iscale = SCALES_TO_FLOAT(scale_value);
350-
#else
351349
iscale = SCALES_TO_FLOAT(*scale_ptr);
352-
#endif
353350
scale = native_recip(iscale);
354-
#else
355-
#if HOST_SIDE_SCALE
356-
scale = SCALES_TO_FLOAT(scale_value);
357351
#else
358352
scale = SCALES_TO_FLOAT(*scale_ptr);
359-
#endif
360353
iscale = native_recip(scale);
361354
#endif
362355
#endif
363356
scale *= 1.442695f; // log2(e)
364357
}
358+
#endif
365359

366360
#ifdef PREFETCH_K0
367361
if (k0end > 0) {

src/gpu/intel/ocl/micro_sdpa.cpp

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -564,15 +564,31 @@ status_t micro_sdpa_t::execute(const exec_ctx_t &ctx) const {
564564
arg_list.append(dst);
565565

566566
if (pd()->scale_md()->format_kind == format_kind::host_side_scalar) {
567-
void *handle = scale.data_handle();
568-
switch (pd()->scale_md()->data_type) {
569-
case data_type::f16: arg_list.append(*(float16_t *)handle); break;
570-
case data_type::bf16:
571-
arg_list.append(*(unsigned short *)handle);
572-
break;
573-
case data_type::f32: arg_list.append(*(float *)handle); break;
574-
default: break;
567+
const void *handle = scale.data_handle();
568+
569+
const float scale_value = [&]() {
570+
switch (pd()->scale_md()->data_type) {
571+
case data_type::f16: return float(*(float16_t *)handle);
572+
case data_type::bf16: return float(*(bfloat16_t *)handle);
573+
case data_type::f32: return *(float *)handle;
574+
default: return 0.f;
575+
}
576+
}();
577+
578+
float scale = 1.f, iscale = 1.f;
579+
if (pd()->with_attn_scale()) {
580+
if (pd()->desc()->invert_scale) {
581+
iscale = scale_value;
582+
scale = 1.f / iscale;
583+
} else {
584+
scale = scale_value;
585+
iscale = 1.f / scale;
586+
}
575587
}
588+
scale *= 1.442695f;
589+
590+
arg_list.append(scale);
591+
arg_list.append(iscale);
576592
} else {
577593
arg_list.append(scale);
578594
}

src/gpu/intel/ocl/reusable_softmax.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -205,18 +205,18 @@ struct reusable_softmax_fwd_t : public gpu_primitive_t {
205205

206206
conf.algorithm_number = [&]() { // -> int
207207
if (arch != arch_t::xe_hpg) {
208-
// if (rt_conf.softmax_axis_stride == 1
209-
// && rt_conf.softmax_axis_size >= 128
210-
// && nelems > (1 << 17)
211-
// && dnnl::impl::utils::div_up(
212-
// rt_conf.softmax_axis_size,
213-
// conf.subgroup_size)
214-
// <= 1024)
215-
// return vectorized;
216-
// if (rt_conf.softmax_axis_stride == 1
217-
// && rt_conf.softmax_axis_size <= conf.subgroup_size
218-
// && nelems < (1 << 15))
219-
// return small;
208+
if (rt_conf.softmax_axis_stride == 1
209+
&& rt_conf.softmax_axis_size >= 128
210+
&& nelems > (1 << 17)
211+
&& dnnl::impl::utils::div_up(
212+
rt_conf.softmax_axis_size,
213+
conf.subgroup_size)
214+
<= 1024)
215+
return vectorized;
216+
if (rt_conf.softmax_axis_stride == 1
217+
&& rt_conf.softmax_axis_size <= conf.subgroup_size
218+
&& nelems < (1 << 15))
219+
return small;
220220
}
221221
if (rt_conf.softmax_axis_size < 6 && nelems > 64000)
222222
return many_reductions_per_workgroup;

0 commit comments

Comments
 (0)