-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[wasm] Optimize WASM relaxed simd MlasGemmQuantKernel #25048
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[wasm] Optimize WASM relaxed simd MlasGemmQuantKernel #25048
Conversation
This change optimizes MlasGemmQuantKernel for WASM relaxed simd build. | Mlas bench/RPL laptop/node v24.1.0 | baseline | opt | diff | |------------------------------------------------------------------------|----------|---------|------| | QGEMM/UnsignedANoPackB/M:384/N:1024/K:1024/Batch:1/Threads:4/real_time | 2452212 | 1708338 | 44% | | QGEMM/UnsignedANoPackB/M:384/N:1024/K:3072/Batch:1/Threads:4/real_time | 9053789 | 6395584 | 42% | | QGEMM/UnsignedANoPackB/M:384/N:1024/K:4096/Batch:1/Threads:4/real_time | 12109727 | 8189719 | 48% | | QGEMM/UnsignedANoPackB/M:384/N:4096/K:1024/Batch:1/Threads:4/real_time | 11787607 | 7926226 | 49% |
c3246c9
to
729830d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes the WASM relaxed SIMD QGEMM micro kernel by introducing a 6x8 kernel implementation to improve performance. Key changes include:
- Introduction of a generic row-count based implementation (6x8 and 1x8) using templated function GemmQuantKernelNx8Impl.
- Refactoring the accumulation and pointer management logic, including a new DotPairAdd helper for FMA operations.
- Updated kernel stride configuration in the dispatch structure to align with the 6x8 kernel design.
Comments suppressed due to low confidence (2)
onnxruntime/core/mlas/lib/qgemm_kernel_wasmrelaxedsimd.cpp:493
- [nitpick] Consider renaming the lambda 'Tail' to a more descriptive name (e.g., 'outputTail' or 'storeTail') to clarify its purpose in handling partial column outputs.
auto Tail = [&](size_t cols, auto load_c, auto store_c) {
onnxruntime/core/mlas/lib/qgemm_kernel_wasmrelaxedsimd.cpp:541
- [nitpick] It would be helpful to add an inline comment clarifying why the CountM parameter is passed as 0 (since it is ignored) to aid reader understanding.
return GemmQuantKernelNx8Impl<6>(A, B, C, PackedCountK, 0, CountN, ldc,
Thanks @guschmue for reviewing this change! The latest change incorporates Copilot's suggestions. PTAL, thanks! As for the CI failure in the last run (Windows GPU CUDA CI), the log shows four identical failures related to HuggingFace model downloads, which appear unrelated to this change. e.g.
A CI re-run might help clear the transient failures. |
/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline,Windows x64 QNN CI Pipeline |
Azure Pipelines successfully started running 5 pipeline(s). |
Description
This change introduced a 6x8 QGEMM micro kernel for WASM relaxed SIMD build.
Motivation and Context
This change optimizes the performance of QGEMM on x64 devices with AVX-VNNI.