-
Notifications
You must be signed in to change notification settings - Fork 1k
xe: sdpa: pass scale as a scalar kernel parameter (host side scalar memory descriptors) #3412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,7 +135,7 @@ static inline status_t sdpa_attr_check(const memory_desc_t *q_desc, | |
static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, | ||
const memory_desc_t *k_md, const memory_desc_t *v_md, | ||
const memory_desc_t *dst_md, const memory_desc_t *attn_mask_md, | ||
data_type_t scale_dt, bool invert_scale, dim_t kv_head_number, | ||
const memory_desc_t *scale_md, bool invert_scale, dim_t kv_head_number, | ||
attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg, | ||
const primitive_attr_t *kq_attr, const primitive_attr_t *vs_attr) { | ||
auto sdpa_desc = sdpa_desc_t(); | ||
|
@@ -153,7 +153,7 @@ static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, | |
sdpa_desc.v_desc = *v_md; | ||
sdpa_desc.dst_desc = *dst_md; | ||
if (attn_mask_md) sdpa_desc.attn_mask_desc = *attn_mask_md; | ||
sdpa_desc.scale_dt = scale_dt; | ||
sdpa_desc.scale_desc = *scale_md; | ||
sdpa_desc.invert_scale = invert_scale; | ||
sdpa_desc.kv_head_number = kv_head_number; | ||
sdpa_desc.mask_type = attn_mask_type; | ||
|
@@ -165,7 +165,7 @@ static inline status_t create_sdpa_pd( | |
std::shared_ptr<primitive_desc_t> &sdpa_pd_, engine_t *engine, | ||
const memory_desc_t *q_md, const memory_desc_t *k_md, | ||
const memory_desc_t *v_md, const memory_desc_t *dst_md, | ||
const memory_desc_t *attn_mask_md, data_type_t scale_dt, | ||
const memory_desc_t *attn_mask_md, const memory_desc_t *scale_md, | ||
bool invert_scale, dim_t kv_head_number, | ||
attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg, | ||
const primitive_attr_t *attr, const primitive_attr_t *kq_attr = nullptr, | ||
|
@@ -175,7 +175,7 @@ static inline status_t create_sdpa_pd( | |
kq_attr, vs_attr)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should add checks in |
||
|
||
auto sdpa_desc = create_sdpa_desc(q_md, k_md, v_md, dst_md, attn_mask_md, | ||
scale_dt, invert_scale, kv_head_number, attn_mask_type, softmax_alg, | ||
scale_md, invert_scale, kv_head_number, attn_mask_type, softmax_alg, | ||
kq_attr, vs_attr); | ||
|
||
primitive_attr_t sdpa_attr = attr ? *attr : default_attr(); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1593,6 +1593,11 @@ std::string init_info_sdpa(const engine_t *e, const pd_t *pd) { | |
} | ||
delimiter = " "; | ||
} | ||
if (pd->with_host_side_scale()) { | ||
ss << delimiter << "scale:host"; | ||
} else { | ||
ss << delimiter << "scale:device"; | ||
} | ||
Comment on lines
+1596
to
+1600
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should go down with the attention mask stuff (around L1615). |
||
if (pd->with_key_zp() || pd->with_value_zp()) { | ||
ss << delimiter << "attr-zero-points:"; | ||
delimiter = ""; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,8 +211,12 @@ DECLARE_2D_TILE_RSELECT(a_scale_tile_type, SUBGROUP_SIZE, ugemm_vs_sg_tile_n, 1, | |
__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void | ||
micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, | ||
const global VAL_DATA_T *V, global DST_DATA_T *A, | ||
const global SCALE_DATA_T *scale_ptr, int d, int k, int q, | ||
const global KEY_ATTR_SCALES_DATA_T *K_scales, | ||
#if HOST_SIDE_SCALE | ||
const float scale, const float iscale, | ||
#else | ||
const global SCALE_DATA_T *scale_ptr, | ||
#endif | ||
Comment on lines
+214
to
+218
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels like i(nverted)scale can be computed down below to have the same number of arguments. Macros in kernel signature is hard to deal with at caller and callee levels, better avoid it whenever possible. |
||
int d, int k, int q, const global KEY_ATTR_SCALES_DATA_T *K_scales, | ||
const global KEY_ATTR_ZP_DATA_T *K_zp, | ||
const global VAL_ATTR_SCALES_DATA_T *V_scales, | ||
const global VAL_ATTR_ZP_DATA_T *V_zp, const int attn_mask_type | ||
|
@@ -335,6 +339,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, | |
#endif | ||
} | ||
|
||
#if !HOST_SIDE_SCALE | ||
/* Load scale */ | ||
float scale = 1.0f; | ||
float iscale = 1.0f; | ||
|
@@ -350,6 +355,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, | |
#endif | ||
scale *= 1.442695f; // log2(e) | ||
} | ||
#endif | ||
|
||
#ifdef PREFETCH_K0 | ||
if (k0end > 0) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C++ counter-part?