Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions kernels/optimized/blas/BlasKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/optimized/blas/BlasKernel.h>

#ifdef __aarch64__
#include <arm_neon.h>
#endif

using torch::executor::BFloat16;

namespace executorch {
namespace cpublas {
namespace internal {
#ifdef __aarch64__
static inline float32x4_t f32_fma(float32x4_t a, float32x4_t b, float32x4_t c) {
#ifdef __ARM_FEATURE_FMA
return vfmaq_f32(a, b, c);
#else
return vaddq_f32(a, vmulq_f32(b, c));
#endif
}

// The below reduce overload and fp16_dot_with_fp32_arith are adapted
// from llama.cpp's ggml_vec_dot_f32 and surrounding utility
// functions. See NOTE [ GGML Copyright Notice ] above for the
// required notice.

// We need the shift for reduce(), hence the extra constants.
static constexpr auto kF32ElementsPerIterationShift = 5;
static constexpr auto kF32ElementsPerIteration = 1
<< kF32ElementsPerIterationShift;
static_assert(kF32ElementsPerIteration == 32);

static constexpr auto kF32ElementsPerRegisterShift = 2;
static constexpr auto kF32ElementsPerRegister = 1
<< kF32ElementsPerRegisterShift;
static_assert(kF32ElementsPerRegister == 4);

static constexpr auto kF32RegisterPairsPerIteration = 4;
static constexpr auto kF32RegistersPerIteration =
kF32RegisterPairsPerIteration * 2;
static constexpr auto kF32RegistersPerIterationShift = 3;
static_assert(
kF32RegistersPerIteration ==
kF32ElementsPerIteration / kF32ElementsPerRegister);
static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift);

static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) {
int offset = kF32RegistersPerIteration;
utils::ForcedUnroll<kF32RegistersPerIterationShift>{}(
[&offset, &x](auto idx) {
offset /= 2;
for (int i = 0; i < offset; ++i) {
x[i] = vaddq_f32(x[i], x[offset + i]);
}
});
return vaddvq_f32(x[0]);
}

static ET_INLINE float32x4_t to_bfloat16(uint16x4_t u16) {
int32x4_t shift = vdupq_n_s32(16);
return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift));
}

static ET_INLINE float32x4_t
f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) {
return f32_fma(a, to_bfloat16(b), to_bfloat16(c));
}

static ET_INLINE void dot_with_fp32_arith_main_inner_loop(
const BFloat16* vec1,
const BFloat16* vec2,
float32x4_t sum[kF32RegistersPerIteration],
int registerPairIndex) {
// TODO: detect intrinsic availability, use them if they're available.
// __ARM_FEATURE_BF16 Load a pair of f32 registers at a time.
const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec1[registerPairIndex * 2 * kF32ElementsPerRegister]));
const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast<const uint16_t*>(
&vec2[registerPairIndex * 2 * kF32ElementsPerRegister]));

sum[2 * registerPairIndex] = f32_fma_bf16(
sum[2 * registerPairIndex],
vget_low_u16(temp_vec1),
vget_low_u16(temp_vec2));
sum[2 * registerPairIndex + 1] = f32_fma_bf16(
sum[2 * registerPairIndex + 1],
vget_high_u16(temp_vec1),
vget_high_u16(temp_vec2));
}

static ET_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop(
const BFloat16* vec1,
const BFloat16* vec2,
float32x4_t* tailSum,
int idx) {
const auto temp_vec1 =
vld1_u16(reinterpret_cast<const uint16_t*>(&vec1[idx]));
const auto temp_vec2 =
vld1_u16(reinterpret_cast<const uint16_t*>(&vec2[idx]));
*tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2);
}

template <typename T>
float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) {
float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)};
const auto len_aligned = len & ~(kF32ElementsPerIteration - 1);
for (int j = 0; j < len_aligned; j += kF32ElementsPerIteration) {
const auto* vec1_ = vec1 + j;
const auto* vec2_ = vec2 + j;
utils::ForcedUnroll<kF32RegisterPairsPerIteration>{}(
[vec1_, vec2_, &sum](auto k) {
dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k);
});
}
auto reducedSum = reduce(sum);

// First-tier tail fixup: make sure we handle workloads that can
// benefit from vectorization, but don't fit into our fully unrolled
// loop above.
float32x4_t tailSum = vdupq_n_f32(0);
const auto len_aligned_4 = len & ~3;
for (int j = len_aligned; j < len_aligned_4; j += 4) {
dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j);
}
auto reducedTail = vpaddq_f32(tailSum, tailSum);
reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0);

// Second-tier tail fixup: handle all workloads.
for (int j = len_aligned_4; j < len; ++j) {
reducedSum += vec1[j] * vec2[j];
}
return reducedSum;
}

float bf16_dot_with_fp32_arith(
const BFloat16* vec1,
const BFloat16* vec2,
int64_t len) {
return dot_with_fp32_arith(vec1, vec2, len);
}
#endif
} // namespace internal
} // namespace cpublas
} // namespace executorch
54 changes: 54 additions & 0 deletions kernels/optimized/blas/BlasKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,16 @@
#include <executorch/kernels/optimized/utils/math_utils.h>
#include <executorch/kernels/optimized/utils/unroll.h>

#include <executorch/runtime/core/portable_type/bfloat16.h>

#include <array>

namespace torch {
namespace executor {
struct BFloat16;
} // namespace executor
} // namespace torch

namespace executorch {
namespace cpublas {

Expand Down Expand Up @@ -154,6 +162,52 @@ void gemm_transa_(
a_ += lda;
}
}

#ifdef __aarch64__
namespace internal {
float bf16_dot_with_fp32_arith(const torch::executor::BFloat16* vec1, const torch::executor::BFloat16* vec2, int64_t len);
} // namespace internal

template <>
inline void gemm_transa_<torch::executor::BFloat16, torch::executor::BFloat16>(
int64_t m, int64_t n, int64_t k,
torch::executor::BFloat16 alpha,
const torch::executor::BFloat16 *a, int64_t lda,
const torch::executor::BFloat16 *b, int64_t ldb,
torch::executor::BFloat16 beta,
torch::executor::BFloat16 *c, int64_t ldc) {
// c = alpha * (a.T @ b) + beta * c
// parallel_for(0, m, 1, [&](int64_t begin, int64_t end) {
if (alpha == 1 && beta == 0) {
const auto *a_ = a;
for (int i = 0; i < m; ++i) {
const auto *b_ = b;
for (int j = 0; j < n; ++j) {
const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
b_ += ldb;
c[j*ldc+i] = dot;
}
a_ += lda;
}
return;
}
const auto *a_ = a;
for (int i = 0; i < m; ++i) {
const auto *b_ = b;
for (int j = 0; j < n; ++j) {
const auto dot = internal::bf16_dot_with_fp32_arith(a_, b_, k);
b_ += ldb;
if (beta == 0) {
c[j*ldc+i] = alpha*dot;
} else {
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*dot;
}
}
a_ += lda;
}
}
#endif

// clang-format on

template <typename scalar_t, typename opmath_t>
Expand Down
Loading