Skip to content

[MLAS] Add 8-bit weights ARM64 Gemm implementation #25110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 46 commits into
base: main
Choose a base branch
from

Conversation

hariharans29
Copy link
Member

@hariharans29 hariharans29 commented Jun 18, 2025

Description

Enable 8-bit weights Gemm on ARM64 via MLAS

  1. Supports 2 flavors of the 8-bit Gemm kernel - one uses vdotq (U8U8) and the other uses vusdotq (U8S8) on platforms where I8MM is supported.

  2. Provides access to these new MLAS Gemm kernels via the MatmulNBits contrib operator

  3. Tests:
    MLAS
    3 new sets of tests:

    • SQ8BitQuantA : Tests the dynamic activation quantization MLAS kernel (fp32 -> uint8_t or fp32 -> int8_t on I8MM platforms)
    • SQ8BitPrepack: Tests the prepacking of the weights for the 8-bit Gemm kernels
    • SQ8BitGemm: Tests the 8-bit Gemm kernels

    MatmulNBits contrib tests

    • Enables the 8-bit Gemm tests on ARM64 (previously only enabled on x86)

Motivation and Context

Enable 8-bit weights Gemm on ARM64 via MLAS

Based on work and contribution by @fajin-corp

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

hariharans29 and others added 4 commits June 25, 2025 12:34
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

hariharans29 and others added 7 commits June 26, 2025 12:11
@hariharans29 hariharans29 changed the title [DO NOT REVIEW] [MLAS] 8 bit weights ARM64 Matmul implementation WIP: [MLAS] 8 bit weights ARM64 Matmul implementation Jun 27, 2025
@hariharans29 hariharans29 changed the title WIP: [MLAS] 8 bit weights ARM64 Matmul implementation [MLAS] 8 bit weights ARM64 Matmul implementation Jun 28, 2025
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

@hariharans29 hariharans29 changed the title [MLAS] 8 bit weights ARM64 Matmul implementation [MLAS] Add 8-bit weights ARM64 Gemm implementation Jun 28, 2025
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@@ -24,7 +24,7 @@
#include "core/session/ort_env.h"
#include "core/util/qmath.h"

#if (defined(MLAS_TARGET_AMD64_IX86) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || defined(USE_CUDA) || defined(USE_WEBGPU)
#if ((defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_ARM64)) && !defined(USE_DML) && !defined(USE_WEBGPU) && !defined(USE_COREML)) || defined(USE_CUDA) || defined(USE_WEBGPU)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enables tests on ARM64

@jywu-msft jywu-msft requested review from edgchen1 and Copilot June 28, 2025 20:24
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for 8-bit weights Gemm on ARM64 via a new MLAS implementation that leverages both vdotq and i8mm instructions. Key changes include updates and additions in test suites (sq8bitgemm, matmul_8bits_test, matmul_4bits_test), integration of a new source file (sqnbitgemm_kernel_neon_int8_i8mm.cpp) with corresponding build system adjustments, and modifications in various MLAS functions to propagate a BlkBitWidth parameter and handle additional block‐sum data.

Reviewed Changes

Copilot reviewed 18 out of 18 changed files in this pull request and generated no comments.

Show a summary per file
File Description
onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp Updates to function signatures (e.g. additional blkSum2 parameters) and test kernel evaluations for ARM64 paths.
onnxruntime/test/contrib_ops/matmul_8bits_test.cc and matmul_4bits_test.cc Renaming test cases to include “4b” or “8b” for clarity and updating test configurations.
onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp and related MLAS files New implementation file added with adjustments for i8mm instructions and extended packing routines including blkSum2.
cmake/onnxruntime_mlas.cmake Build system changes to compile the new source file with proper ARM64 flags (-march=armv8.2-a+i8mm).
onnxruntime/core/mlas/lib/platform.cpp and related header files Updates to dispatch functions and the introduction of the BlkBitWidth parameter with conditional selection for ARM64.
Comments suppressed due to low confidence (3)

onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp:32

  • The addition of 'refBlkSum2_' in the buffer declarations and subsequent changes in PrepackB and CheckBlkSum functions increases complexity. Consider adding a brief comment describing how and why the additional block sum accumulation is used to improve code clarity.
  MatrixGuardBuffer<uint8_t> inputB_, inputZp_, refB_, packedBuffer_;

onnxruntime/core/mlas/lib/platform.cpp:588

  • The flag 'ArmNeonQuantAUnsigned' is deliberately overridden when I8MM support is detected; a comment explaining the rationale behind switching from unsigned to signed mode in this context would help maintainability and clarity for future maintainers.
        this->ArmNeonQuantAUnsigned = false;

cmake/onnxruntime_mlas.cmake:441

  • Ensure that the compile flag '-march=armv8.2-a+i8mm' is consistently used across all ARM64 targets for the new file. Double-check that the flag matches the expected support level when using i8mm instructions and that documentation in the build files reflects this requirement.
        set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp 


if constexpr (QuantAUnsigned) {
{
assert(QuantBBlkSum2 != nullptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion fails on android [QuantAUnsigned = true]: assertion "QuantBBlkSum2 != nullptr" failed when testing with phi-4-mini-instruct: cpu-int4-kquant-block-128-mixed-acc-level-4/v3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants