Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Use workgroup memory to reduce register pressure #24286

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

qjia7
Copy link
Contributor

@qjia7 qjia7 commented Apr 3, 2025

On Qualcomm Adreno X1 GPUs, the previous implementation of the FlashAttentionProgram shader in the WebGPU backend was causing high register pressure, leading to performance degradation. This PR uses workgroup memory to reduce the register pressure and improve performance.

TTFT for phi4 with 1K inputs becomes 10s from 40s on Qualcomm Adreno X1 GPU.

@qjia7 qjia7 marked this pull request as ready for review April 3, 2025 06:30
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Apr 3, 2025
@sushraja-msft
Copy link
Contributor

LGTM thanks

const min_value : q_element_t = q_element_t(-65504.0);

// Default SHM usage limit is 16KB in Dawn.
var<workgroup> k_tile : array<array<q_value_t, qkv_head_size_vec>, max_k_step>; // 96 * 2 * 16 = 3KB.
var<workgroup> v_tile : array<array<q_value_t, qkv_head_size_vec>, max_k_step>; // 96 * 2 * 16 = 3KB.

var<workgroup> o_tile_r : array<array<q_value_t, half_qkv_head_size_vec>, workgroup_size_x>; // 48 * 2 * 64 = 6KB.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain why we need to use workgroup memory apart from private memory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we should remve this // 96 * 2 * 16 = 3KB., the head size is no longer 96 in Phi4. Also the math is confusing it is actually (96/4)816. That is 96 vectorized (96/4), multiplied by qvalue_t which is vec4< f16 > * max_k_step.

Was the comment useful for you ? otherwise we can remove it or say head size is for phi3 and show the math (96/4)816

@qjia7 qjia7 requested review from sushraja-msft and guschmue April 7, 2025 03:20
var<workgroup> v_tile : array<array<q_value_t, qkv_head_size_vec>, max_k_step>; // vec4<f16> * qkv_head_size_vec * max_k_step = 8 * (128/4) * 16 = 4KB. 128 is head_size for phi4.

// Move half of o_tile from private memory into workgroup memory to reduce register pressure. Note that register spill was observed on Qualcomm if whole o_tile is on private memory.
var<workgroup> o_tile_r : array<array<q_value_t, half_qkv_head_size_vec>, workgroup_size_x>; // vec4<f16> * half_qkv_head_size_vec * workgroup_size_x = 8 * (128/4/2) * 64 = 8KB.
Copy link
Contributor

@sushraja-msft sushraja-msft Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: does ort have line wrapping rules? Chromium would want all lines to be not more than 80 columns. You can just move the comment above each variable declaration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

sushraja-msft
sushraja-msft previously approved these changes Apr 9, 2025
@qjia7
Copy link
Contributor Author

qjia7 commented Apr 9, 2025

@guschmue Need your help on full perf test to ensure it won't bring regressions on other GPUs. Thanks.

@qjia7 qjia7 requested a review from sushraja-msft April 10, 2025 03:28
@guschmue
Copy link
Contributor

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows x64 QNN CI Pipeline,web_Debug / build_onnxruntime_web,web_Release / build_onnxruntime_web

Copy link

Azure Pipelines successfully started running 5 pipeline(s).

@guschmue
Copy link
Contributor

azp /run build_x64_release,build_x64_release_ep_generic_interface,build_x64_release_vitisai,build_x64_release_xnnpack,build_x86_release,coreml / build-and-test (arm64, Debug),coreml / build-and-test (arm64, Release),coreml / build-and-test (x86_64, Release),

@guschmue
Copy link
Contributor

azp /run cpu / build-and-test (arm64, Debug),cpu / build-and-test (arm64, Release),iphone_simulator (arm64),iphone_simulator (x86_64),Linux QNN CI Pipeline,Python format,wasm_Debug / build-wasm,wasm_Release / build-wasm

@guschmue
Copy link
Contributor

azp /run web_Debug / build_onnxruntime_web,web_Release / build_onnxruntime_web,webgpu_build_x64_RelWithDebInfo, webgpu_external_dawn_build_x64_RelWithDebInfo,webgpu_minimal_build_edge_build_x64_RelWithDebInfo,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64, QNN CI Pipeline,Windows GPU CUDA CI Pipeline,Windows GPU DML CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows GPU TensorRT CI Pipeline,
Windows OpenVINO CI Pipeline,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants