Skip to content

xe: sdpa: pass scale as a scalar kernel parameter (host side scalar memory descriptors) #3412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/oneapi/dnnl/dnnl.h
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,21 @@ dnnl_status_t DNNL_API dnnl_memory_desc_create_with_packed_encoding(
dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
dnnl_data_type_t data_type, dnnl_dim_t nnz);

/// Creates a memory descriptor for host-side scalars.
///
/// The created memory descriptor cannot be used to create a memory
/// object. It can only be used to create a primitive descriptor to
/// query the actual memory descriptor (similar to the format tag
/// `any`).
///
/// @param memory_desc Output memory descriptor.
/// @param data_type Elements data type.
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
/// @sa @ref dev_guide_sparsity
dnnl_status_t DNNL_API dnnl_memory_desc_create_host_side_scalar(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++ counter-part?

dnnl_memory_desc_t *memory_desc, dnnl_data_type_t data_type);

/// Creates a memory descriptor for a region inside an area
/// described by an existing memory descriptor.
///
Expand Down
2 changes: 2 additions & 0 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,8 @@ struct memory : public handle<dnnl_memory_t> {
blocked = dnnl_blocked,
/// Format kind for sparse tensors.
sparse = dnnl_format_kind_sparse,
/// Format kind for host-side scalars.
host_side_scalar = dnnl_format_kind_host_side_scalar,
/// A special format kind that indicates that tensor format is opaque.
opaque = dnnl_format_kind_opaque,
};
Expand Down
2 changes: 2 additions & 0 deletions include/oneapi/dnnl/dnnl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ typedef enum {
dnnl_format_kind_opaque,
/// Format kind for sparse tensors.
dnnl_format_kind_sparse,
/// Format kind for host side scalars.
dnnl_format_kind_host_side_scalar,
/// Parameter to allow internal only format kinds without undefined
/// behavior. This parameter is chosen to be valid for so long as
/// sizeof(int) >= 2.
Expand Down
1 change: 1 addition & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ const format_kind_t any = dnnl_format_kind_any;
const format_kind_t blocked = dnnl_blocked;
const format_kind_t opaque = dnnl_format_kind_opaque;
const format_kind_t sparse = dnnl_format_kind_sparse;
const format_kind_t host_side_scalar = dnnl_format_kind_host_side_scalar;

// Internal only format kinds.
const format_kind_t internal_only_start = (format_kind_t)(1 << 8);
Expand Down
20 changes: 20 additions & 0 deletions src/common/memory_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ using namespace dnnl::impl::utils;
namespace dnnl {
namespace impl {

status_t memory_desc_init_host_side_scalar(
memory_desc_t &memory_desc, data_type_t data_type) {
memory_desc.ndims = 1;
memory_desc.dims[0] = 1;
memory_desc.data_type = data_type;
memory_desc.format_kind = format_kind::host_side_scalar;
return success;
}

status_t memory_desc_init_by_tag(memory_desc_t &memory_desc, int ndims,
const dims_t dims, data_type_t data_type, format_tag_t tag) {
if (ndims == 0 || tag == format_tag::undef) {
Expand Down Expand Up @@ -653,6 +662,17 @@ status_t dnnl_memory_desc_create_with_packed_encoding(
return success;
}

status_t dnnl_memory_desc_create_host_side_scalar(
memory_desc_t **memory_desc, data_type_t data_type) {
if (any_null(memory_desc)) return invalid_arguments;

auto md = utils::make_unique<memory_desc_t>();
if (!md) return out_of_memory;
CHECK(memory_desc_init_host_side_scalar(*md, data_type));
(*memory_desc) = md.release();
return success;
}

status_t dnnl_memory_desc_create_submemory(memory_desc_t **memory_desc,
const memory_desc_t *parent_memory_desc, const dims_t dims,
const dims_t offsets) {
Expand Down
4 changes: 4 additions & 0 deletions src/common/memory_desc_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ struct memory_desc_wrapper : public c_compatible {
}
bool is_sparse_desc() const { return format_kind() == format_kind::sparse; }

bool is_host_side_scalar_desc() const {
return format_kind() == format_kind::host_side_scalar;
}

const blocking_desc_t &blocking_desc() const {
assert(is_blocking_desc() || is_sparse_packed_desc());
if (!is_sparse_desc()) return md_->format_desc.blocking;
Expand Down
5 changes: 3 additions & 2 deletions src/common/primitive_hashing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ size_t get_md_hash(const memory_desc_t &md) {
// format desc
switch ((int)md.format_kind) {
case format_kind::undef:
case format_kind::any: break;
case format_kind::any:
case format_kind::host_side_scalar: break;
case format_kind::blocked:
for (int i = 0; i < md.ndims; i++) {
if (md.dims[i] == 1 && md.padded_dims[i] == 1) continue;
Expand Down Expand Up @@ -734,14 +735,14 @@ size_t get_desc_hash(const sdpa_desc_t &desc) {
seed = hash_combine(seed, get_md_hash(desc.q_desc));
seed = hash_combine(seed, get_md_hash(desc.k_desc));
seed = hash_combine(seed, get_md_hash(desc.v_desc));
seed = hash_combine(seed, get_md_hash(desc.scale_desc));
seed = hash_combine(seed, desc.kq_scales.get_hash());
seed = hash_combine(seed, desc.kq_zero_points.get_hash());
seed = hash_combine(seed, desc.vs_scales.get_hash());
seed = hash_combine(seed, desc.vs_zero_points.get_hash());
seed = hash_combine(seed, get_md_hash(desc.dst_desc));
seed = hash_combine(seed, get_md_hash(desc.attn_mask_desc));
// Scale type
seed = hash_combine(seed, static_cast<size_t>(desc.scale_dt));
seed = hash_combine(seed, desc.invert_scale);
seed = hash_combine(seed, desc.kv_head_number);
seed = hash_combine(seed, static_cast<size_t>(desc.mask_type));
Expand Down
2 changes: 1 addition & 1 deletion src/common/primitive_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,13 +577,13 @@ void serialize(serialization_stream_t &sstream, const sdpa_desc_t &desc) {
serialize(sstream, desc.q_desc);
serialize(sstream, desc.k_desc);
serialize(sstream, desc.v_desc);
serialize(sstream, desc.scale_desc);
desc.kq_scales.serialize(sstream);
desc.kq_zero_points.serialize(sstream);
desc.vs_scales.serialize(sstream);
desc.vs_zero_points.serialize(sstream);
serialize(sstream, desc.dst_desc);
serialize(sstream, desc.attn_mask_desc);
sstream.append(desc.scale_dt);
sstream.append(desc.invert_scale);
sstream.append(desc.kv_head_number);
sstream.append(desc.mask_type);
Expand Down
7 changes: 6 additions & 1 deletion src/common/sdpa_pd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,19 @@ struct sdpa_pd_t : public primitive_desc_t {
const memory_desc_t *key_md() const { return &desc_.k_desc; }
const memory_desc_t *val_md() const { return &desc_.v_desc; }
const memory_desc_t *attn_mask_md() const { return &desc_.attn_mask_desc; }
const memory_desc_t *scale_md() const { return &desc_.scale_desc; }

int n_inputs() const override {
return 3 + int(with_attn_mask()) + int(with_attn_scale());
}
int n_outputs() const override { return 1; }

bool with_attn_scale() const {
return (desc_.scale_dt != data_type::undef);
return (scale_md()->data_type != data_type::undef);
}

bool with_host_side_scale() const {
return (scale_md()->format_kind == format_kind::host_side_scalar);
}

bool with_attn_mask() const {
Expand Down
9 changes: 4 additions & 5 deletions src/common/sdpa_test_iface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine,
const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc,
const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc,
const_dnnl_memory_desc_t mask_desc, dnnl_data_type_t scale_dt,
const_dnnl_memory_desc_t mask_desc, const_dnnl_memory_desc_t scale_desc,
bool invert_scale, dnnl_dim_t kv_head_number, int attn_mask_type,
dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr,
const_dnnl_primitive_attr_t kq_attr,
Expand All @@ -39,10 +39,9 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create(
query_desc, key_desc, value_desc, engine, attr, kq_attr, vs_attr));

dnnl::impl::sdpa_desc_t sdpa_desc = dnnl::impl::create_sdpa_desc(query_desc,
key_desc, value_desc, dst_desc, mask_desc,
(dnnl::impl::data_type_t)scale_dt, invert_scale, kv_head_number,
static_cast<attn_mask_type_t>(attn_mask_type), softmax_alg, kq_attr,
vs_attr);
key_desc, value_desc, dst_desc, mask_desc, scale_desc, invert_scale,
kv_head_number, static_cast<attn_mask_type_t>(attn_mask_type),
softmax_alg, kq_attr, vs_attr);
return dnnl::impl::primitive_desc_create(primitive_desc_iface, engine,
(const dnnl::impl::op_desc_t *)&sdpa_desc, nullptr, attr);
}
2 changes: 1 addition & 1 deletion src/common/sdpa_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct sdpa_desc_t : public op_desc_t {
memory_desc_t q_desc; /* queries */
memory_desc_t k_desc; /* keys */
memory_desc_t v_desc; /* values */
memory_desc_t scale_desc; /* scale */

// primitive_attr_t can't be used because of deleted copy-ctor, but desc_t
// must be copyable.
Expand All @@ -79,7 +80,6 @@ struct sdpa_desc_t : public op_desc_t {

memory_desc_t dst_desc;
memory_desc_t attn_mask_desc;
data_type_t scale_dt {};
// invert_scale = false: multiply by scale
// invert_scale = true: divide by scale
bool invert_scale {};
Expand Down
8 changes: 4 additions & 4 deletions src/common/sdpa_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ static inline status_t sdpa_attr_check(const memory_desc_t *q_desc,
static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md,
const memory_desc_t *k_md, const memory_desc_t *v_md,
const memory_desc_t *dst_md, const memory_desc_t *attn_mask_md,
data_type_t scale_dt, bool invert_scale, dim_t kv_head_number,
const memory_desc_t *scale_md, bool invert_scale, dim_t kv_head_number,
attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg,
const primitive_attr_t *kq_attr, const primitive_attr_t *vs_attr) {
auto sdpa_desc = sdpa_desc_t();
Expand All @@ -153,7 +153,7 @@ static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md,
sdpa_desc.v_desc = *v_md;
sdpa_desc.dst_desc = *dst_md;
if (attn_mask_md) sdpa_desc.attn_mask_desc = *attn_mask_md;
sdpa_desc.scale_dt = scale_dt;
sdpa_desc.scale_desc = *scale_md;
sdpa_desc.invert_scale = invert_scale;
sdpa_desc.kv_head_number = kv_head_number;
sdpa_desc.mask_type = attn_mask_type;
Expand All @@ -165,7 +165,7 @@ static inline status_t create_sdpa_pd(
std::shared_ptr<primitive_desc_t> &sdpa_pd_, engine_t *engine,
const memory_desc_t *q_md, const memory_desc_t *k_md,
const memory_desc_t *v_md, const memory_desc_t *dst_md,
const memory_desc_t *attn_mask_md, data_type_t scale_dt,
const memory_desc_t *attn_mask_md, const memory_desc_t *scale_md,
bool invert_scale, dim_t kv_head_number,
attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg,
const primitive_attr_t *attr, const primitive_attr_t *kq_attr = nullptr,
Expand All @@ -175,7 +175,7 @@ static inline status_t create_sdpa_pd(
kq_attr, vs_attr));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should add checks in sdpa_desc_check for scale_md.


auto sdpa_desc = create_sdpa_desc(q_md, k_md, v_md, dst_md, attn_mask_md,
scale_dt, invert_scale, kv_head_number, attn_mask_type, softmax_alg,
scale_md, invert_scale, kv_head_number, attn_mask_type, softmax_alg,
kq_attr, vs_attr);

primitive_attr_t sdpa_attr = attr ? *attr : default_attr();
Expand Down
2 changes: 1 addition & 1 deletion src/common/type_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -979,13 +979,13 @@ inline bool operator==(const sdpa_desc_t &lhs, const sdpa_desc_t &rhs) {
&& COMPARE_DESC_MEMBERS(q_desc)
&& COMPARE_DESC_MEMBERS(k_desc)
&& COMPARE_DESC_MEMBERS(v_desc)
&& COMPARE_DESC_MEMBERS(scale_desc)
&& COMPARE_DESC_MEMBERS(kq_scales)
&& COMPARE_DESC_MEMBERS(kq_zero_points)
&& COMPARE_DESC_MEMBERS(vs_scales)
&& COMPARE_DESC_MEMBERS(vs_zero_points)
&& COMPARE_DESC_MEMBERS(dst_desc)
&& COMPARE_DESC_MEMBERS(attn_mask_desc)
&& COMPARE_DESC_MEMBERS(scale_dt)
&& COMPARE_DESC_MEMBERS(invert_scale)
&& COMPARE_DESC_MEMBERS(kv_head_number)
&& COMPARE_DESC_MEMBERS(mask_type)
Expand Down
5 changes: 5 additions & 0 deletions src/common/verbose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1593,6 +1593,11 @@ std::string init_info_sdpa(const engine_t *e, const pd_t *pd) {
}
delimiter = " ";
}
if (pd->with_host_side_scale()) {
ss << delimiter << "scale:host";
} else {
ss << delimiter << "scale:device";
}
Comment on lines +1596 to +1600
Copy link
Contributor

@atkassen atkassen Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go down with the attention mask stuff (around L1615).

if (pd->with_key_zp() || pd->with_value_zp()) {
ss << delimiter << "attr-zero-points:";
delimiter = "";
Expand Down
10 changes: 8 additions & 2 deletions src/gpu/intel/ocl/micro_sdpa.cl
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,12 @@ DECLARE_2D_TILE_RSELECT(a_scale_tile_type, SUBGROUP_SIZE, ugemm_vs_sg_tile_n, 1,
__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void
micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
const global VAL_DATA_T *V, global DST_DATA_T *A,
const global SCALE_DATA_T *scale_ptr, int d, int k, int q,
const global KEY_ATTR_SCALES_DATA_T *K_scales,
#if HOST_SIDE_SCALE
const float scale, const float iscale,
#else
const global SCALE_DATA_T *scale_ptr,
#endif
Comment on lines +214 to +218
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like i(nverted)scale can be computed down below to have the same number of arguments.
Once one for both cases, I'd suggest to introduce SCALE_T (or something along the way) that will be float or global SCALE_DATA_T * depending on whether it's a host side scale or not.

Macros in kernel signature is hard to deal with at caller and callee levels, better avoid it whenever possible.

int d, int k, int q, const global KEY_ATTR_SCALES_DATA_T *K_scales,
const global KEY_ATTR_ZP_DATA_T *K_zp,
const global VAL_ATTR_SCALES_DATA_T *V_scales,
const global VAL_ATTR_ZP_DATA_T *V_zp, const int attn_mask_type
Expand Down Expand Up @@ -335,6 +339,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
#endif
}

#if !HOST_SIDE_SCALE
/* Load scale */
float scale = 1.0f;
float iscale = 1.0f;
Expand All @@ -350,6 +355,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
#endif
scale *= 1.442695f; // log2(e)
}
#endif

#ifdef PREFETCH_K0
if (k0end > 0) {
Expand Down
39 changes: 37 additions & 2 deletions src/gpu/intel/ocl/micro_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) {
if (pd()->with_value_scales() || pd()->with_value_zp())
kernel_ctx.define_int("VAL_GROUP_SIZE", pd()->value_group_size());

def_data_type(kernel_ctx, d->scale_dt, "SCALE");
def_data_type(kernel_ctx, pd()->scale_md()->data_type, "SCALE");
kernel_ctx.define_int("INVERT_SCALE", d->invert_scale);
kernel_ctx.define_int("WITH_ATTN_SCALE", pd()->with_attn_scale());
kernel_ctx.define_int("ATTN_MASK_UNDEF", attn_mask_type::undef);
Expand Down Expand Up @@ -503,6 +503,8 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) {
kernel_ctx.define_int("SOFTMAX_INF_AS_ZERO",
d->softmax_alg == alg_kind::softmax_accurate_inf_as_zero);

kernel_ctx.define_int("HOST_SIDE_SCALE", pd()->with_host_side_scale());

/* Generate microkernel shims */
ShimOptions shimOptions;
shimOptions.subgroupSize = pd()->sg_size();
Expand Down Expand Up @@ -559,7 +561,40 @@ status_t micro_sdpa_t::execute(const exec_ctx_t &ctx) const {
arg_list.append(qry);
arg_list.append(val);
arg_list.append(dst);
arg_list.append(scale);

if (pd()->with_host_side_scale()) {
const void *handle = scale.data_handle();

const float scale_value = [&]() {
switch (pd()->scale_md()->data_type) {
case data_type::f16: return float(*(float16_t *)handle);
case data_type::bf16: return float(*(bfloat16_t *)handle);
case data_type::f32: return *(float *)handle;
default:
assert(!"Unsupported host-side scale datatype");
return 0.f;
}
}();

float scale = 1.f, iscale = 1.f;
if (pd()->with_attn_scale()) {
if (pd()->desc()->invert_scale) {
iscale = scale_value;
scale = 1.f / iscale;
} else {
scale = scale_value;
iscale = 1.f / scale;
}
}
constexpr float log2e = 1.442695f;
scale *= log2e;

arg_list.append(scale);
arg_list.append(iscale);
} else {
arg_list.append(scale);
}

arg_list.append((int)D);
arg_list.append((int)K);
arg_list.append((int)Q);
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/intel/ocl/ref_sdpa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ struct ref_sdpa_t : public gpu_primitive_t {
def_data_type(kernel_ctx, pd()->qry_md()->data_type, "QRY");
def_data_type(kernel_ctx, pd()->key_md()->data_type, "KEY");
def_data_type(kernel_ctx, pd()->val_md()->data_type, "VAL");
def_data_type(kernel_ctx, pd()->scale_md()->data_type, "SCALE");
def_data_type(kernel_ctx, pd()->dst_md()->data_type, "DST");
def_data_type(kernel_ctx, pd()->attn_mask_md()->data_type, "MSK");
def_data_type(kernel_ctx, pd()->desc()->scale_dt, "SCALE");
CHECK(create_kernel(engine, &kernel_, "ref_sdpa", kernel_ctx));
if (!kernel_) return status::runtime_error;
return status::success;
Expand Down
Loading