Skip to content

xe: sdpa: pass scale as a scalar kernel parameter (host side scalar memory descriptors) #3412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

pv-pterab-s
Copy link
Contributor

@pv-pterab-s pv-pterab-s commented Jun 11, 2025

Implements major changes:

  1. Addition of host-side-scalar memory descriptors: indicate to primitive descriptors that a scalar input is stored in host memory and is to be passed as a scalar parameter - not pointer - to OpenCL kernels

  2. Modify SDPA to accept a host-side-scalar descriptor for scale input. On such a descriptor, pass scale as a scalar kernel parameter to the OpenCL kernel while maintaining old behavior for pre-existing descriptor types.

  3. Update internal tests for SDPA to include host-side-scalar support.

By passing scale (a single value) as a scalar kernel parameter as opposed to a device memory pointer, this PR avoids costly host/device memory transfers on SDPA invocation (as scale changes).

To utilize these additions, instantiate the SDPA primitive with the descriptor for scale set as host_side_scalar. Then, on SDPA execution, pass the scale value as a 1-element-sized host-side memory object (via an engine of type CPU). Example:

dnnl::engine gpu_engine = dnnl::engine(engine::kind::gpu, 0);
dnnl::stream gpu_stream = dnnl::stream(gpu_engine);
dnnl::engine cpu_engine = dnnl::engine(engine::kind::cpu, 0);
dnnl::stream cpu_stream = dnnl::stream(cpu_eng);

// Create a host-side-scalar memory descriptor for SDPA's primitive descriptor
dnnl_memory_desc_t tmp_scale_md;
dnnl_memory_desc_create_host_side_scalar(&tmp_scale_md, memory::data_type::f16);
memory::desc scale_md = memory::desc(tmp_scale_md);

// Create a host-side memory block of 1-element for SDPA's execution
out.m_scale = memory(memory::desc({1, 1, 1, 1, 1}, 
                             memory::format_tag data_type::f16,
                             p.qdt, abcde), cpu_engine);

// Provide host-side-scalar scale descriptor to SDPA: scale will be passed as scalar kernel parameter
sdpa::primitive_desc sdpa_prim_pd = sdpa::primitive_desc(gpu_engine ... scale_md ...);

std::unordered_map<int, memory> sdpa_args;
sdpa_args.insert({DNNL_ARG_QUERIES, m_query});
sdpa_args.insert({DNNL_ARG_KEYS, m_keys});
sdpa_args.insert({DNNL_ARG_VALUES, m_value});
sdpa_args.insert({DNNL_ARG_DST, m_output});

// Provide host-side memory to SDPA execute: 
sdpa_args.insert({DNNL_ARG_SCALE, m_scale});
sdpa_prim.execute(gpu_stream, sdpa_args);

@github-actions github-actions bot added platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel component:api Codeowner: @oneapi-src/onednn-arch component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch component:common labels Jun 11, 2025
@pv-pterab-s pv-pterab-s changed the title Pryorgal/host scalar sdpa xe: sdpa: pass scale as a scalar kernel parameter (host side scalar memory descriptors) Jun 11, 2025
@pv-pterab-s pv-pterab-s force-pushed the pryorgal/host-scalar-sdpa branch from fd880e1 to 0f885e5 Compare June 11, 2025 10:35
@umar456
Copy link
Contributor

umar456 commented Jun 11, 2025

You should also update the verbose log here: https://github.com/uxlfoundation/oneDNN/blob/main/src/common/verbose.cpp#L1562

@pv-pterab-s pv-pterab-s force-pushed the pryorgal/host-scalar-sdpa branch 2 times, most recently from ceb4d3a to 54d1948 Compare June 17, 2025 10:57
@pv-pterab-s
Copy link
Contributor Author

make test
enable os_win
disable test_device_cpu
enable test_device_gpu
disable build_cpu_runtime_omp
disable build_cpu_runtime_sycl
disable build_cpu_runtime_tbb
enable build_graph
enable compiler_icx-previous
enable compiler_gnu9
enable compiler_clang14
enable compiler_vs2022
disable build_gpu_runtime_sycl
disable benchdnn_all
enable benchdnn_softmax
enable benchdnn_graph
enable arch_gpu_xe-hpc
enable arch_gpu_xe-hpg-atsm
enable arch_gpu_xe-hpg-dg2
disable arch_gpu_xe-lp
disable arch_gpu_xe-lpg
disable arch_gpu_xe-lpg+
enable arch_gpu_xe2-hpg-bmg
disable arch_gpu_xe2-lpg

@pv-pterab-s pv-pterab-s force-pushed the pryorgal/host-scalar-sdpa branch 3 times, most recently from 15dec24 to 2b9207e Compare June 18, 2025 16:42
@pv-pterab-s pv-pterab-s marked this pull request as ready for review June 18, 2025 17:01
@pv-pterab-s pv-pterab-s requested review from a team as code owners June 18, 2025 17:01
@pv-pterab-s pv-pterab-s marked this pull request as draft June 18, 2025 17:03
@pv-pterab-s pv-pterab-s force-pushed the pryorgal/host-scalar-sdpa branch from 2b9207e to b335eec Compare June 18, 2025 17:13
@pv-pterab-s pv-pterab-s marked this pull request as ready for review June 18, 2025 17:14
@pv-pterab-s
Copy link
Contributor Author

