-
Notifications
You must be signed in to change notification settings - Fork 1k
xe: sdpa: pass scale as a scalar kernel parameter (host side scalar memory descriptors) #3412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
fd880e1
to
0f885e5
Compare
You should also update the verbose log here: https://github.com/uxlfoundation/oneDNN/blob/main/src/common/verbose.cpp#L1562 |
ceb4d3a
to
54d1948
Compare
make test |
15dec24
to
2b9207e
Compare
2b9207e
to
b335eec
Compare
make test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about benchdnn validation?
/// @returns #dnnl_success on success and a status describing the error | ||
/// otherwise. | ||
/// @sa @ref dev_guide_sparsity | ||
dnnl_status_t DNNL_API dnnl_memory_desc_create_host_side_scalar( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C++ counter-part?
#if HOST_SIDE_SCALE | ||
const float scale, const float iscale, | ||
#else | ||
const global SCALE_DATA_T *scale_ptr, | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels like i(nverted)scale can be computed down below to have the same number of arguments.
Once one for both cases, I'd suggest to introduce SCALE_T
(or something along the way) that will be float
or global SCALE_DATA_T *
depending on whether it's a host side scale or not.
Macros in kernel signature is hard to deal with at caller and callee levels, better avoid it whenever possible.
b335eec
to
4b7df7a
Compare
Implements major changes: 1. Addition of host-side-scalar memory descriptors: indicate to primitive descriptors that a scalar input is stored in host memory and is to be passed as a scalar parameter - not pointer - to OpenCL Kernels. 2. Modify SDPA to accept a host-side-scalar descriptor for scale input. On such a descriptor, pass scale as a scalar kernel parameter to the OpenCL kernel while maintaining old behavior for pre-existing descriptor types. 3. Update internal tests for SDPA to include host-side-scalar support. By passing `scale` (a single value) as a scalar kernel parameter as opposed to a device memory pointer, this PR avoids costly host/device memory transfers on SDPA invocation (as scale changes). To utilize these additions, instantiate the SDPA primitive with the descriptor for scale set as `host_side_scalar`. Then, on SDPA execution, pass the `scale` value as a 1-element-sized host-side memory object (via an engine of type CPU). Example: ``` dnnl::engine gpu_engine = dnnl::engine(engine::kind::gpu, 0); dnnl::stream gpu_stream = dnnl::stream(gpu_engine); dnnl::engine cpu_engine = dnnl::engine(engine::kind::cpu, 0); dnnl::stream cpu_stream = dnnl::stream(cpu_eng); // Create a host-side-scalar memory descriptor for SDPA's primitive // descriptor dnnl_memory_desc_t tmp_scale_md; dnnl_memory_desc_create_host_side_scalar(&tmp_scale_md, memory::data_type::f16); memory::desc scale_md = memory::desc(tmp_scale_md); // Create a host-side memory block of 1-element for SDPA's execution out.m_scale = memory(memory::desc({1, 1, 1, 1, 1}, memory::format_tag data_type::f16, p.qdt, abcde), cpu_engine); // Provide host-side-scalar scale descriptor to SDPA: scale will be // passed as scalar kernel parameter sdpa::primitive_desc sdpa_prim_pd = sdpa::primitive_desc(gpu_engine ... scale_md ...); std::unordered_map<int, memory> sdpa_args; sdpa_args.insert({DNNL_ARG_QUERIES, m_query}); sdpa_args.insert({DNNL_ARG_KEYS, m_keys}); sdpa_args.insert({DNNL_ARG_VALUES, m_value}); sdpa_args.insert({DNNL_ARG_DST, m_output}); // Provide host-side memory to SDPA execute: sdpa_args.insert({DNNL_ARG_SCALE, m_scale}); sdpa_prim.execute(gpu_stream, sdpa_args); ```
4b7df7a
to
daafaec
Compare
Implements major changes:
Addition of host-side-scalar memory descriptors: indicate to primitive descriptors that a scalar input is stored in host memory and is to be passed as a scalar parameter - not pointer - to OpenCL kernels
Modify SDPA to accept a host-side-scalar descriptor for scale input. On such a descriptor, pass scale as a scalar kernel parameter to the OpenCL kernel while maintaining old behavior for pre-existing descriptor types.
Update internal tests for SDPA to include host-side-scalar support.
By passing
scale
(a single value) as a scalar kernel parameter as opposed to a device memory pointer, this PR avoids costly host/device memory transfers on SDPA invocation (as scale changes).To utilize these additions, instantiate the SDPA primitive with the descriptor for scale set as
host_side_scalar
. Then, on SDPA execution, pass thescale
value as a 1-element-sized host-side memory object (via an engine of type CPU). Example: