Skip to content

Commit

Permalink
Add NEON-accelerated int8mm for bfloat16 (#125290)
Browse files Browse the repository at this point in the history
As apparently `vshlq_u32` is faster than `vcvt_f32_f16`

Refactor NEON `tinygemm_kernel` to rely on `load_as_float32x4` and `load_as_float32x4x2` and implement them for float16 (using vcvt), bfloat16 (using left shift) and plain float32 (not using anything)

As result stories110M run at 60 tokens/sec with f16, but at 66 tokens/sec with bf16 and  75 tokens/sec with f32, though more bandwith demand starts to favor reduced floating types as model size gets bigger.

Pull Request resolved: #125290
Approved by: https://github.com/mikekgfb
  • Loading branch information
malfet authored and pytorchmergebot committed May 1, 2024
1 parent 9fbb4df commit a271514
Showing 1 changed file with 82 additions and 12 deletions.
94 changes: 82 additions & 12 deletions aten/src/ATen/native/cpu/int8mm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,50 @@ inline void tinygemm_kernel(
#if !defined(C10_MOBILE) && defined(__aarch64__)
#include <arm_neon.h>

static inline float reduce(float32x4_t x) {
inline float reduce(float32x4_t x) {
auto sum = vpaddq_f32(x, x);
return vgetq_lane_f32(vpaddq_f32(sum, sum), 0);
}

template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const Half* RESTRICT A,
inline float32x4x2_t load_as_float32x4x2(const Half* ptr) {
float16x8_t f16_val = vld1q_f16(reinterpret_cast<const float16_t *>(ptr));
auto val_low = vcvt_f32_f16(vget_low_f16(f16_val));
auto val_high = vcvt_f32_f16(vget_high_f16(f16_val));
return {val_low, val_high};
}

inline float32x4_t load_as_float32x4(const Half* ptr) {
return vcvt_f32_f16(vld1_f16(reinterpret_cast<const float16_t *>(ptr)));
}

inline float32x4x2_t load_as_float32x4x2(const BFloat16* ptr) {
int32x4_t shift = vdupq_n_s32(16);
uint16x8_t u16_val = vld1q_u16(reinterpret_cast<const uint16_t *>(ptr));
uint32x4_t int_low = vmovl_u16(vget_low_u16(u16_val));
uint32x4_t int_high = vmovl_u16(vget_high_u16(u16_val));
return {vreinterpretq_f32_u32(vshlq_u32(int_low, shift)), vreinterpretq_f32_u32(vshlq_u32(int_high, shift))};
}

inline float32x4_t load_as_float32x4(const BFloat16* ptr) {
int32x4_t shift = vdupq_n_s32(16);
uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast<const uint16_t *>(ptr)));
return vreinterpretq_f32_u32(vshlq_u32(as_int, shift));
}

inline float32x4_t load_as_float32x4(const float* ptr) {
return vld1q_f32(ptr);
}

inline float32x4x2_t load_as_float32x4x2(const float* ptr) {
return {vld1q_f32(ptr), vld1q_f32(ptr + 4)};
}

template <int BLOCK_M, int BLOCK_N, typename T>
inline void tinygemm_kernel_(
const T* RESTRICT A,
const int8_t* RESTRICT B,
const Half* RESTRICT scales,
Half* RESTRICT C,
const T* RESTRICT scales,
T* RESTRICT C,
int lda,
int ldb,
int ldc,
Expand All @@ -207,24 +240,61 @@ inline void tinygemm_kernel(
c_val[i] = vdupq_n_f32(0.0);
});
for (int k = 0; k < K; k += 8) {
float16x8_t a_val = vld1q_f16(reinterpret_cast<const float16_t *>(A) + m * lda + k);
auto a_val_low = vcvt_f32_f16(vget_low_f16(a_val));
auto a_val_high = vcvt_f32_f16(vget_high_f16(a_val));
auto a_val = load_as_float32x4x2(A + m * lda + k);
c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
int16x8_t b_val = vmovl_s8(vld1_s8(B + i * ldb + k));
auto b_val_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_val)));
auto b_val_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_val)));
c_val[i] = vfmaq_f32(c_val[i], a_val_high, b_val_high);
c_val[i] = vfmaq_f32(c_val[i], a_val_low, b_val_low);
c_val[i] = vfmaq_f32(c_val[i], a_val.val[1], b_val_high);
c_val[i] = vfmaq_f32(c_val[i], a_val.val[0], b_val_low);
});
}

float32x4_t scale_val = vcvt_f32_f16(vld1_f16(reinterpret_cast<const float16_t *>(scales)));
float32x4_t scale_val = load_as_float32x4(scales);
c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
C[m * ldc + i] = reduce(c_val[i]) * vgetq_lane_f32(scale_val, i);
});
}
}

template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const Half* RESTRICT A,
const int8_t* RESTRICT B,
const Half* RESTRICT scales,
Half* RESTRICT C,
int lda,
int ldb,
int ldc,
int K) {
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
}

template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const BFloat16* RESTRICT A,
const int8_t* RESTRICT B,
const BFloat16* RESTRICT scales,
BFloat16* RESTRICT C,
int lda,
int ldb,
int ldc,
int K) {
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
}

template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const float* RESTRICT A,
const int8_t* RESTRICT B,
const float* RESTRICT scales,
float* RESTRICT C,
int lda,
int ldb,
int ldc,
int K) {
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
}
#endif

// non-vectorized version
Expand Down

0 comments on commit a271514

Please sign in to comment.