From a47fa51ecc9797e0406e50758f98f2105da97d3b Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 9 Sep 2024 13:26:17 -0700 Subject: [PATCH] [ExecuTorch] port bf16 dot product kernel from ATen CPUBlas This makes bf16 stories110M go faster. Differential Revision: [D62154018](https://our.internmc.facebook.com/intern/diff/D62154018/) [ghstack-poisoned] --- kernels/optimized/blas/BlasKernel.cpp | 151 ++++++++++++++++++++++++++ kernels/optimized/blas/BlasKernel.h | 54 +++++++++ 2 files changed, 205 insertions(+) create mode 100644 kernels/optimized/blas/BlasKernel.cpp diff --git a/kernels/optimized/blas/BlasKernel.cpp b/kernels/optimized/blas/BlasKernel.cpp new file mode 100644 index 00000000000..cc619b7061a --- /dev/null +++ b/kernels/optimized/blas/BlasKernel.cpp @@ -0,0 +1,151 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#ifdef __aarch64__ +#include +#endif + +using torch::executor::BFloat16; + +namespace executorch { +namespace cpublas { +namespace internal { +#ifdef __aarch64__ +static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) { +#ifdef __ARM_FEATURE_FMA + return vfmaq_f32(a, b, c); +#else + return vaddq_f32(a, vmulq_f32(b, c)); +#endif +} + +// The below reduce overload and fp16_dot_with_fp32_arith are adapted +// from llama.cpp's ggml_vec_dot_f32 and surrounding utility +// functions. See NOTE [ GGML Copyright Notice ] above for the +// required notice. + +// We need the shift for reduce(), hence the extra constants. +static constexpr auto kF32ElementsPerIterationShift = 5; +static constexpr auto kF32ElementsPerIteration = 1 + << kF32ElementsPerIterationShift; +static_assert(kF32ElementsPerIteration == 32); + +static constexpr auto kF32ElementsPerRegisterShift = 2; +static constexpr auto kF32ElementsPerRegister = 1 + << kF32ElementsPerRegisterShift; +static_assert(kF32ElementsPerRegister == 4); + +static constexpr auto kF32RegisterPairsPerIteration = 4; +static constexpr auto kF32RegistersPerIteration = + kF32RegisterPairsPerIteration * 2; +static constexpr auto kF32RegistersPerIterationShift = 3; +static_assert( + kF32RegistersPerIteration == + kF32ElementsPerIteration / kF32ElementsPerRegister); +static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift); + +static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) { + int offset = kF32RegistersPerIteration; + utils::ForcedUnroll{}( + [&offset, &x](auto idx) { + offset /= 2; + for (int i = 0; i < offset; ++i) { + x[i] = vaddq_f32(x[i], x[offset + i]); + } + }); + return vaddvq_f32(x[0]); +} + +static ET_INLINE float32x4_t to_bfloat16(uint16x4_t u16) { + int32x4_t shift = vdupq_n_s32(16); + return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift)); +} + +static ET_INLINE float32x4_t +f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) { + return f32_fma(a, to_bfloat16(b), to_bfloat16(c)); +} + +static ET_INLINE void dot_with_fp32_arith_main_inner_loop( + const BFloat16* vec1, + const BFloat16* vec2, + float32x4_t sum[kF32RegistersPerIteration], + int registerPairIndex) { + // TODO: detect intrinsic availability, use them if they're available. + // __ARM_FEATURE_BF16 Load a pair of f32 registers at a time. + const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( + &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( + &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + + sum[2 * registerPairIndex] = f32_fma_bf16( + sum[2 * registerPairIndex], + vget_low_u16(temp_vec1), + vget_low_u16(temp_vec2)); + sum[2 * registerPairIndex + 1] = f32_fma_bf16( + sum[2 * registerPairIndex + 1], + vget_high_u16(temp_vec1), + vget_high_u16(temp_vec2)); +} + +static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( + const BFloat16* vec1, + const BFloat16* vec2, + float32x4_t* tailSum, + int idx) { + const auto temp_vec1 = + vld1_u16(reinterpret_cast(&vec1[idx])); + const auto temp_vec2 = + vld1_u16(reinterpret_cast(&vec2[idx])); + *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2); +} + +template +float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { + float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; + const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); + for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) { + const auto* vec1_ = vec1 + j; + const auto* vec2_ = vec2 + j; + utils::ForcedUnroll{}( + [vec1_, vec2_, &sum](auto k) { + dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k); + }); + } + auto reducedSum = reduce(sum); + + // First-tier tail fixup: make sure we handle workloads that can + // benefit from vectorization, but don't fit into our fully unrolled + // loop above. + float32x4_t tailSum = vdupq_n_f32(0); + const auto len_aligned_4 = len & ~3; + for (int j = len_aligned; j < len_aligned_4; j += 4) { + dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j); + } + auto reducedTail = vpaddq_f32(tailSum, tailSum); + reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0); + + // Second-tier tail fixup: handle all workloads. + for (int j = len_aligned_4; j < len; ++j) { + reducedSum += vec1[j] * vec2[j]; + } + return reducedSum; +} + +float bf16_dot_with_fp32_arith( + const BFloat16* vec1, + const BFloat16* vec2, + int64_t len) { + return dot_with_fp32_arith(vec1, vec2, len); +} +#endif +} // namespace internal +} // namespace cpublas +} // namespace executorch diff --git a/kernels/optimized/blas/BlasKernel.h b/kernels/optimized/blas/BlasKernel.h index 10b568c50d3..f594d1748e7 100644 --- a/kernels/optimized/blas/BlasKernel.h +++ b/kernels/optimized/blas/BlasKernel.h @@ -11,8 +11,16 @@ #include #include +#include + #include +namespace torch { +namespace executor { +struct BFloat16; +} // namespace executor +} // namespace torch + namespace executorch { namespace cpublas { @@ -154,6 +162,52 @@ void gemm_transa_( a_ += lda; } } + +#ifdef __aarch64__ +namespace internal { +float bf16_dot_with_fp32_arith(const torch::executor::BFloat16* vec1, const torch::executor::BFloat16* vec2, int64_t len); +} // namespace internal + +template <> +inline void gemm_transa_( + int64_t m, int64_t n, int64_t k, + torch::executor::BFloat16 alpha, + const torch::executor::BFloat16 *a, int64_t lda, + const torch::executor::BFloat16 *b, int64_t ldb, + torch::executor::BFloat16 beta, + torch::executor::BFloat16 *c, int64_t ldc) { + // c = alpha * (a.T @ b) + beta * c +// parallel_for(0, m, 1, [&](int64_t begin, int64_t end) { + if (alpha == 1 && beta == 0) { + const auto *a_ = a; + for (int i = 0; i < m; ++i) { + const auto *b_ = b; + for (int j = 0; j < n; ++j) { + const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k); + b_ += ldb; + c[j*ldc+i] = dot; + } + a_ += lda; + } + return; + } + const auto *a_ = a; + for (int i = 0; i < m; ++i) { + const auto *b_ = b; + for (int j = 0; j < n; ++j) { + const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k); + b_ += ldb; + if (beta == 0) { + c[j*ldc+i] = alpha*dot; + } else { + c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot; + } + } + a_ += lda; + } +} +#endif + // clang-format on template