Skip to content

Commit

Permalink
Enable x86 CPU vectorization on windows [submodule sleef] (#118980)
Browse files Browse the repository at this point in the history
Enable VEC on Windows OS.
1. Fix some type defination gap between Windows and Linux.
2. Fix some operator not support on Windows, such as [], /.
3. Enable static sleef library build on Windows.
4. Disable unsupported function overloading on MSVC.
5. Upgrade submodule sleef lib, which fixed build issue on Windows.
6. Fixed bazel build issues.
7. Fix test app not link to sleef on Windows.

Pull Request resolved: #118980
Approved by: https://github.com/jgong5, https://github.com/ezyang, https://github.com/malfet
  • Loading branch information
xuhancn authored and pytorchmergebot committed Mar 20, 2024
1 parent 666d629 commit aa74a8b
Show file tree
Hide file tree
Showing 19 changed files with 173 additions and 94 deletions.
48 changes: 19 additions & 29 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -419,32 +419,25 @@ if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(s390x|ppc64le)$")
list(APPEND ATen_CPU_DEPENDENCY_LIBS cpuinfo)
endif()

if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
# Preserve values for the main build
set(__aten_sleef_build_shared_libs ${BUILD_SHARED_LIBS})
set(__aten_sleef_build_tests ${BUILD_TESTS})

# Unset our restrictive C++ flags here and reset them later.
# Remove this once we use proper target_compile_options.
set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(CMAKE_CXX_FLAGS)

# Bump up optimization level for sleef to -O1, since at -O0 the compiler
# excessively spills intermediate vector registers to the stack
# and makes things run impossibly slowly
set(OLD_CMAKE_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
if(${CMAKE_C_FLAGS_DEBUG} MATCHES "-O0")
string(REGEX REPLACE "-O0" "-O1" CMAKE_C_FLAGS_DEBUG ${OLD_CMAKE_C_FLAGS_DEBUG})
else()
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -O1")
if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
if(NOT MSVC)
# Bump up optimization level for sleef to -O1, since at -O0 the compiler
# excessively spills intermediate vector registers to the stack
# and makes things run impossibly slowly
set(OLD_CMAKE_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
if(${CMAKE_C_FLAGS_DEBUG} MATCHES "-O0")
string(REGEX REPLACE "-O0" "-O1" CMAKE_C_FLAGS_DEBUG ${OLD_CMAKE_C_FLAGS_DEBUG})
else()
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -O1")
endif()
endif()

if(NOT USE_SYSTEM_SLEEF)
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build sleef static" FORCE)
set(BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE)
set(BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
set(BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
set(OLD_CMAKE_BUILD_TYPE ${CMAKE_BUILD_TYPE})
set(SLEEF_BUILD_SHARED_LIBS OFF CACHE BOOL "Build sleef static" FORCE)
set(SLEEF_BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE)
set(SLEEF_BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
set(SLEEF_BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
set(SLEEF_BUILD_SCALAR_LIB OFF CACHE BOOL "libsleefscalar will be built." FORCE)
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" OR CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
set(DISABLE_SVE ON CACHE BOOL "Xcode's clang-12.5 crashes while trying to compile SVE code" FORCE)
Expand All @@ -465,12 +458,9 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
endif()
list(APPEND ATen_CPU_DEPENDENCY_LIBS sleef)

set(CMAKE_C_FLAGS_DEBUG ${OLD_CMAKE_C_FLAGS_DEBUG})
set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS})

# Set these back. TODO: Use SLEEF_ to pass these instead
set(BUILD_SHARED_LIBS ${__aten_sleef_build_shared_libs} CACHE BOOL "Build shared libs" FORCE)
set(BUILD_TESTS ${__aten_sleef_build_tests} CACHE BOOL "Build tests" FORCE)
if(NOT MSVC)
set(CMAKE_C_FLAGS_DEBUG ${OLD_CMAKE_C_FLAGS_DEBUG})
endif()
endif()

if(USE_CUDA AND NOT USE_ROCM)
Expand Down
14 changes: 8 additions & 6 deletions aten/src/ATen/cpu/vec/vec256/vec256.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
}


#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -94,7 +94,8 @@ inline Vectorized<double> cast<double, int64_t>(const Vectorized<int64_t>& src)
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

#ifndef _MSC_VER
// MSVC is not working well on complex function overload.
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
Expand All @@ -106,9 +107,10 @@ std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorize
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
return _mm256_i32gather_ps(base_addr, vindex, scale);
}

#endif
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

#ifndef _MSC_VER
// MSVC is not working well on complex function overload.
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
Expand All @@ -122,7 +124,7 @@ inline mask_gather(const Vectorized<float>& src, const float* base_addr,
const Vectorized<int32_t>& vindex, Vectorized<float>& mask) {
return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale);
}

#endif
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

// Only works for inputs in the range: [-2^51, 2^51]
Expand Down Expand Up @@ -302,6 +304,6 @@ inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
return flip8(v);
}

#endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#endif // (defined(CPU_CAPABILITY_AVX2)

}} // namepsace at::vec::CPU_CAPABILITY
16 changes: 9 additions & 7 deletions aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>