make test
enable os_win
disable test_device_cpu
enable test_device_gpu
disable build_cpu_runtime_omp
disable build_cpu_runtime_sycl
disable build_cpu_runtime_tbb
enable build_graph
enable compiler_icx-previous
enable compiler_gnu9
enable compiler_clang14
enable compiler_vs2022
disable build_gpu_runtime_sycl
disable benchdnn_all
enable benchdnn_softmax
enable benchdnn_graph
enable arch_gpu_xe-hpc
enable arch_gpu_xe-hpg-atsm
enable arch_gpu_xe-hpg-dg2
disable arch_gpu_xe-lp
disable arch_gpu_xe-lpg
disable arch_gpu_xe-lpg+
enable arch_gpu_xe2-hpg-bmg
disable arch_gpu_xe2-lpg

@pv-pterab-s pv-pterab-s requested a review from mgouicem June 18, 2025 17:15
Copy link
Contributor

@dzarukin dzarukin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about benchdnn validation?

/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
/// @sa @ref dev_guide_sparsity
dnnl_status_t DNNL_API dnnl_memory_desc_create_host_side_scalar(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++ counter-part?

Comment on lines +214 to +218
#if HOST_SIDE_SCALE
const float scale, const float iscale,
#else
const global SCALE_DATA_T *scale_ptr,
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels like i(nverted)scale can be computed down below to have the same number of arguments.
Once one for both cases, I'd suggest to introduce SCALE_T (or something along the way) that will be float or global SCALE_DATA_T * depending on whether it's a host side scale or not.

Macros in kernel signature is hard to deal with at caller and callee levels, better avoid it whenever possible.

@pv-pterab-s pv-pterab-s force-pushed the pryorgal/host-scalar-sdpa branch from b335eec to 4b7df7a Compare June 19, 2025 20:38
Implements major changes:

1. Addition of host-side-scalar memory descriptors: indicate to
   primitive descriptors that a scalar input is stored in host memory
   and is to be passed as a scalar parameter - not pointer - to OpenCL
   Kernels.

2. Modify SDPA to accept a host-side-scalar descriptor for scale
   input. On such a descriptor, pass scale as a scalar kernel
   parameter to the OpenCL kernel while maintaining old behavior for
   pre-existing descriptor types.

3. Update internal tests for SDPA to include host-side-scalar support.

By passing `scale` (a single value) as a scalar kernel parameter as
opposed to a device memory pointer, this PR avoids costly host/device
memory transfers on SDPA invocation (as scale changes).

To utilize these additions, instantiate the SDPA primitive with the
descriptor for scale set as `host_side_scalar`. Then, on SDPA
execution, pass the `scale` value as a 1-element-sized host-side
memory object (via an engine of type CPU). Example:

```
dnnl::engine gpu_engine = dnnl::engine(engine::kind::gpu, 0);
dnnl::stream gpu_stream = dnnl::stream(gpu_engine);
dnnl::engine cpu_engine = dnnl::engine(engine::kind::cpu, 0);
dnnl::stream cpu_stream = dnnl::stream(cpu_eng);

// Create a host-side-scalar memory descriptor for SDPA's primitive
// descriptor
dnnl_memory_desc_t tmp_scale_md;
dnnl_memory_desc_create_host_side_scalar(&tmp_scale_md,
memory::data_type::f16);
memory::desc scale_md = memory::desc(tmp_scale_md);

// Create a host-side memory block of 1-element for SDPA's execution
out.m_scale = memory(memory::desc({1, 1, 1, 1, 1},
                             memory::format_tag data_type::f16,
                                                          p.qdt,
                                                          abcde),
                                                          cpu_engine);

// Provide host-side-scalar scale descriptor to SDPA: scale will be
// passed as scalar kernel parameter
sdpa::primitive_desc sdpa_prim_pd =
    sdpa::primitive_desc(gpu_engine ... scale_md ...);

std::unordered_map<int, memory> sdpa_args;
sdpa_args.insert({DNNL_ARG_QUERIES, m_query});
sdpa_args.insert({DNNL_ARG_KEYS, m_keys});
sdpa_args.insert({DNNL_ARG_VALUES, m_value});
sdpa_args.insert({DNNL_ARG_DST, m_output});

// Provide host-side memory to SDPA execute:
sdpa_args.insert({DNNL_ARG_SCALE, m_scale});
sdpa_prim.execute(gpu_stream, sdpa_args);
```
@pv-pterab-s pv-pterab-s force-pushed the pryorgal/host-scalar-sdpa branch from 4b7df7a to daafaec Compare June 19, 2025 20:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:api Codeowner: @oneapi-src/onednn-arch component:common component:graph-api Codeowner: @oneapi-src/onednn-graph component:tests Codeowner: @oneapi-src/onednn-arch platform:gpu-intel Codeowner: @oneapi-src/onednn-gpu-intel
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants