-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[webgpu] Use workgroup memory to reduce register pressure #24286
base: main
Are you sure you want to change the base?
Conversation
LGTM thanks |
const min_value : q_element_t = q_element_t(-65504.0); | ||
|
||
// Default SHM usage limit is 16KB in Dawn. | ||
var<workgroup> k_tile : array<array<q_value_t, qkv_head_size_vec>, max_k_step>; // 96 * 2 * 16 = 3KB. | ||
var<workgroup> v_tile : array<array<q_value_t, qkv_head_size_vec>, max_k_step>; // 96 * 2 * 16 = 3KB. | ||
|
||
var<workgroup> o_tile_r : array<array<q_value_t, half_qkv_head_size_vec>, workgroup_size_x>; // 48 * 2 * 64 = 6KB. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain why we need to use workgroup memory apart from private memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps we should remve this // 96 * 2 * 16 = 3KB., the head size is no longer 96 in Phi4. Also the math is confusing it is actually (96/4)816. That is 96 vectorized (96/4), multiplied by qvalue_t which is vec4< f16 > * max_k_step.
Was the comment useful for you ? otherwise we can remove it or say head size is for phi3 and show the math (96/4)816
var<workgroup> v_tile : array<array<q_value_t, qkv_head_size_vec>, max_k_step>; // vec4<f16> * qkv_head_size_vec * max_k_step = 8 * (128/4) * 16 = 4KB. 128 is head_size for phi4. | ||
|
||
// Move half of o_tile from private memory into workgroup memory to reduce register pressure. Note that register spill was observed on Qualcomm if whole o_tile is on private memory. | ||
var<workgroup> o_tile_r : array<array<q_value_t, half_qkv_head_size_vec>, workgroup_size_x>; // vec4<f16> * half_qkv_head_size_vec * workgroup_size_x = 8 * (128/4/2) * 64 = 8KB. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: does ort have line wrapping rules? Chromium would want all lines to be not more than 80 columns. You can just move the comment above each variable declaration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@guschmue Need your help on full perf test to ensure it won't bring regressions on other GPUs. Thanks. |
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows x64 QNN CI Pipeline,web_Debug / build_onnxruntime_web,web_Release / build_onnxruntime_web |
Azure Pipelines successfully started running 5 pipeline(s). |
azp /run build_x64_release,build_x64_release_ep_generic_interface,build_x64_release_vitisai,build_x64_release_xnnpack,build_x86_release,coreml / build-and-test (arm64, Debug),coreml / build-and-test (arm64, Release),coreml / build-and-test (x86_64, Release), |
azp /run cpu / build-and-test (arm64, Debug),cpu / build-and-test (arm64, Release),iphone_simulator (arm64),iphone_simulator (x86_64),Linux QNN CI Pipeline,Python format,wasm_Debug / build-wasm,wasm_Release / build-wasm |
azp /run web_Debug / build_onnxruntime_web,web_Release / build_onnxruntime_web,webgpu_build_x64_RelWithDebInfo, webgpu_external_dawn_build_x64_RelWithDebInfo,webgpu_minimal_build_edge_build_x64_RelWithDebInfo,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64, QNN CI Pipeline,Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows GPU TensorRT CI Pipeline, |
On Qualcomm Adreno X1 GPUs, the previous implementation of the FlashAttentionProgram shader in the WebGPU backend was causing high register pressure, leading to performance degradation. This PR uses workgroup memory to reduce the register pressure and improve performance.
TTFT for phi4 with 1K inputs becomes 10s from 40s on Qualcomm Adreno X1 GPU.