-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[MLAS] Add 8-bit weights ARM64 Gemm implementation #25110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…nnxruntime into hari/matmul8bits_arm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…nnxruntime into hari/matmul8bits_arm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@@ -24,7 +24,7 @@ | |||
#include "core/session/ort_env.h" | |||
#include "core/util/qmath.h" | |||
|
|||
#if (defined(MLAS_TARGET_AMD64_IX86) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || defined(USE_CUDA) || defined(USE_WEBGPU) | |||
#if ((defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_ARM64)) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || defined(USE_CUDA) || defined(USE_WEBGPU) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Enables tests on ARM64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for 8-bit weights Gemm on ARM64 via a new MLAS implementation that leverages both vdotq and i8mm instructions. Key changes include updates and additions in test suites (sq8bitgemm, matmul_8bits_test, matmul_4bits_test), integration of a new source file (sqnbitgemm_kernel_neon_int8_i8mm.cpp) with corresponding build system adjustments, and modifications in various MLAS functions to propagate a BlkBitWidth parameter and handle additional block‐sum data.
Reviewed Changes
Copilot reviewed 18 out of 18 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp | Updates to function signatures (e.g. additional blkSum2 parameters) and test kernel evaluations for ARM64 paths. |
onnxruntime/test/contrib_ops/matmul_8bits_test.cc and matmul_4bits_test.cc | Renaming test cases to include “4b” or “8b” for clarity and updating test configurations. |
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp and related MLAS files | New implementation file added with adjustments for i8mm instructions and extended packing routines including blkSum2. |
cmake/onnxruntime_mlas.cmake | Build system changes to compile the new source file with proper ARM64 flags (-march=armv8.2-a+i8mm). |
onnxruntime/core/mlas/lib/platform.cpp and related header files | Updates to dispatch functions and the introduction of the BlkBitWidth parameter with conditional selection for ARM64. |
Comments suppressed due to low confidence (3)
onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp:32
- The addition of 'refBlkSum2_' in the buffer declarations and subsequent changes in PrepackB and CheckBlkSum functions increases complexity. Consider adding a brief comment describing how and why the additional block sum accumulation is used to improve code clarity.
MatrixGuardBuffer<uint8_t> inputB_, inputZp_, refB_, packedBuffer_;
onnxruntime/core/mlas/lib/platform.cpp:588
- The flag 'ArmNeonQuantAUnsigned' is deliberately overridden when I8MM support is detected; a comment explaining the rationale behind switching from unsigned to signed mode in this context would help maintainability and clarity for future maintainers.
this->ArmNeonQuantAUnsigned = false;
cmake/onnxruntime_mlas.cmake:441
- Ensure that the compile flag '-march=armv8.2-a+i8mm' is consistently used across all ARM64 targets for the new file. Double-check that the flag matches the expected support level when using i8mm instructions and that documentation in the build files reflects this requirement.
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp
|
||
if constexpr (QuantAUnsigned) { | ||
{ | ||
assert(QuantBBlkSum2 != nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assertion fails on android [QuantAUnsigned = true]: assertion "QuantBBlkSum2 != nullptr" failed
when testing with phi-4-mini-instruct: cpu-int4-kquant-block-128-mixed-acc-level-4/v3
Description
Enable 8-bit weights Gemm on ARM64 via MLAS
Supports 2 flavors of the 8-bit Gemm kernel - one uses
vdotq
(U8U8) and the other usesvusdotq
(U8S8) on platforms where I8MM is supported.Provides access to these new MLAS Gemm kernels via the
MatmulNBits
contrib operatorTests:
MLAS
3 new sets of tests:
SQ8BitQuantA
: Tests the dynamic activation quantization MLAS kernel (fp32 -> uint8_t
orfp32 -> int8_t
on I8MM platforms)SQ8BitPrepack
: Tests the prepacking of the weights for the 8-bit Gemm kernelsSQ8BitGemm
: Tests the 8-bit Gemm kernelsMatmulNBits contrib tests
Motivation and Context
Enable 8-bit weights Gemm on ARM64 via MLAS
Based on work and contribution by @fajin-corp