Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NEON-accelerated int8mm for bfloat16 #125290

Closed
wants to merge 2 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 63 additions & 12 deletions aten/src/ATen/native/cpu/int8mm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,43 @@ inline void tinygemm_kernel(
#if !defined(C10_MOBILE) && defined(__aarch64__)
#include <arm_neon.h>

static inline float reduce(float32x4_t x) {
inline float reduce(float32x4_t x) {
auto sum = vpaddq_f32(x, x);
return vgetq_lane_f32(vpaddq_f32(sum, sum), 0);
}

template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const Half* RESTRICT A,
inline float32x4x2_t load_as_float32x4x2(const Half* ptr) {
float16x8_t f16_val = vld1q_f16(reinterpret_cast<const float16_t *>(ptr));
auto val_low = vcvt_f32_f16(vget_low_f16(f16_val));
auto val_high = vcvt_f32_f16(vget_high_f16(f16_val));
return {val_low, val_high};
}

inline float32x4_t load_as_float32x4(const Half* ptr) {
return vcvt_f32_f16(vld1_f16(reinterpret_cast<const float16_t *>(ptr)));
}

inline float32x4x2_t load_as_float32x4x2(const BFloat16* ptr) {
int32x4_t shift = vdupq_n_s32(16);
uint16x8_t u16_val = vld1q_u16(reinterpret_cast<const uint16_t *>(ptr));
uint32x4_t int_low = vmovl_u16(vget_low_u16(u16_val));
uint32x4_t int_high = vmovl_u16(vget_high_u16(u16_val));
return {vreinterpretq_f32_u32(vshlq_u32(int_low, shift)), vreinterpretq_f32_u32(vshlq_u32(int_high, shift))};
}

inline float32x4_t load_as_float32x4(const BFloat16* ptr) {
int32x4_t shift = vdupq_n_s32(16);
uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast<const uint16_t *>(ptr)));
return vreinterpretq_f32_u32(vshlq_u32(as_int, shift));
}


template <int BLOCK_M, int BLOCK_N, typename T>
inline void tinygemm_kernel_(
const T* RESTRICT A,
const int8_t* RESTRICT B,
const Half* RESTRICT scales,
Half* RESTRICT C,
const T* RESTRICT scales,
T* RESTRICT C,
int lda,
int ldb,
int ldc,
Expand All @@ -207,24 +233,49 @@ inline void tinygemm_kernel(
c_val[i] = vdupq_n_f32(0.0);
});
for (int k = 0; k < K; k += 8) {
float16x8_t a_val = vld1q_f16(reinterpret_cast<const float16_t *>(A) + m * lda + k);
auto a_val_low = vcvt_f32_f16(vget_low_f16(a_val));
auto a_val_high = vcvt_f32_f16(vget_high_f16(a_val));
auto a_val = load_as_float32x4x2(A + m * lda + k);
c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
int16x8_t b_val = vmovl_s8(vld1_s8(B + i * ldb + k));
auto b_val_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_val)));
auto b_val_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_val)));
c_val[i] = vfmaq_f32(c_val[i], a_val_high, b_val_high);
c_val[i] = vfmaq_f32(c_val[i], a_val_low, b_val_low);
c_val[i] = vfmaq_f32(c_val[i], a_val.val[1], b_val_high);
c_val[i] = vfmaq_f32(c_val[i], a_val.val[0], b_val_low);
});
}

float32x4_t scale_val = vcvt_f32_f16(vld1_f16(reinterpret_cast<const float16_t *>(scales)));
float32x4_t scale_val = load_as_float32x4(scales);
c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
C[m * ldc + i] = reduce(c_val[i]) * vgetq_lane_f32(scale_val, i);
});
}
}

template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const Half* RESTRICT A,
const int8_t* RESTRICT B,
const Half* RESTRICT scales,
Half* RESTRICT C,
int lda,
int ldb,
int ldc,
int K) {
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
}

template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const BFloat16* RESTRICT A,
const int8_t* RESTRICT B,
const BFloat16* RESTRICT scales,
BFloat16* RESTRICT C,
int lda,
int ldb,
int ldc,
int K) {
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
}

#endif

// non-vectorized version
Expand Down
Loading