Skip to content

Commit ceb4d3a

Browse files
committed
xe: sdpa: (tmp) verify graph api changes w/o innaccurate softmax alg
1 parent 59ec241 commit ceb4d3a

File tree

4 files changed

+29
-24
lines changed

4 files changed

+29
-24
lines changed

src/common/sdpa_utils.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ static inline status_t create_sdpa_pd(
166166
const memory_desc_t *q_md, const memory_desc_t *k_md,
167167
const memory_desc_t *v_md, const memory_desc_t *dst_md,
168168
const memory_desc_t *attn_mask_md, const memory_desc_t *scale_md,
169-
data_type_t scale_dt, bool invert_scale, dim_t kv_head_number,
169+
bool invert_scale, dim_t kv_head_number,
170170
attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg,
171171
const primitive_attr_t *attr, const primitive_attr_t *kq_attr = nullptr,
172172
const primitive_attr_t *vs_attr = nullptr) {

src/gpu/intel/ocl/reusable_softmax.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -205,18 +205,18 @@ struct reusable_softmax_fwd_t : public gpu_primitive_t {
205205

206206
conf.algorithm_number = [&]() { // -> int
207207
if (arch != arch_t::xe_hpg) {
208-
if (rt_conf.softmax_axis_stride == 1
209-
&& rt_conf.softmax_axis_size >= 128
210-
&& nelems > (1 << 17)
211-
&& dnnl::impl::utils::div_up(
212-
rt_conf.softmax_axis_size,
213-
conf.subgroup_size)
214-
<= 1024)
215-
return vectorized;
216-
if (rt_conf.softmax_axis_stride == 1
217-
&& rt_conf.softmax_axis_size <= conf.subgroup_size
218-
&& nelems < (1 << 15))
219-
return small;
208+
// if (rt_conf.softmax_axis_stride == 1
209+
// && rt_conf.softmax_axis_size >= 128
210+
// && nelems > (1 << 17)
211+
// && dnnl::impl::utils::div_up(
212+
// rt_conf.softmax_axis_size,
213+
// conf.subgroup_size)
214+
// <= 1024)
215+
// return vectorized;
216+
// if (rt_conf.softmax_axis_stride == 1
217+
// && rt_conf.softmax_axis_size <= conf.subgroup_size
218+
// && nelems < (1 << 15))
219+
// return small;
220220
}
221221
if (rt_conf.softmax_axis_size < 6 && nelems > 64000)
222222
return many_reductions_per_workgroup;

src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,11 @@ status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,
351351
if (attn_mask_)
352352
md_mask = make_dnnl_memory_desc(attn_mask_->get_logical_tensor());
353353

354-
auto scale_dt = impl::data_type::undef;
355-
if (scale_) scale_dt = scale_->get_logical_tensor().data_type;
354+
dnnl::memory::desc scale_md;
355+
if (scale_)
356+
scale_md = {dims {1},
357+
static_cast<data_type>(scale_->get_logical_tensor().data_type),
358+
dnnl::memory::format_tag::a};
356359

357360
dnnl::primitive_attr attr, qk_attr, vs_attr;
358361

@@ -376,9 +379,9 @@ status_t sdp_primitive_config_t::init(std::shared_ptr<subgraph_t> &sg,
376379
? alg_kind::softmax_accurate_inf_as_zero
377380
: alg_kind::softmax_accurate;
378381
CHECK(create_sdpa_pd(sdpa_pd_, p_engine.get(), md_q.get(), md_k.get(),
379-
md_v.get(), md_dst.get(), md_mask.get(), dnnl::memory::desc().get(),
380-
scale_dt, invert_scale_, kv_head_number_, mask_type_, softmax_alg,
381-
attr.get(), qk_attr.get(), vs_attr.get()));
382+
md_v.get(), md_dst.get(), md_mask.get(), scale_md.get(),
383+
invert_scale_, kv_head_number_, mask_type_, softmax_alg, attr.get(),
384+
qk_attr.get(), vs_attr.get()));
382385

383386
auto status = sdpa_pd_->create_primitive(sdpa_prim_, p_engine.get());
384387

src/graph/backend/dnnl/op_executable.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2796,12 +2796,14 @@ struct sdpa_executable_t : public op_executable_t {
27962796
auto md_dst = make_dnnl_memory_desc(
27972797
op->get_output_value(0)->get_logical_tensor());
27982798

2799-
auto scale_dt = impl::data_type::undef;
28002799
size_t idx = 3;
2800+
dnnl::memory::desc scale_md;
28012801
if (with_scale_)
2802-
scale_dt = op->get_input_value(idx++)
2803-
->get_logical_tensor()
2804-
.data_type;
2802+
scale_md = {dims {1},
2803+
static_cast<data_type>(op->get_input_value(idx++)
2804+
->get_logical_tensor()
2805+
.data_type),
2806+
dnnl::memory::format_tag::a};
28052807

28062808
dnnl::memory::desc md_mask;
28072809
with_explicit_mask_ = mask_type_ == attn_mask_type::buffer;
@@ -2826,8 +2828,8 @@ struct sdpa_executable_t : public op_executable_t {
28262828
: alg_kind::softmax_accurate;
28272829
status_t s = create_sdpa_pd(sdpa_pd_, p_engine.get(), md_q.get(),
28282830
md_k.get(), md_v.get(), md_dst.get(), md_mask.get(),
2829-
dnnl::memory::desc().get(), scale_dt, is_invert_scale_,
2830-
kv_head_number, mask_type_, softmax_alg, attr.get());
2831+
scale_md.get(), is_invert_scale_, kv_head_number, mask_type_,
2832+
softmax_alg, attr.get());
28312833
if (s != dnnl::impl::status::success) {
28322834
is_initialized_ = false;
28332835
} else {

0 commit comments

Comments
 (0)