#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif

Expand All @@ -18,7 +19,7 @@ namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {

#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)

// bfloat16 conversion
static inline void cvtbf16_fp32(const __m128i& a, __m256& o) {
Expand Down Expand Up @@ -265,7 +266,8 @@ static_assert(
}
return b;
}
Vectorized<T> map(const __m256 (*const vop)(__m256)) const {

Vectorized<T> map(SLEEF_CONST __m256 (*vop)(__m256)) const {
__m256 lo, hi;
cvt_to_fp32<T>(values, lo, hi);
const auto o1 = vop(lo);
Expand Down Expand Up @@ -1026,7 +1028,7 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
CONVERT_VECTORIZED_INIT(Half, half);

#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#else // defined(CPU_CAPABILITY_AVX2)

#define CONVERT_NON_VECTORIZED_INIT(type, name) \
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(const Vectorized<type>& a) { \
Expand All @@ -1051,9 +1053,9 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
CONVERT_NON_VECTORIZED_INIT(Half, half);

#endif // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#endif // defined(CPU_CAPABILITY_AVX2)

#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#define LOAD_FP32_VECTORIZED_INIT(type, name) \
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
auto values = _mm_loadu_si128(reinterpret_cast<const __m128i*>(data)); \
Expand All @@ -1072,7 +1074,7 @@ inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vec
LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16);
LOAD_FP32_VECTORIZED_INIT(Half, fp16);

#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#else // defined(CPU_CAPABILITY_AVX2)
#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
__at_align__ float values[Vectorized<float>::size()]; \
Expand Down
7 changes: 4 additions & 3 deletions aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>

#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif

namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {

#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)

template <> class Vectorized<c10::complex<double>> {
private:
Expand Down Expand Up @@ -145,7 +146,7 @@ template <> class Vectorized<c10::complex<double>> {
auto abs = abs_();
auto zero = _mm256_setzero_pd();
auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ);
auto div = values / abs;
auto div = _mm256_div_pd(values, abs);
return _mm256_blendv_pd(div, zero, mask);
}
__m256d real_() const {
Expand Down
7 changes: 4 additions & 3 deletions aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
#include <c10/util/irange.h>
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif

namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {

#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)

template <> class Vectorized<c10::complex<float>> {
private:
Expand Down Expand Up @@ -180,7 +181,7 @@ template <> class Vectorized<c10::complex<float>> {
auto abs = abs_();
auto zero = _mm256_setzero_ps();
auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ);
auto div = values / abs;
auto div = _mm256_div_ps(values, abs);
return _mm256_blendv_ps(div, zero, mask);
}
__m256 real_() const {
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/cpu/vec/vec256/vec256_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif

Expand All @@ -15,7 +16,7 @@ namespace at::vec {
inline namespace CPU_CAPABILITY {


#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)

template <> class Vectorized<double> {
private:
Expand Down
15 changes: 8 additions & 7 deletions aten/src/ATen/cpu/vec/vec256/vec256_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
#include <ATen/cpu/vec/intrinsics.h>
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/irange.h>
#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)
#define SLEEF_STATIC_LIBS
#include <sleef.h>
#endif

namespace at::vec {
// See Note [CPU_CAPABILITY namespace]
inline namespace CPU_CAPABILITY {

#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)

template <> class Vectorized<float> {
private:
Expand Down Expand Up @@ -226,14 +227,14 @@ template <> class Vectorized<float> {
static __m256 vec_factorial_5 =
_mm256_set1_ps(0.00828929059f); // 1/factorial(5)
static __m256 vec_exp_log2ef =
(__m256)_mm256_set1_epi32(0x3fb8aa3b); // log2(e)
_mm256_castsi256_ps(_mm256_set1_epi32(0x3fb8aa3b)); // log2(e)
static __m256 vec_half = _mm256_set1_ps(0.5f);
static __m256 vec_one = _mm256_set1_ps(1.f);
static __m256 vec_zero = _mm256_set1_ps(0.f);
static __m256 vec_two = _mm256_set1_ps(2.f);
static __m256 vec_ln2f = (__m256)_mm256_set1_epi32(0x3f317218); // ln(2)
static __m256 vec_ln_flt_min = (__m256)_mm256_set1_epi32(0xc2aeac50);
static __m256 vec_ln_flt_max = (__m256)_mm256_set1_epi32(0x42b17218);
static __m256 vec_ln2f = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f317218)); // ln(2)
static __m256 vec_ln_flt_min = _mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50));
static __m256 vec_ln_flt_max = _mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218));
static __m256i vec_127 = _mm256_set1_epi32(0x0000007f);
static int n_mantissa_bits = 23;

Expand Down Expand Up @@ -266,7 +267,7 @@ template <> class Vectorized<float> {
auto vec_exp_number_i = _mm256_cvtps_epi32(vec_exp_number);
auto vec_two_pow_n_i = _mm256_add_epi32(vec_exp_number_i, vec_127);
vec_two_pow_n_i = _mm256_slli_epi32(vec_two_pow_n_i, n_mantissa_bits);
auto vec_two_pow_n = (__m256)vec_two_pow_n_i;
auto vec_two_pow_n = _mm256_castsi256_ps(vec_two_pow_n_i);
vec_two_pow_n =
_mm256_blendv_ps(vec_two_pow_n, vec_zero, less_ln_flt_min_mask);

Expand Down
12 changes: 9 additions & 3 deletions aten/src/ATen/cpu/vec/vec256/vec256_qint.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,17 @@
namespace at::vec {
inline namespace CPU_CAPABILITY {

#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX2)

#ifdef _MSC_VER
__declspec(align(64)) struct Vectorizedqi {
protected:
__m256i vals;
#else
struct Vectorizedqi {
protected:
__m256i vals __attribute__((aligned(64)));
#endif

public:
Vectorizedqi() {}
Expand Down Expand Up @@ -133,7 +139,7 @@ inline convert_float_to_int8(at::vec::Vectorized<float> src) {
}

template <typename T>
inline void __attribute__((always_inline)) QuantizeAvx2(
__FORCE_INLINE void QuantizeAvx2(
const float* src,
T* dst,
int len,
Expand Down Expand Up @@ -1331,5 +1337,5 @@ Vectorized<c10::quint8> inline maximum(const Vectorized<c10::quint8>& a, const V
return a.maximum(b);
}

#endif // if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
#endif // if defined(CPU_CAPABILITY_AVX2)
}} // namespace at::vec::CPU_CAPABILITY
14 changes: 8 additions & 6 deletions aten/src/ATen/cpu/vec/vec512/vec512.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
}


#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#if defined(CPU_CAPABILITY_AVX512)

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand All @@ -80,7 +80,8 @@ inline Vectorized<double> cast<double, int64_t>(const Vectorized<int64_t>& src)
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

#ifndef _MSC_VER
// MSVC is not working well on complex function overload.
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
Expand All @@ -92,9 +93,10 @@ std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorize
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
return _mm512_i32gather_ps(vindex, base_addr, scale);
}

#endif
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

#ifndef _MSC_VER
// MSVC is not working well on complex function overload.
template<int64_t scale = 1>
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
Expand All @@ -112,7 +114,7 @@ inline mask_gather(const Vectorized<float>& src, const float* base_addr,
auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ);
return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale);
}

#endif
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template<>
Expand Down Expand Up @@ -270,6 +272,6 @@ inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
return flip8(v);
}

#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
#endif // defined(CPU_CAPABILITY_AVX512)

}}}

1 comment on commit aa74a8b

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #118980 on behalf of https://github.com/huydhn due to Sorry for revert your change one more time but the hard part is that it breaks lot of internal builds (comment)

Please sign in to comment.