Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,17 @@ def get_fbgemm_avx2_srcs(msvc = False):
"src/UtilsAvx2.cc",
]

def get_fbgemm_inline_avx2_srcs(msvc = False):
return [
#FP16 kernels contain inline assembly and inline assembly syntax for MSVC is different.
"src/FbgemmFP16UKernelsAvx2.cc" if not msvc else "src/FbgemmFP16UKernelsIntrinsicAvx2.cc",
]
def get_fbgemm_inline_avx2_srcs(msvc = False, buck = False):
intrinsics_srcs = ["src/FbgemmFP16UKernelsIntrinsicAvx2.cc"]

#FP16 kernels contain inline assembly and inline assembly syntax for MSVC is different.
asm_srcs = ["src/FbgemmFP16UKernelsAvx2.cc"]
if buck:
return select({
"DEFAULT": asm_srcs if not msvc else intrinsics_srcs,
"ovr_config//cpu:arm64": intrinsics_srcs,
})
return asm_srcs if not msvc else intrinsics_srcs

def get_fbgemm_avx512_srcs(msvc = False):
return [
Expand All @@ -116,12 +122,21 @@ def get_fbgemm_avx512_srcs(msvc = False):
"src/UtilsAvx512.cc",
]

def get_fbgemm_inline_avx512_srcs(msvc = False):
return [
#FP16 kernels contain inline assembly and inline assembly syntax for MSVC is different.
"src/FbgemmFP16UKernelsAvx512.cc" if not msvc else "src/FbgemmFP16UKernelsIntrinsicAvx512.cc",
"src/FbgemmFP16UKernelsAvx512_256.cc" if not msvc else "src/FbgemmFP16UKernelsIntrinsicAvx512_256.cc",
def get_fbgemm_inline_avx512_srcs(msvc = False, buck = False):
intrinsics_srcs = [
"src/FbgemmFP16UKernelsIntrinsicAvx512.cc",
"src/FbgemmFP16UKernelsIntrinsicAvx512_256.cc",
]
asm_srcs = [
"src/FbgemmFP16UKernelsAvx512.cc",
"src/FbgemmFP16UKernelsAvx512_256.cc",
]
if buck:
return select({
"DEFAULT": asm_srcs if not msvc else intrinsics_srcs,
"ovr_config//cpu:arm64": intrinsics_srcs,
})
return asm_srcs if not msvc else intrinsics_srcs

def get_fbgemm_tests(skip_tests = []):
return native.glob(["test/*Test.cc"], exclude = skip_tests)
10 changes: 8 additions & 2 deletions src/FbgemmBfloat16Convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@ namespace fbgemm {
void FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size) {
// Run time CPU detection
if (cpuinfo_initialize()) {
#ifndef __aarch64__
if (fbgemmHasAvx512Support()) {
FloatToBfloat16_avx512(src, dst, size);
} else if (fbgemmHasAvx2Support()) {
} else
#endif
if (fbgemmHasAvx2Support()) {
FloatToBfloat16_avx2(src, dst, size);
} else {
FloatToBfloat16_ref(src, dst, size);
Expand All @@ -59,9 +62,12 @@ void FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size) {
void Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size) {
// Run time CPU detection
if (cpuinfo_initialize()) {
#ifndef __aarch64__
if (fbgemmHasAvx512Support()) {
Bfloat16ToFloat_avx512(src, dst, size);
} else if (fbgemmHasAvx2Support()) {
} else
#endif
if (fbgemmHasAvx2Support()) {
Bfloat16ToFloat_avx2(src, dst, size);
} else {
Bfloat16ToFloat_ref(src, dst, size);
Expand Down
7 changes: 6 additions & 1 deletion src/FbgemmFP16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ constexpr kernel_array_t<float16> kernel_fp16_avx512_256 = {
gemmkernel_14x2_Avx512_256_fp16_fA0fB0fC0};

constexpr kernel_array_t<float16> kernel_fp16_avx512 = {
#ifndef __aarch64__
nullptr,
gemmkernel_1x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_2x2_Avx512_fp16_fA0fB0fC0,
Expand All @@ -63,7 +64,11 @@ constexpr kernel_array_t<float16> kernel_fp16_avx512 = {
gemmkernel_11x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_12x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_13x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_14x2_Avx512_fp16_fA0fB0fC0};
gemmkernel_14x2_Avx512_fp16_fA0fB0fC0
#else
nullptr
#endif
};

} // namespace

Expand Down
2 changes: 0 additions & 2 deletions src/FbgemmFP16UKernelsIntrinsicAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

#ifdef _MSC_VER
#if defined(__x86_64__) || defined(__i386__) || \
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
#include <immintrin.h>
Expand Down Expand Up @@ -115,4 +114,3 @@ void NOINLINE gemmkernel_6x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
}

} // namespace fbgemm
#endif // _MSC_VER
2 changes: 0 additions & 2 deletions src/FbgemmFP16UKernelsIntrinsicAvx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

#ifdef _MSC_VER
#if defined(__x86_64__) || defined(__i386__) || \
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
#include <immintrin.h>
Expand Down Expand Up @@ -140,4 +139,3 @@ void NOINLINE gemmkernel_14x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
}

} // namespace fbgemm
#endif // _MSC_VER
2 changes: 0 additions & 2 deletions src/FbgemmFP16UKernelsIntrinsicAvx512_256.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

#ifdef _MSC_VER
#if defined(__x86_64__) || defined(__i386__) || \
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
#include <immintrin.h>
Expand Down Expand Up @@ -121,4 +120,3 @@ void NOINLINE gemmkernel_14x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
}

} // namespace fbgemm
#endif // _MSC_VER
3 changes: 3 additions & 0 deletions src/QuantUtilsAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ void NO_SANITIZE("address") FusedQuantizeDequantizeAvx2(
float inverse_scale = 1.f / qparams.scale;
constexpr int32_t min_val = std::numeric_limits<T>::min();
constexpr int32_t max_val = std::numeric_limits<T>::max();
(void)inverse_scale; // Suppress unused variable warning
(void)min_val; // Suppress unused variable warning
(void)max_val; // Suppress unused variable warning
#if defined(__AVX2__) && (defined(__FMA__) || defined(_MSC_VER))

constexpr int VLEN = 8;
Expand Down