diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 106dd098dccc3..4f2fe2615e5cf 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -132,7 +132,9 @@ fi if [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX-* || $TEST_CONFIG == 'nogpu_NO_AVX' ]]; then export ATEN_CPU_CAPABILITY=default elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX2-* || $TEST_CONFIG == 'nogpu_NO_AVX2' ]]; then - export ATEN_CPU_CAPABILITY=avx + export ATEN_CPU_CAPABILITY=default +elif [[ "${BUILD_ENVIRONMENT}" == *-NO_AVX512-* || $TEST_CONFIG == 'nogpu_NO_AVX512' ]]; then + export ATEN_CPU_CAPABILITY=avx2 fi if [ -n "$IN_PULL_REQUEST" ] && [[ "$BUILD_ENVIRONMENT" != *coverage* ]]; then diff --git a/aten.bzl b/aten.bzl index 6bce36ca904ca..c2fcee7323d8e 100644 --- a/aten.bzl +++ b/aten.bzl @@ -1,9 +1,8 @@ load("@rules_cc//cc:defs.bzl", "cc_library") -CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX", "AVX2"] +CPU_CAPABILITY_NAMES = ["DEFAULT", "AVX2"] CAPABILITY_COMPILER_FLAGS = { "AVX2": ["-mavx2", "-mfma"], - "AVX": ["-mavx"], "DEFAULT": [], } diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index cf4c192d5de44..3e160a201022f 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -50,7 +50,7 @@ if(NOT BUILD_LITE_INTERPRETER) endif() EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS}) -file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec256/*.h" "cpu/vec/*.h" "quantized/*.h") +file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/*.h" "quantized/*.h") file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp") file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh") file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp") diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp index 43d81cbdbe334..750c90bb4c59f 100644 --- a/aten/src/ATen/Version.cpp +++ b/aten/src/ATen/Version.cpp @@ -108,12 +108,12 @@ std::string used_cpu_capability() { case native::CPUCapability::DEFAULT: ss << "NO AVX"; break; - case native::CPUCapability::AVX: - ss << "AVX"; - break; case native::CPUCapability::AVX2: ss << "AVX2"; break; + case native::CPUCapability::AVX512: + ss << "AVX512"; + break; #endif default: break; diff --git a/aten/src/ATen/cpu/FlushDenormal.cpp b/aten/src/ATen/cpu/FlushDenormal.cpp index 7c7df405be50c..c1d330f6a74c9 100644 --- a/aten/src/ATen/cpu/FlushDenormal.cpp +++ b/aten/src/ATen/cpu/FlushDenormal.cpp @@ -1,6 +1,5 @@ #include - -#include +#include #include namespace at { namespace cpu { diff --git a/aten/src/ATen/cpu/vec/functional.h b/aten/src/ATen/cpu/vec/functional.h index c9a9c443a1638..210ae9e9e883b 100644 --- a/aten/src/ATen/cpu/vec/functional.h +++ b/aten/src/ATen/cpu/vec/functional.h @@ -1 +1,6 @@ -#include +#pragma once + +#include +#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX) +#include +#endif diff --git a/aten/src/ATen/cpu/vec/vec256/functional_base.h b/aten/src/ATen/cpu/vec/functional_base.h similarity index 99% rename from aten/src/ATen/cpu/vec/vec256/functional_base.h rename to aten/src/ATen/cpu/vec/functional_base.h index 519f1008788ea..7bd04e637c7e3 100644 --- a/aten/src/ATen/cpu/vec/vec256/functional_base.h +++ b/aten/src/ATen/cpu/vec/functional_base.h @@ -3,7 +3,7 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include +#include namespace at { namespace vec { diff --git a/aten/src/ATen/cpu/vec/vec256/functional_bfloat16.h b/aten/src/ATen/cpu/vec/functional_bfloat16.h similarity index 97% rename from aten/src/ATen/cpu/vec/vec256/functional_bfloat16.h rename to aten/src/ATen/cpu/vec/functional_bfloat16.h index 442c587a6c248..9efa7004090bb 100644 --- a/aten/src/ATen/cpu/vec/vec256/functional_bfloat16.h +++ b/aten/src/ATen/cpu/vec/functional_bfloat16.h @@ -3,7 +3,7 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include +#include namespace at { namespace vec { @@ -15,26 +15,26 @@ template <> struct VecScalarType { using type = float; }; template using vec_scalar_t = typename VecScalarType::type; -// Note that we already have specializes member of Vectorized for BFloat16 -// so the following function would run smoothly: +// Note that we already have specialized member of Vectorized for BFloat16 +// so the following functions would run smoothly: // using Vec = Vectorized; // Vec one = Vec(BFloat16(1)); // vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N); // -// Why we still need to specializes "funtional"? +// Then why we still need to specialize "funtional"? // If we do specialization at Vectorized<> level, the above example would need 3 pairs of -// conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/". +// conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/". // If we do specialization at vec::map<>() level, we have only 1 pair of conversion -// of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only. +// of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only. // -// The following BFloat16 functionalities will only do data type conversion for input -// and output vector (reduce functionalities will only convert the final scalar back to bf16). +// The following BFloat16 functionality will only do data type conversion for input +// and output vector (reduce functionality will only convert the final scalar back to bf16). // Compared to Vectorized<> specialization, // 1. better performance since we have less data type conversion; // 2. less rounding error since immediate results are kept in fp32; // 3. accumulation done on data type of fp32. // -// If you plan to extend this file, make sure add unit test at +// If you plan to extend this file, please ensure adding unit tests at // aten/src/ATen/test/vec_test_all_types.cpp // template diff --git a/aten/src/ATen/cpu/vec/vec256/intrinsics.h b/aten/src/ATen/cpu/vec/intrinsics.h similarity index 86% rename from aten/src/ATen/cpu/vec/vec256/intrinsics.h rename to aten/src/ATen/cpu/vec/intrinsics.h index 5ac4d484ccdc3..a6a73e232e112 100644 --- a/aten/src/ATen/cpu/vec/vec256/intrinsics.h +++ b/aten/src/ATen/cpu/vec/intrinsics.h @@ -1,6 +1,6 @@ #pragma once -#if defined(__clang__) && (defined(__x86_64__) || defined(__i386__)) -/* Clang-compatible compiler, targeting x86/x86-64 */ +#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) +/* GCC or clang-compatible compiler, targeting x86/x86-64 */ #include #elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__)) /* Clang-compatible compiler, targeting arm neon */ @@ -14,9 +14,6 @@ #define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8)) #define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16)) #endif -#elif defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__)) -/* GCC-compatible compiler, targeting x86/x86-64 */ -#include #elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__)) /* GCC-compatible compiler, targeting ARM with NEON */ #include diff --git a/aten/src/ATen/cpu/vec/vec.h b/aten/src/ATen/cpu/vec/vec.h index 5a041f2df7090..24b8818d2a8d8 100644 --- a/aten/src/ATen/cpu/vec/vec.h +++ b/aten/src/ATen/cpu/vec/vec.h @@ -1 +1,5 @@ +#if defined(CPU_CAPABILITY_AVX512) +#include +#else #include +#endif diff --git a/aten/src/ATen/cpu/vec/vec256/functional.h b/aten/src/ATen/cpu/vec/vec256/functional.h deleted file mode 100644 index e1ddc3c9cc760..0000000000000 --- a/aten/src/ATen/cpu/vec/vec256/functional.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once - -#include -#if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX) -#include -#endif diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h index 4200e6729afcb..0d13458bc4c1c 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256.h @@ -3,9 +3,9 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include +#include -#include +#include #if !defined(__VSX__) || !defined(CPU_CAPABILITY_VSX) #include #include @@ -68,9 +68,9 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) { } -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template<> inline Vectorized cast(const Vectorized& src) { @@ -82,29 +82,6 @@ inline Vectorized cast(const Vectorized& src) { return _mm256_castps_pd(src); } -#if defined(CPU_CAPABILITY_AVX2) - -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -#define DEFINE_FLOAT_INT_CAST(int_t, float_t, float_ch) \ -template<> \ -inline Vectorized cast(const Vectorized& src) { \ - return _mm256_castp ## float_ch ## _si256(src); \ -} \ -template<> \ -inline Vectorized cast(const Vectorized& src) { \ - return _mm256_castsi256_p ## float_ch (src); \ -} - -DEFINE_FLOAT_INT_CAST(int64_t, double, d) -DEFINE_FLOAT_INT_CAST(int32_t, double, d) -DEFINE_FLOAT_INT_CAST(int16_t, double, d) -DEFINE_FLOAT_INT_CAST(int64_t, float, s) -DEFINE_FLOAT_INT_CAST(int32_t, float, s) -DEFINE_FLOAT_INT_CAST(int16_t, float, s) - -#undef DEFINE_FLOAT_INT_CAST - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template @@ -243,8 +220,6 @@ inline deinterleave2(const Vectorized& a, const Vectorized& _mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart } -#endif // defined(CPU_CAPABILITY_AVX2) - -#endif // (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) }}} diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h index 7d61d8a9d38e3..82a8200ce2c03 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h @@ -3,8 +3,8 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include -#include +#include +#include #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #include #endif @@ -100,7 +100,7 @@ template <> class Vectorized { return _mm256_loadu_si256(reinterpret_cast(ptr)); } static Vectorized loadu(const void* ptr, int16_t count) { - __at_align32__ int16_t tmp_values[size()]; + __at_align__ int16_t tmp_values[size()]; std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); return loadu(tmp_values); } @@ -108,14 +108,14 @@ template <> class Vectorized { if (count == size()) { _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { - __at_align32__ int16_t tmp_values[size()]; + __at_align__ int16_t tmp_values[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); } } template static Vectorized blend(const Vectorized& a, const Vectorized& b) { - __at_align32__ int16_t tmp_values[size()]; + __at_align__ int16_t tmp_values[size()]; a.store(tmp_values); if (mask & 0x01) tmp_values[0] = _mm256_extract_epi16(b.values, 0); @@ -280,7 +280,7 @@ template <> class Vectorized { Vectorized erfinv() const { __m256 lo, hi; cvtbf16_fp32(values, lo, hi); - __at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; _mm256_storeu_ps(reinterpret_cast(tmp1), lo); _mm256_storeu_ps(reinterpret_cast(tmp2), hi); for (int64_t i = 0; i < size() / 2; i++) { @@ -318,7 +318,7 @@ template <> class Vectorized { Vectorized i0() const { __m256 lo, hi; cvtbf16_fp32(values, lo, hi); - __at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; _mm256_storeu_ps(reinterpret_cast(tmp1), lo); _mm256_storeu_ps(reinterpret_cast(tmp2), hi); for (int64_t i = 0; i < size() / 2; i++) { @@ -333,7 +333,7 @@ template <> class Vectorized { __m256 lo, hi; cvtbf16_fp32(values, lo, hi); constexpr auto sz = size(); - __at_align32__ float tmp1[sz / 2], tmp2[sz / 2]; + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; _mm256_storeu_ps(reinterpret_cast(tmp1), lo); _mm256_storeu_ps(reinterpret_cast(tmp2), hi); @@ -350,10 +350,10 @@ template <> class Vectorized { __m256 xlo, xhi; cvtbf16_fp32(values, lo, hi); cvtbf16_fp32(x.values, xlo, xhi); - __at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; _mm256_storeu_ps(reinterpret_cast(tmp1), lo); _mm256_storeu_ps(reinterpret_cast(tmp2), hi); - __at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2]; + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); for (int64_t i = 0; i < size() / 2; ++i) { @@ -370,10 +370,10 @@ template <> class Vectorized { __m256 xlo, xhi; cvtbf16_fp32(values, lo, hi); cvtbf16_fp32(x.values, xlo, xhi); - __at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; _mm256_storeu_ps(reinterpret_cast(tmp1), lo); _mm256_storeu_ps(reinterpret_cast(tmp2), hi); - __at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2]; + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; _mm256_storeu_ps(reinterpret_cast(tmpx1), xlo); _mm256_storeu_ps(reinterpret_cast(tmpx2), xhi); for (int64_t i = 0; i < size() / 2; ++i) { @@ -717,12 +717,13 @@ inline Vectorized convert_float_bfloat16(const Vectorized& a, c return cvtfp32_bf16(__m256(a), __m256(b)); } -#else //defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) + +#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) inline std::tuple, Vectorized> convert_bfloat16_float(const Vectorized& a) { constexpr int64_t K = Vectorized::size(); - __at_align32__ float arr[K]; - __at_align32__ BFloat16 arr2[K]; + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; a.store(arr2); convert(arr2, arr, K); return std::make_tuple( @@ -732,15 +733,15 @@ inline std::tuple, Vectorized> convert_bfloat16_float(c inline Vectorized convert_float_bfloat16(const Vectorized& a, const Vectorized& b) { constexpr int64_t K = Vectorized::size(); - __at_align32__ float arr[K]; - __at_align32__ BFloat16 arr2[K]; + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; a.store(arr); b.store(arr + Vectorized::size()); convert(arr, arr2, K); return Vectorized::loadu(arr2); } -#endif +#endif // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out) { @@ -759,7 +760,7 @@ void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out1, Vec } #else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out) { - __at_align32__ float values[Vectorized::size()]; + __at_align__ float values[Vectorized::size()]; for (int k = 0; k < Vectorized::size(); ++k) { values[k] = data[k]; } diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h index f96aea6e09ebd..40276ba8365d5 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h @@ -4,9 +4,10 @@ // See Note [Do not compile initializers with AVX] #include -#include -#include -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#include +#include + +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #include #endif @@ -15,7 +16,7 @@ namespace vec { // See Note [Acceptable use of anonymous namespace in header] namespace { -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) template <> class Vectorized> { private: @@ -81,7 +82,7 @@ template <> class Vectorized> { if (count == size()) return _mm256_loadu_pd(reinterpret_cast(ptr)); - __at_align32__ double tmp_values[2*size()]; + __at_align__ double tmp_values[2*size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -106,7 +107,7 @@ template <> class Vectorized> { const c10::complex& operator[](int idx) const = delete; c10::complex& operator[](int idx) = delete; Vectorized> map(c10::complex (*const f)(const c10::complex &)) const { - __at_align32__ c10::complex tmp[size()]; + __at_align__ c10::complex tmp[size()]; store(tmp); for (int i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -288,8 +289,8 @@ template <> class Vectorized> { return sqrt().reciprocal(); } Vectorized> pow(const Vectorized> &exp) const { - __at_align32__ c10::complex x_tmp[size()]; - __at_align32__ c10::complex y_tmp[size()]; + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; store(x_tmp); exp.store(y_tmp); for (int i = 0; i < size(); i++) { diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h index 5494828b56501..f40196320022b 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h @@ -4,9 +4,9 @@ // See Note [Do not compile initializers with AVX] #include -#include -#include -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#include +#include +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #include #endif @@ -15,7 +15,7 @@ namespace vec { // See Note [Acceptable use of anonymous namespace in header] namespace { -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) template <> class Vectorized> { private: @@ -117,7 +117,7 @@ template <> class Vectorized> { if (count == size()) return _mm256_loadu_ps(reinterpret_cast(ptr)); - __at_align32__ float tmp_values[2*size()]; + __at_align__ float tmp_values[2*size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -142,7 +142,7 @@ template <> class Vectorized> { const c10::complex& operator[](int idx) const = delete; c10::complex& operator[](int idx) = delete; Vectorized> map(c10::complex (*const f)(const c10::complex &)) const { - __at_align32__ c10::complex tmp[size()]; + __at_align__ c10::complex tmp[size()]; store(tmp); for (int i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -323,8 +323,8 @@ template <> class Vectorized> { return sqrt().reciprocal(); } Vectorized> pow(const Vectorized> &exp) const { - __at_align32__ c10::complex x_tmp[size()]; - __at_align32__ c10::complex y_tmp[size()]; + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; store(x_tmp); exp.store(y_tmp); for (int i = 0; i < size(); i++) { diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_double.h index 1c575b9a28c7a..f92f44e562a9d 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_double.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_double.h @@ -3,9 +3,9 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include -#include -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#include +#include +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #include #endif @@ -14,7 +14,8 @@ namespace vec { // See Note [Acceptable use of anonymous namespace in header] namespace { -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) + +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) template <> class Vectorized { private: @@ -67,7 +68,7 @@ template <> class Vectorized { return _mm256_loadu_pd(reinterpret_cast(ptr)); - __at_align32__ double tmp_values[size()]; + __at_align__ double tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -100,7 +101,7 @@ template <> class Vectorized { return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q); } Vectorized map(double (*const f)(double)) const { - __at_align32__ double tmp[size()]; + __at_align__ double tmp[size()]; store(tmp); for (int64_t i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -175,8 +176,8 @@ template <> class Vectorized { return map(calc_i0e); } Vectorized igamma(const Vectorized &x) const { - __at_align32__ double tmp[size()]; - __at_align32__ double tmp_x[size()]; + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { @@ -185,8 +186,8 @@ template <> class Vectorized { return loadu(tmp); } Vectorized igammac(const Vectorized &x) const { - __at_align32__ double tmp[size()]; - __at_align32__ double tmp_x[size()]; + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_float.h index 1f4c3f63477c1..deb9542984373 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_float.h @@ -3,9 +3,9 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include -#include -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#include +#include +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) #include #endif @@ -14,7 +14,7 @@ namespace vec { // See Note [Acceptable use of anonymous namespace in header] namespace { -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) template <> class Vectorized { private: @@ -76,7 +76,7 @@ template <> class Vectorized { static Vectorized loadu(const void* ptr, int64_t count = size()) { if (count == size()) return _mm256_loadu_ps(reinterpret_cast(ptr)); - __at_align32__ float tmp_values[size()]; + __at_align__ float tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -107,7 +107,7 @@ template <> class Vectorized { return _mm256_cmp_ps(values, _mm256_set1_ps(0.0f), _CMP_UNORD_Q); } Vectorized map(float (*const f)(float)) const { - __at_align32__ float tmp[size()]; + __at_align__ float tmp[size()]; store(tmp); for (int64_t i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -213,8 +213,8 @@ template <> class Vectorized { return map(calc_i0e); } Vectorized igamma(const Vectorized &x) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_x[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { @@ -223,8 +223,8 @@ template <> class Vectorized { return loadu(tmp); } Vectorized igammac(const Vectorized &x) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_x[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { @@ -412,12 +412,11 @@ inline void convert(const float* src, float* dst, int64_t n) { } } -#ifdef CPU_CAPABILITY_AVX2 + template <> Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { return _mm256_fmadd_ps(a, b, c); } -#endif #endif diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h b/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h index b39d808a13a3c..2aac442d2123d 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h @@ -3,8 +3,8 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include -#include +#include +#include // Sleef offers vectorized versions of some transcedentals // such as sin, cos, tan etc.. // However for now opting for STL, since we are not building @@ -220,7 +220,7 @@ template <> class Vectorized { return res; } else { - __at_align32__ float tmp_values[size()]; + __at_align__ float tmp_values[size()]; for (auto i = 0; i < size(); ++i) { tmp_values[i] = 0.0; } @@ -261,19 +261,19 @@ template <> class Vectorized { // Once we specialize that implementation for ARM // this should be removed. TODO (kimishpatel) float operator[](int idx) const { - __at_align32__ float tmp[size()]; + __at_align__ float tmp[size()]; store(tmp); return tmp[idx]; } float operator[](int idx) { - __at_align32__ float tmp[size()]; + __at_align__ float tmp[size()]; store(tmp); return tmp[idx]; } // For boolean version where we want to if any 1/all zero // etc. can be done faster in a different way. int zero_mask() const { - __at_align32__ float tmp[size()]; + __at_align__ float tmp[size()]; store(tmp); int mask = 0; for (int i = 0; i < size(); ++ i) { @@ -284,8 +284,8 @@ template <> class Vectorized { return mask; } Vectorized isnan() const { - __at_align32__ float tmp[size()]; - __at_align32__ float res[size()]; + __at_align__ float tmp[size()]; + __at_align__ float res[size()]; store(tmp); for (int i = 0; i < size(); i++) { if (_isnan(tmp[i])) { @@ -297,7 +297,7 @@ template <> class Vectorized { return loadu(res); }; Vectorized map(float (*const f)(float)) const { - __at_align32__ float tmp[size()]; + __at_align__ float tmp[size()]; store(tmp); for (int64_t i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -332,8 +332,8 @@ template <> class Vectorized { return map(std::atan); } Vectorized atan2(const Vectorized &exp) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_exp[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_exp[size()]; store(tmp); exp.store(tmp_exp); for (int64_t i = 0; i < size(); i++) { @@ -342,8 +342,8 @@ template <> class Vectorized { return loadu(tmp); } Vectorized copysign(const Vectorized &sign) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_sign[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_sign[size()]; store(tmp); sign.store(tmp_sign); for (size_type i = 0; i < size(); i++) { @@ -367,8 +367,8 @@ template <> class Vectorized { return map(std::expm1); } Vectorized fmod(const Vectorized& q) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_q[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_q[size()]; store(tmp); q.store(tmp_q); for (int64_t i = 0; i < size(); i++) { @@ -377,8 +377,8 @@ template <> class Vectorized { return loadu(tmp); } Vectorized hypot(const Vectorized &b) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_b[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; store(tmp); b.store(tmp_b); for (int64_t i = 0; i < size(); i++) { @@ -393,8 +393,8 @@ template <> class Vectorized { return map(calc_i0e); } Vectorized igamma(const Vectorized &x) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_x[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { @@ -403,8 +403,8 @@ template <> class Vectorized { return loadu(tmp); } Vectorized igammac(const Vectorized &x) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_x[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; store(tmp); x.store(tmp_x); for (int64_t i = 0; i < size(); i++) { @@ -425,8 +425,8 @@ template <> class Vectorized { return map(std::log2); } Vectorized nextafter(const Vectorized &b) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_b[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_b[size()]; store(tmp); b.store(tmp_b); for (int64_t i = 0; i < size(); i++) { @@ -490,8 +490,8 @@ template <> class Vectorized { return this->sqrt().reciprocal(); } Vectorized pow(const Vectorized &exp) const { - __at_align32__ float tmp[size()]; - __at_align32__ float tmp_exp[size()]; + __at_align__ float tmp[size()]; + __at_align__ float tmp_exp[size()]; store(tmp); exp.store(tmp_exp); for (int64_t i = 0; i < size(); i++) { diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_int.h b/aten/src/ATen/cpu/vec/vec256/vec256_int.h index 6f85988bcf41b..86cf42556d192 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_int.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_int.h @@ -3,8 +3,8 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include -#include +#include +#include #include namespace at { @@ -55,7 +55,7 @@ class Vectorized : public Vectorizedi { } template static Vectorized blend(Vectorized a, Vectorized b) { - __at_align32__ int64_t tmp_values[size()]; + __at_align__ int64_t tmp_values[size()]; a.store(tmp_values); if (mask & 0x01) tmp_values[0] = _mm256_extract_epi64(b.values, 0); @@ -93,7 +93,7 @@ class Vectorized : public Vectorizedi { return _mm256_loadu_si256(reinterpret_cast(ptr)); } static Vectorized loadu(const void* ptr, int64_t count) { - __at_align32__ int64_t tmp_values[size()]; + __at_align__ int64_t tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -109,7 +109,7 @@ class Vectorized : public Vectorizedi { // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { - __at_align32__ int64_t tmp_values[size()]; + __at_align__ int64_t tmp_values[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(int64_t)); } @@ -216,7 +216,7 @@ class Vectorized : public Vectorizedi { return _mm256_loadu_si256(reinterpret_cast(ptr)); } static Vectorized loadu(const void* ptr, int32_t count) { - __at_align32__ int32_t tmp_values[size()]; + __at_align__ int32_t tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -232,7 +232,7 @@ class Vectorized : public Vectorizedi { // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { - __at_align32__ int32_t tmp_values[size()]; + __at_align__ int32_t tmp_values[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(int32_t)); } @@ -346,7 +346,7 @@ class Vectorized : public Vectorizedi { } template static Vectorized blend(Vectorized a, Vectorized b) { - __at_align32__ int16_t tmp_values[size()]; + __at_align__ int16_t tmp_values[size()]; a.store(tmp_values); if (mask & 0x01) tmp_values[0] = _mm256_extract_epi16(b.values, 0); @@ -436,7 +436,7 @@ class Vectorized : public Vectorizedi { return _mm256_loadu_si256(reinterpret_cast(ptr)); } static Vectorized loadu(const void* ptr, int16_t count) { - __at_align32__ int16_t tmp_values[size()]; + __at_align__ int16_t tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -452,7 +452,7 @@ class Vectorized : public Vectorizedi { // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { - __at_align32__ int16_t tmp_values[size()]; + __at_align__ int16_t tmp_values[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); } @@ -527,7 +527,7 @@ class Vectorized : public Vectorizedi { } template static Vectorized blend(Vectorized a, Vectorized b) { - __at_align32__ int8_t tmp_values[size()]; + __at_align__ int8_t tmp_values[size()]; a.store(tmp_values); if (mask & 0x01) tmp_values[0] = _mm256_extract_epi8(b.values, 0); @@ -685,7 +685,7 @@ class Vectorized : public Vectorizedi { return _mm256_loadu_si256(reinterpret_cast(ptr)); } static Vectorized loadu(const void* ptr, int8_t count) { - __at_align32__ int8_t tmp_values[size()]; + __at_align__ int8_t tmp_values[size()]; // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two // instructions while a loop would be compiled to one instruction. @@ -701,7 +701,7 @@ class Vectorized : public Vectorizedi { // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm256-storeu-si256.html _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); } else if (count > 0) { - __at_align32__ int8_t tmp_values[size()]; + __at_align__ int8_t tmp_values[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(int8_t)); } diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h index 874be68e5235d..dc5e833127327 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h @@ -3,8 +3,8 @@ // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] -#include -#include +#include +#include #include #include #include @@ -39,7 +39,7 @@ namespace at { namespace vec { namespace { -#if (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) +#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) struct Vectorizedqi { protected: @@ -53,7 +53,6 @@ struct Vectorizedqi { } }; -#if defined(CPU_CAPABILITY_AVX2) template __m256i pack_saturate_and_clamp( __m256i first, @@ -94,7 +93,6 @@ __m256i pack_saturate_and_clamp( _mm256_set1_epi8(min_val), _mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val))); } -#endif template inline void __attribute__((always_inline)) QuantizeAvx2( @@ -103,7 +101,6 @@ inline void __attribute__((always_inline)) QuantizeAvx2( int len, float inverse_scale, int64_t zero_point) { -#if defined(CPU_CAPABILITY_AVX2) constexpr int VLEN = 8; constexpr auto min_val = std::numeric_limits::min(); constexpr auto max_val = std::numeric_limits::max(); @@ -212,10 +209,6 @@ inline void __attribute__((always_inline)) QuantizeAvx2( std::min(std::max(transformed, float(min_val)), float(max_val)); dst[i] = clipped; } -#else - at::native::quantize_vec( - 1.0f / inverse_scale, zero_point, src, reinterpret_cast(dst), len); -#endif } template<> @@ -266,11 +259,7 @@ struct Vectorized : public Vectorizedqi { Vectorized zero_point, Vectorized scale_zp_premul) const { __m256 float_vals = _mm256_cvtepi32_ps(vals); -#if defined(CPU_CAPABILITY_AVX2) return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; -#else - return {scale * (Vectorized(float_vals) - zero_point)}; -#endif } static Vectorized quantize( @@ -286,39 +275,11 @@ struct Vectorized : public Vectorizedqi { } Vectorized maximum(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_max_epi32(vals, b.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(int_vals.data()), vals); - std::array b_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(b_vals.data()), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::max(int_vals[i], b_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } Vectorized minimum(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_min_epi32(vals, b.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array b_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&b_vals), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::min(int_vals[i], b_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } Vectorized relu(Vectorized zero_point) const { @@ -328,65 +289,24 @@ struct Vectorized : public Vectorizedqi { Vectorized relu6( Vectorized zero_point, Vectorized q_six) { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_min_epi32( _mm256_max_epi32(vals, zero_point.vals), q_six.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array zero_point_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals); - std::array q_six_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::min( - std::max(int_vals[i], zero_point_vals[i]), q_six_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } int_vec_return_type widening_subtract(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return {_mm256_sub_epi32(vals, b)}; -#else - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array b_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&b_vals), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = int_vals[i] - b_vals[i]; - } - return {_mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals))}; -#endif } static Vectorized requantize_from_int( const int_vec_return_type& inp, float multiplier, int32_t zero_point) { -#ifdef CPU_CAPABILITY_AVX2 __m256 multiplier_v = _mm256_set1_ps(multiplier); __m256i zero_point_v = _mm256_set1_epi32(zero_point); __m256 scaled = _mm256_mul_ps(_mm256_cvtepi32_ps(inp[0]), multiplier_v); __m256i rounded = _mm256_cvtps_epi32(scaled); return _mm256_add_epi32(rounded, zero_point_v); -#else - std::array inp_vals; - inp[0].store(inp_vals.data()); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = - nearbyint(static_cast(inp_vals[i]) * multiplier) + - zero_point; - } - return loadu(result_vals.data()); -#endif } void dump() const { @@ -411,43 +331,16 @@ template <> Vectorized inline operator*( const Vectorized& a, const Vectorized& b) { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_mullo_epi32(a, b); -#else - // Pray the compiler can autovectorize this - std::array::size()> a_vals; - std::array::size()> b_vals; - a.store(a_vals.data()); - b.store(b_vals.data()); - std::array::size()> result_vals; - for (size_t i = 0; i < std::decay_t::size(); ++i) { - result_vals[i] = a_vals[i] * b_vals[i]; - } - return Vectorized::loadu(result_vals.data()); -#endif } template <> Vectorized inline operator+( const Vectorized& a, const Vectorized& b) { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_add_epi32(a, b); -#else - // Pray the compiler can autovectorize this - std::array::size()> a_vals; - std::array::size()> b_vals; - a.store(a_vals.data()); - b.store(b_vals.data()); - std::array::size()> result_vals; - for (size_t i = 0; i < std::decay_t::size(); ++i) { - result_vals[i] = a_vals[i] + b_vals[i]; - } - return Vectorized::loadu(result_vals.data()); -#endif } -#ifdef CPU_CAPABILITY_AVX2 /* * Convert values from int32 back to int8/uint8 */ @@ -493,7 +386,6 @@ __m256i RequantizeAvx2( xyzw_clamped_v = _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); return xyzw_clamped_v; } -#endif template<> struct Vectorized : public Vectorizedqi { @@ -544,21 +436,7 @@ struct Vectorized : public Vectorizedqi { private: __m256i cvtepi8_epi32(__m128i epi8_vals) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_cvtepi8_epi32(epi8_vals); -#else // CPU_CAPABILITY_AVX2 - __m128i result_data[2]; - __m128i unpacked1 = _mm_unpacklo_epi8(epi8_vals, epi8_vals); - __m128i unpacked2 = _mm_unpacklo_epi16(unpacked1, unpacked1); - __m128i shifted1 = _mm_srli_si128(epi8_vals, 4); - __m128i shifted2 = _mm_srai_epi32(unpacked2, 24); - result_data[0] = shifted2; - __m128i unpacked3 = _mm_unpacklo_epi8(shifted1, shifted1); - __m128i unpacked4 = _mm_unpacklo_epi16(unpacked3, unpacked3); - __m128i shifted3 = _mm_srai_epi32(unpacked4, 24); - result_data[1] = shifted3; - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_data)); -#endif } public: @@ -576,7 +454,6 @@ struct Vectorized : public Vectorizedqi { __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2)); __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3)); -#if defined(CPU_CAPABILITY_AVX2) auto val0 = vec::fmadd(scale, Vectorized(float_val0), scale_neg_zp_premul); auto val1 = @@ -585,12 +462,6 @@ struct Vectorized : public Vectorizedqi { vec::fmadd(scale, Vectorized(float_val2), scale_neg_zp_premul); auto val3 = vec::fmadd(scale, Vectorized(float_val3), scale_neg_zp_premul); -#else - auto val0 = scale * (Vectorized(float_val0) - zero_point); - auto val1 = scale * (Vectorized(float_val1) - zero_point); - auto val2 = scale * (Vectorized(float_val2) - zero_point); - auto val3 = scale * (Vectorized(float_val3) - zero_point); -#endif return {val0, val1, val2, val3}; } @@ -607,39 +478,11 @@ struct Vectorized : public Vectorizedqi { } Vectorized maximum(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_max_epi8(vals, b.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array b_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&b_vals), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::max(int_vals[i], b_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } Vectorized minimum(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_min_epi8(vals, b.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array b_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&b_vals), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::min(int_vals[i], b_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } Vectorized relu(Vectorized zero_point) const { @@ -649,29 +492,11 @@ struct Vectorized : public Vectorizedqi { Vectorized relu6( Vectorized zero_point, Vectorized q_six) { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_min_epi8( _mm256_max_epi8(vals, zero_point.vals), q_six.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array zero_point_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals); - std::array q_six_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::min( - std::max(int_vals[i], zero_point_vals[i]), q_six_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } int_vec_return_type widening_subtract(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); @@ -701,55 +526,15 @@ struct Vectorized : public Vectorizedqi { Vectorized(res_1), Vectorized(res_2), Vectorized(res_3)}; -#else - // Pray the compiler can autovectorize this - std::array int_vals; - store(int_vals.data()); - std::array b_vals; - b.store(b_vals.data()); - constexpr int elem_per_int_vec = size() / int_num_vecs(); - int32_t rv[int_num_vecs()][elem_per_int_vec]; - for (size_t i = 0; i < int_num_vecs(); ++i) { - for (size_t j = 0; j < elem_per_int_vec; ++j) { - rv[i][j] = static_cast(int_vals[i * elem_per_int_vec + j]) - - static_cast(b_vals[i * elem_per_int_vec + j]); - } - } - return {Vectorized::loadu(rv[0]), - Vectorized::loadu(rv[1]), - Vectorized::loadu(rv[2]), - Vectorized::loadu(rv[3])}; -#endif } static Vectorized requantize_from_int( const int_vec_return_type& inp, float multiplier, int32_t zero_point) { -#ifdef CPU_CAPABILITY_AVX2 __m256 multiplier_v = _mm256_set1_ps(multiplier); __m256i zero_point_v = _mm256_set1_epi32(zero_point); return RequantizeAvx2(inp, multiplier_v, zero_point_v); -#else - // Pray the compiler can autovectorize this - constexpr int elem_per_int_vec = size() / int_num_vecs(); - constexpr auto min_val = std::numeric_limits::min(); - constexpr auto max_val = std::numeric_limits::max(); - int32_t rv[int_num_vecs()][elem_per_int_vec]; - for (size_t i = 0; i < int_num_vecs(); ++i) { - inp[i].store(rv[i]); - } - std::array result_vals; - for (size_t i = 0; i < int_num_vecs(); ++i) { - for (size_t j = 0; j < elem_per_int_vec; ++j) { - int32_t rounded = - nearbyint(static_cast(rv[i][j]) * multiplier) + zero_point; - result_vals[i * elem_per_int_vec + j] = - std::min(std::max(rounded, min_val), max_val); - } - } - return loadu(result_vals.data()); -#endif } void dump() const { @@ -817,20 +602,7 @@ struct Vectorized : public Vectorizedqi { private: __m256i cvtepu8_epi32(__m128i epu8_vals) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_cvtepu8_epi32(epu8_vals); -#else // CPU_CAPABILITY_AVX2 - __m128i result_data[2]; - __m128i zeros = _mm_setzero_si128(); - __m128i unpacked1 = _mm_unpacklo_epi8(epu8_vals, zeros); - __m128i unpacked2 = _mm_unpacklo_epi16(unpacked1, zeros); - result_data[0] = unpacked2; - __m128i shifted = _mm_srli_si128(epu8_vals, 4); - __m128i unpacked3 = _mm_unpacklo_epi8(shifted, zeros); - __m128i unpacked4 = _mm_unpacklo_epi16(unpacked3, zeros); - result_data[1] = unpacked4; - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_data)); -#endif } public: @@ -848,7 +620,6 @@ struct Vectorized : public Vectorizedqi { __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2)); __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3)); -#if defined(CPU_CAPABILITY_AVX2) auto val0 = vec::fmadd(scale, Vectorized(float_val0), scale_zp_premul); auto val1 = @@ -857,12 +628,6 @@ struct Vectorized : public Vectorizedqi { vec::fmadd(scale, Vectorized(float_val2), scale_zp_premul); auto val3 = vec::fmadd(scale, Vectorized(float_val3), scale_zp_premul); -#else - auto val0 = scale * (Vectorized(float_val0) - zero_point); - auto val1 = scale * (Vectorized(float_val1) - zero_point); - auto val2 = scale * (Vectorized(float_val2) - zero_point); - auto val3 = scale * (Vectorized(float_val3) - zero_point); -#endif return {val0, val1, val2, val3}; } @@ -879,39 +644,11 @@ struct Vectorized : public Vectorizedqi { } Vectorized maximum(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_max_epu8(vals, b.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array b_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&b_vals), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::max(int_vals[i], b_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } Vectorized minimum(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_min_epu8(vals, b.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array b_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&b_vals), b.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::min(int_vals[i], b_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } Vectorized relu(Vectorized zero_point) const { @@ -921,29 +658,11 @@ struct Vectorized : public Vectorizedqi { Vectorized relu6( Vectorized zero_point, Vectorized q_six) { -#ifdef CPU_CAPABILITY_AVX2 return _mm256_min_epu8( _mm256_max_epu8(vals, zero_point.vals), q_six.vals); -#else - // Pray the compiler can autovectorize this - std::array int_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - std::array zero_point_vals; - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals); - std::array q_six_vals; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals); - std::array result_vals; - for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::min( - std::max(int_vals[i], zero_point_vals[i]), q_six_vals[i]); - } - return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); -#endif } int_vec_return_type widening_subtract(Vectorized b) const { -#ifdef CPU_CAPABILITY_AVX2 __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); @@ -972,55 +691,15 @@ struct Vectorized : public Vectorizedqi { Vectorized(res_1), Vectorized(res_2), Vectorized(res_3)}; -#else - // Pray the compiler can autovectorize this - std::array int_vals; - std::array b_vals; - store(int_vals.data()); - b.store(b_vals.data()); - static constexpr int elem_per_int_vec = size() / int_num_vecs(); - int32_t rv[int_num_vecs()][elem_per_int_vec]; - for (size_t i = 0; i < int_num_vecs(); ++i) { - for (size_t j = 0; j < elem_per_int_vec; ++j) { - rv[i][j] = static_cast(int_vals[i * elem_per_int_vec + j]) - - static_cast(b_vals[i * elem_per_int_vec + j]); - } - } - return {Vectorized::loadu(rv[0]), - Vectorized::loadu(rv[1]), - Vectorized::loadu(rv[2]), - Vectorized::loadu(rv[3])}; -#endif } static Vectorized requantize_from_int( const int_vec_return_type& inp, float multiplier, int32_t zero_point) { -#ifdef CPU_CAPABILITY_AVX2 __m256 multiplier_v = _mm256_set1_ps(multiplier); __m256i zero_point_v = _mm256_set1_epi32(zero_point); return RequantizeAvx2(inp, multiplier_v, zero_point_v); -#else - // Pray the compiler can autovectorize this - constexpr int elem_per_int_vec = size() / int_num_vecs(); - constexpr auto min_val = std::numeric_limits::min(); - constexpr auto max_val = std::numeric_limits::max(); - int32_t rv[int_num_vecs()][elem_per_int_vec]; - for (size_t i = 0; i < int_num_vecs(); ++i) { - inp[i].store(rv[i]); - } - std::array result_vals; - for (size_t i = 0; i < int_num_vecs(); ++i) { - for (size_t j = 0; j < elem_per_int_vec; ++j) { - int32_t rounded = - nearbyint(static_cast(rv[i][j]) * multiplier) + zero_point; - result_vals[i * elem_per_int_vec + j] = - std::min(std::max(rounded, min_val), max_val); - } - } - return loadu(result_vals.data()); -#endif } void dump() const { @@ -1497,6 +1176,5 @@ Vectorized inline maximum(const Vectorized& a, const V return a.maximum(b); } -#endif // (defined(CPU_CAPABILITY_AVX) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) - +#endif // if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) }}} diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h index 3c42b60164511..3d798a7f62689 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include #include diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h index ce59bae3f4ffc..888f2f0b932b2 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h @@ -1,6 +1,6 @@ #pragma once -#include -#include +#include +#include #include #include @@ -141,7 +141,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return { @@ -153,7 +153,7 @@ class Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); std::memcpy( @@ -165,7 +165,7 @@ class Vectorized { ComplexDbl& operator[](int idx) = delete; Vectorized map(ComplexDbl (*const f)(ComplexDbl)) const { - __at_align32__ ComplexDbl tmp[size()]; + __at_align__ ComplexDbl tmp[size()]; store(tmp); for (int i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -174,7 +174,7 @@ class Vectorized { } Vectorized map(ComplexDbl (*const f)(const ComplexDbl&)) const { - __at_align32__ ComplexDbl tmp[size()]; + __at_align__ ComplexDbl tmp[size()]; store(tmp); for (int i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -455,8 +455,8 @@ class Vectorized { } Vectorized pow(const Vectorized& exp) const { - __at_align32__ ComplexDbl x_tmp[size()]; - __at_align32__ ComplexDbl y_tmp[size()]; + __at_align__ ComplexDbl x_tmp[size()]; + __at_align__ ComplexDbl y_tmp[size()]; store(x_tmp); exp.store(y_tmp); for (int i = 0; i < size(); i++) { diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h index f96488964bb9f..0aa726b9bfdd6 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include @@ -196,7 +196,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return { @@ -209,7 +209,7 @@ class Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, reinterpret_cast(tmp_values)); vec_vsx_st(_vec1, offset16, reinterpret_cast(tmp_values)); std::memcpy( @@ -221,7 +221,7 @@ class Vectorized { ComplexFlt& operator[](int idx) = delete; Vectorized map(ComplexFlt (*const f)(ComplexFlt)) const { - __at_align32__ ComplexFlt tmp[size()]; + __at_align__ ComplexFlt tmp[size()]; store(tmp); for (int i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -230,7 +230,7 @@ class Vectorized { } Vectorized map(ComplexFlt (*const f)(const ComplexFlt&)) const { - __at_align32__ ComplexFlt tmp[size()]; + __at_align__ ComplexFlt tmp[size()]; store(tmp); for (int i = 0; i < size(); i++) { tmp[i] = f(tmp[i]); @@ -434,8 +434,8 @@ class Vectorized { } Vectorized pow(const Vectorized& exp) const { - __at_align32__ ComplexFlt x_tmp[size()]; - __at_align32__ ComplexFlt y_tmp[size()]; + __at_align__ ComplexFlt x_tmp[size()]; + __at_align__ ComplexFlt y_tmp[size()]; store(x_tmp); exp.store(y_tmp); for (int i = 0; i < size(); i++) { diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h index ac0a131878a02..29616182fe12b 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include @@ -169,7 +169,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; @@ -179,7 +179,7 @@ class Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h index 5fd1fb9afc80b..2427276bcea2c 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include namespace at { @@ -180,7 +180,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; @@ -190,7 +190,7 @@ class Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h index 16a535fd1d10d..bd179883c9bfe 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include namespace at { namespace vec { @@ -269,7 +269,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; @@ -279,7 +279,7 @@ class Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy(ptr, tmp_values, std::min(count, size()) * sizeof(value_type)); diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h index 759c497396547..460f49cbc8ddb 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include namespace at { namespace vec { @@ -199,7 +199,7 @@ class Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; @@ -209,7 +209,7 @@ class Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h index d2fbf4d51cf03..fea094029653b 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include namespace at { namespace vec { @@ -148,7 +148,7 @@ class Vectorized { (vint64)vec_vsx_ld(offset16, dptr)}; } - __at_align32__ double tmp_values[size()]; + __at_align__ double tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return { @@ -161,7 +161,7 @@ class Vectorized { vec_vsx_st((vfloat64)_vec0, offset0, dptr); vec_vsx_st((vfloat64)_vec1, offset16, dptr); } else if (count > 0) { - __at_align32__ double tmp_values[size()]; + __at_align__ double tmp_values[size()]; vec_vsx_st((vfloat64)_vec0, offset0, tmp_values); vec_vsx_st((vfloat64)_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h index ac2047d75ba26..ed457b9adefc8 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include #include @@ -81,7 +81,7 @@ struct Vectorized { vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; @@ -91,7 +91,7 @@ struct Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h index 728ff51d71d77..f2a8446cd0ed9 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include #include @@ -91,7 +91,7 @@ struct Vectorized { vec_vsx_ld(offset0, reinterpret_cast(ptr)), vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; } @@ -100,7 +100,7 @@ struct Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h index 4994abe9f13ab..c335ace0ced6f 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h @@ -1,7 +1,7 @@ #pragma once -#include -#include +#include +#include #include #include #include @@ -92,7 +92,7 @@ struct Vectorized { vec_vsx_ld(offset0, reinterpret_cast(ptr)), vec_vsx_ld(offset16, reinterpret_cast(ptr))}; } - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; std::memcpy(tmp_values, ptr, std::min(count, size()) * sizeof(value_type)); return {vec_vsx_ld(offset0, tmp_values), vec_vsx_ld(offset16, tmp_values)}; } @@ -101,7 +101,7 @@ struct Vectorized { vec_vsx_st(_vec0, offset0, reinterpret_cast(ptr)); vec_vsx_st(_vec1, offset16, reinterpret_cast(ptr)); } else if (count > 0) { - __at_align32__ value_type tmp_values[size()]; + __at_align__ value_type tmp_values[size()]; vec_vsx_st(_vec0, offset0, tmp_values); vec_vsx_st(_vec1, offset16, tmp_values); std::memcpy( diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h b/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h index 8bc943d62dcb5..afd21b09b4584 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vsx_helpers.h @@ -1,7 +1,7 @@ #pragma once #include #include -#include +#include using vbool8 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) char; using vbool16 = __attribute__((altivec(vector__))) __attribute__((altivec(bool__))) short; diff --git a/aten/src/ATen/cpu/vec/vec512/vec512.h b/aten/src/ATen/cpu/vec/vec512/vec512.h new file mode 100644 index 0000000000000..6f53067c0efa7 --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512.h @@ -0,0 +1,195 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + + +#include +#include +#include +#include +#include + +namespace at { +namespace vec { + +// See Note [Acceptable use of anonymous namespace in header] +namespace { + + C10_UNUSED std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) { + stream << val.val_; + return stream; + } + C10_UNUSED std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) { + stream << static_cast(val.val_); + return stream; + } + C10_UNUSED std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) { + stream << static_cast(val.val_); + return stream; + } + +template +std::ostream& operator<<(std::ostream& stream, const Vectorized& vec) { + T buf[Vectorized::size()]; + vec.store(buf); + stream << "vec["; + for (int i = 0; i != Vectorized::size(); i++) { + if (i != 0) { + stream << ", "; + } + stream << buf[i]; + } + stream << "]"; + return stream; +} + + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template<> +inline Vectorized cast(const Vectorized& src) { + return _mm512_castpd_ps(src); +} + +template<> +inline Vectorized cast(const Vectorized& src) { + return _mm512_castps_pd(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std::enable_if_t> +inline gather(const double* base_addr, const Vectorized& vindex) { + return _mm512_i64gather_pd(vindex, base_addr, scale); +} + +template +std::enable_if_t> +inline gather(const float* base_addr, const Vectorized& vindex) { + return _mm512_i32gather_ps(vindex, base_addr, scale); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template +std::enable_if_t> +inline mask_gather(const Vectorized& src, const double* base_addr, + const Vectorized& vindex, const Vectorized& mask) { + auto all_ones = _mm512_castsi512_pd(_mm512_set1_epi64(0xFFFFFFFFFFFFFFFF)); + auto mask_ = _mm512_cmp_pd_mask(all_ones, mask.values, _CMP_EQ_OQ); + return _mm512_mask_i64gather_pd(src, mask_, vindex, base_addr, scale); +} + +template +std::enable_if_t> +inline mask_gather(const Vectorized& src, const float* base_addr, + const Vectorized& vindex, const Vectorized& mask) { + auto all_ones = _mm512_castsi512_ps(_mm512_set1_epi32(0xFFFFFFFF)); + auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ); + return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template<> +Vectorized +inline convert_to_int_of_same_size(const Vectorized &src) { + return _mm512_cvtpd_epi64(src); +} + +template<> +Vectorized +inline convert_to_int_of_same_size(const Vectorized &src) { + return _mm512_cvttps_epi32(src); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> +inline interleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, a1, a3, a3, a4, a5, a6, a7} + // b = {b0, b1, b2, b3, b4, b5, b6, b7} + // group cols crossing lanes: + // return {a0, b0, a1, b1, a2, b2, a3, b3} + // {a4, b4, a5, b5, a6, b6, a7, b7} + __m512i idx1 = _mm512_set_epi64(11, 3, 10, 2, 9, 1, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 14, 6, 13, 5, 12, 4); + return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); +} + +template <> +std::pair, Vectorized> +inline interleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} + // b = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} + // + // return: + // {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} + // {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} + __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, + 19, 3, 18, 2, 17, 1, 16, 0); + __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, + 27, 11, 26, 10, 25, 9, 24, 8); + return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); +} + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template <> +std::pair, Vectorized> +inline deinterleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3} + // b = {a4, b4, a5, b5, a6, b6, a7, b7} + // output: + // return {a0, a1, a2, a3, a4, a5, a6, a7} + // {b0, b1, b2, b3, b4, b5, b6, b7} + // The members of indices have been written in binary format for better understandability + __m512i idx1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1); + + return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); +} + +template <> +std::pair, Vectorized> +inline deinterleave2(const Vectorized& a, const Vectorized& b) { + // inputs: + // a = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} + // b = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} + // output: + // return {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} + // {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} + __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, + 14, 12, 10, 8, 6, 4, 2, 0); + __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, + 15, 13, 11, 9, 7, 5, 3, 1); + + return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); +} + +#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h new file mode 100644 index 0000000000000..4a240bb36d301 --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h @@ -0,0 +1,879 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) +#include +#endif + +namespace at { +namespace vec { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +static inline void cvtbf16_fp32(const __m256i& a, __m512& o) { + o = _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)); +} + +static inline void cvtbf16_fp32(const __m512i& a, __m512& o1, __m512& o2) { + __m256i lo = _mm512_extracti32x8_epi32(a, 0); + __m256i hi = _mm512_extracti32x8_epi32(a, 1); + cvtbf16_fp32(lo, o1); + cvtbf16_fp32(hi, o2); +} + +static inline __m512i cvtfp32_bf16(const __m512& a, const __m512& b) { + __m512i lo = _mm512_castps_si512(a); + __m512i hi = _mm512_castps_si512(b); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_lo = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q); + auto mask_hi = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_lo = _mm512_and_si512(_mm512_srli_epi32(lo, 16), ones); + auto t_hi = _mm512_and_si512(_mm512_srli_epi32(hi, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_lo = _mm512_add_epi32(t_lo, vec_bias); + t_hi = _mm512_add_epi32(t_hi, vec_bias); + // input += rounding_bias; + t_lo = _mm512_add_epi32(t_lo, lo); + t_hi = _mm512_add_epi32(t_hi, hi); + // input = input >> 16; + t_lo = _mm512_srli_epi32(t_lo, 16); + t_hi = _mm512_srli_epi32(t_hi, 16); + // Check NaN before converting back to bf16 + t_lo = _mm512_mask_blend_epi32(mask_lo, nan, t_lo); + t_hi = _mm512_mask_blend_epi32(mask_hi, nan, t_hi); + + t_lo = _mm512_packus_epi32(t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4] + __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_permutexvar_epi64(idx, t_lo); +} + +static inline __m512i merge_compare_result(const __m512& a, const __m512& b) { + __m512i lo = _mm512_castps_si512(a); + __m512i hi = _mm512_castps_si512(b); + lo = _mm512_srli_epi32(lo, 16); + hi = _mm512_srli_epi32(hi, 16); + auto out = _mm512_packus_epi32(lo, hi); + __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); + return _mm512_permutexvar_epi64(idx, out); +} + +template <> class Vectorized { +private: + __m512i values; +public: + using value_type = uint16_t; + using size_type = int; + static constexpr size_type size() { + return 32; + } + Vectorized() {} + Vectorized(__m512i v) : values(v) {} + Vectorized(BFloat16 val) { + value_type uw = val.x; + values = _mm512_set1_epi16(uw); + } + Vectorized(BFloat16 val1, BFloat16 val2, BFloat16 val3, BFloat16 val4, + BFloat16 val5, BFloat16 val6, BFloat16 val7, BFloat16 val8, + BFloat16 val9, BFloat16 val10, BFloat16 val11, BFloat16 val12, + BFloat16 val13, BFloat16 val14, BFloat16 val15, BFloat16 val16, + BFloat16 val17, BFloat16 val18, BFloat16 val19, BFloat16 val20, + BFloat16 val21, BFloat16 val22, BFloat16 val23, BFloat16 val24, + BFloat16 val25, BFloat16 val26, BFloat16 val27, BFloat16 val28, + BFloat16 val29, BFloat16 val30, BFloat16 val31, BFloat16 val32) { + values = _mm512_set_epi16( + val32.x, val31.x, val30.x, val29.x, val28.x, val27.x, val26.x, val25.x, + val24.x, val23.x, val22.x, val21.x, val20.x, val19.x, val18.x, val17.x, + val16.x, val15.x, val14.x, val13.x, val12.x, val11.x, val10.x, val9.x, + val8.x, val7.x, val6.x, val5.x, val4.x, val3.x, val2.x, val1.x); + } + operator __m512i() const { + return values; + } + BFloat16& operator[](int idx) = delete; + const BFloat16& operator[](int idx) const = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit + return _mm512_cmpeq_epi16_mask(values, _mm512_set1_epi16(0)); + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int16_t count) { + __at_align__ int16_t tmp_values[size()]; + std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __at_align__ int16_t tmp_values[size()]; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); + } + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + __at_align__ int16_t tmp_values[size()]; + a.store(tmp_values); + if (mask & 0x01) + tmp_values[0] = b.values[31]; + if (mask & 0x02) + tmp_values[1] = b.values[30]; + if (mask & 0x04) + tmp_values[2] = b.values[29]; + if (mask & 0x08) + tmp_values[3] = b.values[28]; + if (mask & 0x10) + tmp_values[4] = b.values[27]; + if (mask & 0x20) + tmp_values[5] = b.values[26]; + if (mask & 0x40) + tmp_values[6] = b.values[25]; + if (mask & 0x80) + tmp_values[7] = b.values[24]; + if (mask & 0x100) + tmp_values[8] = b.values[23]; + if (mask & 0x200) + tmp_values[9] = b.values[22]; + if (mask & 0x400) + tmp_values[10] = b.values[21]; + if (mask & 0x800) + tmp_values[11] = b.values[20]; + if (mask & 0x1000) + tmp_values[12] = b.values[19]; + if (mask & 0x2000) + tmp_values[13] = b.values[18]; + if (mask & 0x4000) + tmp_values[14] = b.values[17]; + if (mask & 0x8000) + tmp_values[15] = b.values[16]; + if (mask & 0x10000) + tmp_values[16] = b.values[15]; + if (mask & 0x20000) + tmp_values[17] = b.values[14]; + if (mask & 0x40000) + tmp_values[18] = b.values[13]; + if (mask & 0x80000) + tmp_values[19] = b.values[12]; + if (mask & 0x100000) + tmp_values[20] = b.values[11]; + if (mask & 0x200000) + tmp_values[21] = b.values[10]; + if (mask & 0x400000) + tmp_values[22] = b.values[9]; + if (mask & 0x800000) + tmp_values[23] = b.values[8]; + if (mask & 0x1000000) + tmp_values[24] = b.values[7]; + if (mask & 0x2000000) + tmp_values[25] = b.values[6]; + if (mask & 0x4000000) + tmp_values[26] = b.values[5]; + if (mask & 0x8000000) + tmp_values[27] = b.values[4]; + if (mask & 0x10000000) + tmp_values[28] = b.values[3]; + if (mask & 0x20000000) + tmp_values[29] = b.values[2]; + if (mask & 0x40000000) + tmp_values[30] = b.values[1]; + if (mask & 0x80000000) + tmp_values[31] = b.values[0]; + return loadu(tmp_values); + } + static Vectorized blendv(const Vectorized& a, + const Vectorized& b, const Vectorized& mask) { + auto all_ones = _mm512_set1_epi16(0xFFFF); + auto mask_ = _mm512_cmp_epi16_mask(mask, all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi16(mask_, a.values, b.values); + } + template + static Vectorized arange(BFloat16 base = 0.f, step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, + base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, + base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step, + base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step, + base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step, + base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step, + base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step); + } + static Vectorized set(const Vectorized& a, + const Vectorized& b, int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + case 16: + return blend<65535>(a, b); + case 17: + return blend<131071>(a, b); + case 18: + return blend<262143>(a, b); + case 19: + return blend<524287>(a, b); + case 20: + return blend<1048575>(a, b); + case 21: + return blend<2097151>(a, b); + case 22: + return blend<4194303>(a, b); + case 23: + return blend<8388607>(a, b); + case 24: + return blend<16777215>(a, b); + case 25: + return blend<33554431>(a, b); + case 26: + return blend<67108863>(a, b); + case 27: + return blend<134217727>(a, b); + case 28: + return blend<268435455>(a, b); + case 29: + return blend<536870911>(a, b); + case 30: + return blend<1073741823>(a, b); + case 31: + return blend<2147483647>(a, b); + } + return b; + } + Vectorized map(const __m512 (*const vop)(__m512)) const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + const auto o1 = vop(lo); + const auto o2 = vop(hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized abs() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + const auto mask = _mm512_set1_ps(-0.f); + const auto o1 = _mm512_andnot_ps(mask, lo); + const auto o2 = _mm512_andnot_ps(mask, hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized angle() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto angle_lambda = [](__m512 values) { + const auto zero_vec = _mm512_set1_ps(0.f); + const auto nan_vec = _mm512_set1_ps(NAN); + const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ); + const auto non_nan_mask_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec), + not_nan_mask, 0xFFFFFFFF); + const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(non_nan_mask_vec), + zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_ps(c10::pi); + + const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec); + return angle; + }; + auto o1 = angle_lambda(lo); + auto o2 = angle_lambda(hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return map(Sleef_acosf16_u10); + } + Vectorized asin() const { + return map(Sleef_asinf16_u10); + } + Vectorized atan() const { + return map(Sleef_atanf16_u10); + } + Vectorized atan2(const Vectorized &b) const { + __m512 lo, hi; + __m512 b1, b2; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(b.values, b1, b2); + auto o1 = Sleef_atan2f16_u10(lo, b1); + auto o2 = Sleef_atan2f16_u10(hi, b2); + return cvtfp32_bf16(o1, o2); + } + Vectorized copysign(const Vectorized &sign) const { + // copy sign bit (0x8000) from sign and remaining bits from values + __m512i mask_value = _mm512_set1_epi32(~0x80008000); + __m512i mask_signbit = _mm512_set1_epi32(0x80008000); + return Vectorized( + _mm512_or_si512( + _mm512_and_si512(values, mask_value), + _mm512_and_si512(sign, mask_signbit))); + } + Vectorized erf() const { + return map(Sleef_erff16_u10); + } + Vectorized erfc() const { + return map(Sleef_erfcf16_u15); + } + Vectorized erfinv() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_erfinv(tmp1[i]); + tmp2[i] = calc_erfinv(tmp2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } + Vectorized exp() const { + return map(Sleef_expf16_u10); + } + Vectorized expm1() const { + return map(Sleef_expm1f16_u10); + } + Vectorized fmod(const Vectorized & q) const { + __m512 x_lo, x_hi; + cvtbf16_fp32(values, x_lo, x_hi); + __m512 q_lo, q_hi; + cvtbf16_fp32(q.values, q_lo, q_hi); + auto o1 = Sleef_fmodf16(x_lo, q_lo); + auto o2 = Sleef_fmodf16(x_hi, q_hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized hypot(const Vectorized &b) const { + __m512 lo, hi; + __m512 b1, b2; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(b.values, b1, b2); + auto o1 = Sleef_hypotf16_u05(lo, b1); + auto o2 = Sleef_hypotf16_u05(hi, b2); + return cvtfp32_bf16(o1, o2); + } + Vectorized i0() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_i0(tmp1[i]); + tmp2[i] = calc_i0(tmp2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } + Vectorized i0e() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + constexpr auto sz = size(); + __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + + for (auto i = decltype(sz){0}; i < sz / 2; i++) { + tmp1[i] = calc_i0e(tmp1[i]); + tmp2[i] = calc_i0e(tmp2[i]); + } + const auto o1 = _mm512_loadu_ps(tmp1); + const auto o2 = _mm512_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } + Vectorized igamma(const Vectorized &x) const { + __m512 lo, hi; + __m512 xlo, xhi; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } + + Vectorized igammac(const Vectorized &x) const { + __m512 lo, hi; + __m512 xlo, xhi; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(x.values, xlo, xhi); + __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmp1), lo); + _mm512_storeu_ps(reinterpret_cast(tmp2), hi); + __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; + _mm512_storeu_ps(reinterpret_cast(tmpx1), xlo); + _mm512_storeu_ps(reinterpret_cast(tmpx2), xhi); + for (int64_t i = 0; i < size() / 2; ++i) { + tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]); + tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]); + } + auto o1 = _mm512_loadu_ps(tmp1); + auto o2 = _mm512_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } + Vectorized log() const { + return map(Sleef_logf16_u10); + } + Vectorized log2() const { + return map(Sleef_log2f16_u10); + } + Vectorized log10() const { + return map(Sleef_log10f16_u10); + } + Vectorized log1p() const { + return map(Sleef_log1pf16_u10); + } + Vectorized frac() const; + Vectorized sin() const { + return map(Sleef_sinf16_u10); + } + Vectorized sinh() const { + return map(Sleef_sinhf16_u10); + } + Vectorized cos() const { + return map(Sleef_cosf16_u10); + } + Vectorized cosh() const { + return map(Sleef_coshf16_u10); + } + Vectorized ceil() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto o1 = _mm512_ceil_ps(lo); + auto o2 = _mm512_ceil_ps(hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized floor() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto o1 = _mm512_floor_ps(lo); + auto o2 = _mm512_floor_ps(hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized neg() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto mask = _mm512_set1_ps(-0.f); + auto o1 = _mm512_xor_ps(mask, lo); + auto o2 = _mm512_xor_ps(mask, hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized round() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + return cvtfp32_bf16(o1, o2); + } + Vectorized tan() const { + return map(Sleef_tanf16_u10); + } + Vectorized tanh() const { + return map(Sleef_tanhf16_u10); + } + Vectorized trunc() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto o1 = _mm512_roundscale_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + auto o2 = _mm512_roundscale_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + return cvtfp32_bf16(o1, o2); + } + Vectorized lgamma() const { + return map(Sleef_lgammaf16_u10); + } + Vectorized sqrt() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto o1 = _mm512_sqrt_ps(lo); + auto o2 = _mm512_sqrt_ps(hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized reciprocal() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto ones = _mm512_set1_ps(1); + auto o1 = _mm512_div_ps(ones, lo); + auto o2 = _mm512_div_ps(ones, hi); + return cvtfp32_bf16(o1, o2); + } + Vectorized rsqrt() const { + __m512 lo, hi; + cvtbf16_fp32(values, lo, hi); + auto ones = _mm512_set1_ps(1); + auto o1 = _mm512_div_ps(ones, _mm512_sqrt_ps(lo)); + auto o2 = _mm512_div_ps(ones, _mm512_sqrt_ps(hi)); + return cvtfp32_bf16(o1, o2); + } + Vectorized pow(const Vectorized &b) const { + __m512 lo, hi; + __m512 b1, b2; + cvtbf16_fp32(values, lo, hi); + cvtbf16_fp32(b.values, b1, b2); + auto o1 = Sleef_powf16_u10(lo, b1); + auto o2 = Sleef_powf16_u10(hi, b2); + return cvtfp32_bf16(o1, o2); + } + + Vectorized inline operator>(const Vectorized& other) const; + Vectorized inline operator<(const Vectorized& other) const; + Vectorized inline operator>=(const Vectorized& other) const; + Vectorized inline operator<=(const Vectorized& other) const; + Vectorized inline operator==(const Vectorized& other) const; + Vectorized inline operator!=(const Vectorized& other) const; + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template +Vectorized static inline bfloat16_binary_op_as_fp32(const Vectorized& a, + const Vectorized& b, Op op) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return cvtfp32_bf16(o1, o2); +} + +template +Vectorized static inline bfloat16_compare_as_fp32(const Vectorized& a, + const Vectorized& b, Op op) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto o1 = op(a_lo, b_lo); + auto o2 = op(a_hi, b_hi); + return merge_compare_result(o1, o2); +} + +Vectorized inline Vectorized::operator>(const Vectorized& other) const { + return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GT_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); +} +Vectorized inline Vectorized::operator<(const Vectorized& other) const { + return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LT_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); +} +Vectorized inline Vectorized::operator>=(const Vectorized& other) const { + return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_GE_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); +} +Vectorized inline Vectorized::operator<=(const Vectorized& other) const { + return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_LE_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); +} +Vectorized inline Vectorized::operator==(const Vectorized& other) const { + return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_EQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); +} +Vectorized inline Vectorized::operator!=(const Vectorized& other) const { + return bfloat16_compare_as_fp32(*this, other, [](__m512 x, __m512 y) { + auto zero_vec = _mm512_set1_epi32(0); + auto cmp = _mm512_cmp_ps_mask(x, y, _CMP_NEQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, cmp, 0xFFFFFFFF)); + }); +} + +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_add_ps(x, y); }); +} +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_sub_ps(x, y); }); +} +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_mul_ps(x, y); }); +} +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return bfloat16_binary_op_as_fp32(a, b, [](const __m512& x, const __m512& y) { return _mm512_div_ps(x, y); }); +} + +Vectorized inline operator&(const Vectorized& a, const Vectorized& b) { + return _mm512_and_si512(a, b); +} +Vectorized inline operator|(const Vectorized& a, const Vectorized& b) { + return _mm512_or_si512(a, b); +} +Vectorized inline operator^(const Vectorized& a, const Vectorized& b) { + return _mm512_xor_si512(a, b); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +// frac. Implement this here so we can use subtraction +Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto max_lo = _mm512_max_ps(a_lo, b_lo); + auto max_hi = _mm512_max_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps(_mm512_set1_epi32(nan_lo_mask)); + auto nan_hi = _mm512_castsi512_ps(_mm512_set1_epi32(nan_hi_mask)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(max_lo, nan_lo); + auto o2 = _mm512_or_ps(max_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512i zero_vec = _mm512_set1_epi32(0); + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + auto min_lo = _mm512_min_ps(a_lo, b_lo); + auto min_hi = _mm512_min_ps(a_hi, b_hi); + auto nan_lo_mask = _mm512_cmp_ps_mask(a_lo, b_lo, _CMP_UNORD_Q); + auto nan_hi_mask = _mm512_cmp_ps_mask(a_hi, b_hi, _CMP_UNORD_Q); + auto nan_lo = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_lo_mask, + 0xFFFFFFFF)); + auto nan_hi = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, nan_hi_mask, + 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + auto o1 = _mm512_or_ps(min_lo, nan_lo); + auto o2 = _mm512_or_ps(min_hi, nan_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp(const Vectorized& a, + const Vectorized& min, const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + __m512 max_lo, max_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(min), min_lo, min_hi); + cvtbf16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, _mm512_max_ps(min_lo, a_lo)); + auto o2 = _mm512_min_ps(max_hi, _mm512_max_ps(min_hi, a_hi)); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) { + __m512 a_lo, a_hi; + __m512 max_lo, max_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(max), max_lo, max_hi); + auto o1 = _mm512_min_ps(max_lo, a_lo); + auto o2 = _mm512_min_ps(max_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) { + __m512 a_lo, a_hi; + __m512 min_lo, min_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(min), min_lo, min_hi); + auto o1 = _mm512_max_ps(min_lo, a_lo); + auto o2 = _mm512_max_ps(min_hi, a_hi); + return cvtfp32_bf16(o1, o2); +} + +template <> +inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { + int64_t i; +#pragma unroll + for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { + auto vsrc = _mm512_loadu_si512(reinterpret_cast<__m512i*>((void*)(src + i))); + _mm512_storeu_si512(reinterpret_cast<__m512i*>((void*)(dst + i)), vsrc); + } +#pragma unroll + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd(const Vectorized& a, + const Vectorized& b, const Vectorized& c) { + __m512 a_lo, a_hi; + __m512 b_lo, b_hi; + __m512 c_lo, c_hi; + cvtbf16_fp32(__m512i(a), a_lo, a_hi); + cvtbf16_fp32(__m512i(b), b_lo, b_hi); + cvtbf16_fp32(__m512i(c), c_lo, c_hi); + auto o1 = _mm512_fmadd_ps(a_lo, b_lo, c_lo); + auto o2 = _mm512_fmadd_ps(a_hi, b_hi, c_hi); + return cvtfp32_bf16(o1, o2); +} + +inline std::tuple, Vectorized> convert_bfloat16_float(const Vectorized& a) { + __m512 o1, o2; + cvtbf16_fp32(__m512i(a), o1, o2); + return std::make_tuple(o1, o2); +} + +inline Vectorized convert_float_bfloat16(const Vectorized& a, const Vectorized& b) { + return cvtfp32_bf16(__m512(a), __m512(b)); +} + +#else //defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +inline std::tuple, Vectorized> convert_bfloat16_float(const Vectorized& a) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; + a.store(arr2); + convert(arr2, arr, K); + return std::make_tuple( + Vectorized::loadu(arr), + Vectorized::loadu(arr + Vectorized::size())); +} + +inline Vectorized convert_float_bfloat16(const Vectorized& a, const Vectorized& b) { + constexpr int64_t K = Vectorized::size(); + __at_align__ float arr[K]; + __at_align__ BFloat16 arr2[K]; + a.store(arr); + b.store(arr + Vectorized::size()); + convert(arr, arr2, K); + return Vectorized::loadu(arr2); +} + +#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) +void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out) { + auto values = _mm256_loadu_si256(reinterpret_cast(data)); + __m512 out_values; + cvtbf16_fp32(values, out_values); + out = out_values; +} + +void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out1, Vectorized& out2) { + auto vec = Vectorized::loadu(data); + __m512 out1_values, out2_values; + cvtbf16_fp32(vec, out1_values, out2_values); + out1 = out1_values; + out2 = out2_values; +} +#else // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) +void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out) { + __at_align__ float values[Vectorized::size()]; + for (int k = 0; k < Vectorized::size(); ++k) { + values[k] = data[k]; + } + out = Vectorized::loadu(values); +} + +void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized& out1, Vectorized& out2) { + load_fp32_from_bf16(data, out1); + data += Vectorized::size(); + load_fp32_from_bf16(data, out2); +} + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h new file mode 100644 index 0000000000000..6fc22f0f7d336 --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h @@ -0,0 +1,526 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) +#include +#endif + +namespace at { +namespace vec { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template <> class Vectorized> { +private: + __m512d values; + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; +public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + Vectorized(__m512d v) : values(v) {} + Vectorized(c10::complex val) { + double real_value = val.real(); + double imag_value = val.imag(); + values = _mm512_setr_pd(real_value, imag_value, real_value, imag_value, + real_value, imag_value, real_value, imag_value); + } + Vectorized(c10::complex val1, c10::complex val2, + c10::complex val3, c10::complex val4) { + values = _mm512_setr_pd(val1.real(), val1.imag(), + val2.real(), val2.imag(), + val3.real(), val3.imag(), + val4.real(), val4.imag()); + } + operator __m512d() const { + return values; + } + template + static Vectorized> blend(const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + // NOLINTNEXTLINE(clang-diagnostic-warning) + switch (mask) { + case 0: + return a; + case 1: + return _mm512_mask_blend_pd(0x03, a.values, b.values); //b0000 0001 = b0000 0011 + case 2: + return _mm512_mask_blend_pd(0x0C, a.values, b.values); //b0000 0010 = b0000 1100 + case 3: + return _mm512_mask_blend_pd(0x0F, a.values, b.values); //b0000 0011 = b0000 1111 + case 4: + return _mm512_mask_blend_pd(0x30, a.values, b.values); //b0000 0100 = b0011 0000 + case 5: + return _mm512_mask_blend_pd(0x33, a.values, b.values); //b0000 0101 = b0011 0011 + case 6: + return _mm512_mask_blend_pd(0x3C, a.values, b.values); //b0000 0110 = b0011 1100 + case 7: + return _mm512_mask_blend_pd(0x3F, a.values, b.values); //b0000 0111 = b0011 1111 + case 8: + return _mm512_mask_blend_pd(0xC0, a.values, b.values); //b0000 1000 = b1100 0000 + case 9: + return _mm512_mask_blend_pd(0xC3, a.values, b.values); //b0000 1001 = b1100 0011 + case 10: + return _mm512_mask_blend_pd(0xCC, a.values, b.values); //b0000 1010 = b1100 1100 + case 11: + return _mm512_mask_blend_pd(0xCF, a.values, b.values); //b0000 1011 = b1100 1111 + case 12: + return _mm512_mask_blend_pd(0xF0, a.values, b.values); //b0000 1100 = b1111 0000 + case 13: + return _mm512_mask_blend_pd(0xF3, a.values, b.values); //b0000 1101 = b1111 0011 + case 14: + return _mm512_mask_blend_pd(0xFC, a.values, b.values); //b0000 1110 = b1111 1100 + case 15: + return _mm512_mask_blend_pd(0xFF, a.values, b.values); //b0000 1111 = b1111 1111 + } + return b; + } + static Vectorized> blendv(const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm512_unpacklo_pd(mask.values, mask.values); + auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask_), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_pd(mmask, a.values, b.values); + } + template + static Vectorized> arange(c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>(base, + base + c10::complex(1)*step, + base + c10::complex(2)*step, + base + c10::complex(3)*step); + } + static Vectorized> set(const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + } + return b; + } + static Vectorized> loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_pd(reinterpret_cast(ptr)); + + __at_align__ double tmp_values[2*size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < 2*size(); ++i) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm512_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[2*size()]; + _mm512_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map(c10::complex (*const f)(const c10::complex &)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + // AVX512 doesn't have horizontal add & horizontal sub instructions. + // TODO: hadd_pd() & hsub_pd() may have scope for improvement. + static inline __m512d hadd_pd(__m512d a, __m512d b) { + __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); + return _mm512_add_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); + } + static inline __m512d hsub_pd(__m512d a, __m512d b) { + __m512i idx1 = _mm512_set_epi64(14, 6, 12, 4, 10, 2, 8, 0); + __m512i idx2 = _mm512_set_epi64(15, 7, 13, 5, 11, 3, 9, 1); + return _mm512_sub_pd(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b), + _mm512_mask_permutex2var_pd(a, 0xff, idx2, b)); + } + __m512d abs_2_() const { + auto val_2 = _mm512_mul_pd(values, values); // a*a b*b + return hadd_pd(val_2, val_2); // a*a+b*b a*a+b*b + } + __m512d abs_() const { + return _mm512_sqrt_pd(abs_2_()); // abs abs + } + Vectorized> abs() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); + return _mm512_and_pd(abs_(), real_mask); // abs 0 + } + __m512d angle_() const { + //angle = atan2(b/a) + auto b_a = _mm512_permute_pd(values, 0x55); // b a + return Sleef_atan2d8_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); + auto angle = _mm512_permute_pd(angle_(), 0x55); // angle 90-angle + return _mm512_and_pd(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm512_setzero_pd(); + auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ); + auto mask_vec = _mm512_mask_set1_epi64(_mm512_castpd_si512(zero), mask, + 0xFFFFFFFFFFFFFFFF); + auto abs_val = Vectorized(abs); + + auto div = values / abs_val.values; // x / abs(x) + + return blendv(div, zero, _mm512_castsi512_pd(mask_vec)); + } + __m512d real_() const { + const __m512d real_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000, + 0xFFFFFFFFFFFFFFFF, 0x0000000000000000)); + return _mm512_and_pd(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m512d imag_() const { + const __m512d imag_mask = _mm512_castsi512_pd(_mm512_setr_epi64(0x0000000000000000, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, + 0x0000000000000000, 0xFFFFFFFFFFFFFFFF)); + return _mm512_and_pd(values, imag_mask); + } + Vectorized> imag() const { + return _mm512_permute_pd(imag_(), 0x55); //b a + } + __m512d conj_() const { + const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + return _mm512_xor_pd(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number performance. + return map(std::log); + } + Vectorized> log2() const { + const __m512d log2_ = _mm512_set1_pd(std::log(2)); + return _mm512_div_pd(log(), log2_); + } + Vectorized> log10() const { + const __m512d log10_ = _mm512_set1_pd(std::log(10)); + return _mm512_div_pd(log(), log10_); + } + Vectorized> log1p() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + const __m512d one = _mm512_set1_pd(1); + + auto conj = conj_(); + auto b_a = _mm512_permute_pd(conj, 0x55); //-b a + auto ab = _mm512_mul_pd(conj, b_a); //-ab -ab + auto im = _mm512_add_pd(ab, ab); //-2ab -2ab + + auto val_2 = _mm512_mul_pd(values, values); // a*a b*b + auto re = hsub_pd(val_2, _mm512_permute_pd(val_2, 0x55)); // a*a-b*b b*b-a*a + re = _mm512_sub_pd(one, re); + + auto root = Vectorized(_mm512_mask_blend_pd(0xAA, re, im)).sqrt(); //sqrt(re + i*im) + auto ln = Vectorized(_mm512_add_pd(b_a, root)).log(); //ln(iz + sqrt()) + return Vectorized(_mm512_permute_pd(ln.values, 0x55)).conj(); //-i*ln() + } + Vectorized> acos() const { + // acos(x) = pi/2 - asin(x) + constexpr auto pi_2d = c10::pi / 2; + const __m512d pi_2 = _mm512_setr_pd(pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0, pi_2d, 0.0); + return _mm512_sub_pd(pi_2, asin()); + } + Vectorized> atan() const; + Vectorized> atan2(const Vectorized> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> erf() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> erfc() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> exp() const { + //exp(a + bi) + // = exp(a)*(cos(b) + sin(b)i) + auto exp = Sleef_expd8_u10(values); //exp(a) exp(b) + exp = _mm512_mask_blend_pd(0xAA, exp, _mm512_permute_pd(exp, 0x55)); //exp(a) exp(a) + + auto sin_cos = Sleef_sincosd8_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)] + auto cos_sin = _mm512_mask_blend_pd(0xAA, _mm512_permute_pd(sin_cos.y, 0x55), + sin_cos.x); //cos(b) sin(b) + return _mm512_mul_pd(exp, cos_sin); + } + Vectorized> expm1() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm512_ceil_pd(values); + } + Vectorized> floor() const { + return _mm512_floor_pd(values); + } + Vectorized> hypot(const Vectorized> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> igamma(const Vectorized> &x) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> igammac(const Vectorized> &x) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> neg() const { + auto zero = _mm512_setzero_pd(); + return _mm512_sub_pd(zero, values); + } + Vectorized> nextafter(const Vectorized> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> round() const { + return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow(const Vectorized> &exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (int i = 0; i < size(); i++) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==(const Vectorized>& other) const { + auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask, + 0xFFFFFFFFFFFFFFFF)); + } + Vectorized> operator!=(const Vectorized>& other) const { + auto mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, mask, + 0xFFFFFFFFFFFFFFFF)); + } + Vectorized> operator<(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq(const Vectorized>& other) const; + Vectorized> ne(const Vectorized>& other) const; + Vectorized> lt(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> le(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> gt(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> ge(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } +}; + +template <> Vectorized> inline operator+(const Vectorized> &a, + const Vectorized> &b) { + return _mm512_add_pd(a, b); +} + +template <> Vectorized> inline operator-(const Vectorized> &a, + const Vectorized> &b) { + return _mm512_sub_pd(a, b); +} + +template <> Vectorized> inline operator*(const Vectorized> &a, + const Vectorized> &b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto ac_bd = _mm512_mul_pd(a, b); //ac bd + + auto d_c = _mm512_permute_pd(b, 0x55); //d c + d_c = _mm512_xor_pd(sign_mask, d_c); //d -c + auto ad_bc = _mm512_mul_pd(a, d_c); //ad -bc + + auto ret = Vectorized>::hsub_pd(ac_bd, ad_bc); //ac - bd ad + bc + return ret; +} + +template <> Vectorized> inline operator/(const Vectorized> &a, + const Vectorized> &b) { + //re + im*i = (a + bi) / (c + di) + //re = (ac + bd)/abs_2() + //im = (bc - ad)/abs_2() + const __m512d sign_mask = _mm512_setr_pd(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0); + auto ac_bd = _mm512_mul_pd(a, b); //ac bd + + auto d_c = _mm512_permute_pd(b, 0x55); //d c + d_c = _mm512_xor_pd(sign_mask, d_c); //-d c + auto ad_bc = _mm512_mul_pd(a, d_c); //-ad bc + + auto re_im = Vectorized>::hadd_pd(ac_bd, ad_bc);//ac + bd bc - ad + return _mm512_div_pd(re_im, b.abs_2_()); +} + +// reciprocal. Implement this here so we can use multiplication. +Vectorized> Vectorized>::reciprocal() const{ + //re + im*i = (a + bi) / (c + di) + //re = (ac + bd)/abs_2() = c/abs_2() + //im = (bc - ad)/abs_2() = d/abs_2() + const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto c_d = _mm512_xor_pd(sign_mask, values); //c -d + return _mm512_div_pd(c_d, abs_2_()); +} + +Vectorized> Vectorized>::atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + const __m512d i = _mm512_setr_pd(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); + const Vectorized i_half = _mm512_setr_pd(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5); + + auto sum = Vectorized(_mm512_add_pd(i, values)); // a 1+b + auto sub = Vectorized(_mm512_sub_pd(i, values)); // -a 1-b + auto ln = (sum/sub).log(); // ln((i + z)/(i - z)) + return i_half*ln; // i/2*ln() +} + +template <> +Vectorized> inline maximum(const Vectorized>& a, + const Vectorized>& b) { + auto zero_vec = _mm512_set1_epi64(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm512_mask_blend_pd(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask, + 0xFFFFFFFFFFFFFFFF); + return _mm512_or_pd(max, _mm512_castsi512_pd(isnan)); +} + +template <> +Vectorized> inline minimum(const Vectorized>& a, + const Vectorized>& b) { + auto zero_vec = _mm512_set1_epi64(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm512_mask_blend_pd(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_pd_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi64(zero_vec, isnan_mask, + 0xFFFFFFFFFFFFFFFF); + return _mm512_or_pd(min, _mm512_castsi512_pd(isnan)); +} + +template <> +Vectorized> inline operator&(const Vectorized>& a, + const Vectorized>& b) { + return _mm512_and_pd(a, b); +} + +template <> +Vectorized> inline operator|(const Vectorized>& a, + const Vectorized>& b) { + return _mm512_or_pd(a, b); +} + +template <> +Vectorized> inline operator^(const Vectorized>& a, + const Vectorized>& b) { + return _mm512_xor_pd(a, b); +} + +Vectorized> Vectorized>::eq(const Vectorized>& other) const { + return (*this == other) & Vectorized>(_mm512_set1_pd(1.0)); +} + +Vectorized> Vectorized>::ne(const Vectorized>& other) const { + return (*this != other) & Vectorized>(_mm512_set1_pd(1.0)); +} + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h new file mode 100644 index 0000000000000..dfd070604c40c --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h @@ -0,0 +1,1030 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) +#include +#endif + +namespace at { +namespace vec { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template <> class Vectorized> { +private: + __m512 values; + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; +public: + using value_type = c10::complex; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() {} + Vectorized(__m512 v) : values(v) {} + Vectorized(c10::complex val) { + float real_value = val.real(); + float imag_value = val.imag(); + values = _mm512_setr_ps(real_value, imag_value, + real_value, imag_value, + real_value, imag_value, + real_value, imag_value, + real_value, imag_value, + real_value, imag_value, + real_value, imag_value, + real_value, imag_value); + } + Vectorized(c10::complex val1, c10::complex val2, + c10::complex val3, c10::complex val4, + c10::complex val5, c10::complex val6, + c10::complex val7, c10::complex val8) { + values = _mm512_setr_ps(val1.real(), val1.imag(), + val2.real(), val2.imag(), + val3.real(), val3.imag(), + val4.real(), val4.imag(), + val5.real(), val5.imag(), + val6.real(), val6.imag(), + val7.real(), val7.imag(), + val8.real(), val8.imag()); + } + operator __m512() const { + return values; + } + template + static Vectorized> blend(const Vectorized>& a, + const Vectorized>& b) { + // convert c10::complex index mask to V index mask: xy -> xxyy + // NOLINTNEXTLINE(clang-diagnostic-warning) + // The compiler would hopefully convert this switch condition + // into a jump table + switch (mask) { + case 0: + return a; + case 1: + return _mm512_mask_blend_ps(0x03, a.values, b.values); + case 2: + return _mm512_mask_blend_ps(0x0C, a.values, b.values); + case 3: + return _mm512_mask_blend_ps(0x0F, a.values, b.values); + case 4: + return _mm512_mask_blend_ps(0x30, a.values, b.values); + case 5: + return _mm512_mask_blend_ps(0x33, a.values, b.values); + case 6: + return _mm512_mask_blend_ps(0x3C, a.values, b.values); + case 7: + return _mm512_mask_blend_ps(0x3F, a.values, b.values); + case 8: + return _mm512_mask_blend_ps(0xC0, a.values, b.values); + case 9: + return _mm512_mask_blend_ps(0xC3, a.values, b.values); + case 10: + return _mm512_mask_blend_ps(0xCC, a.values, b.values); + case 11: + return _mm512_mask_blend_ps(0xCF, a.values, b.values); + case 12: + return _mm512_mask_blend_ps(0xF0, a.values, b.values); + case 13: + return _mm512_mask_blend_ps(0xF3, a.values, b.values); + case 14: + return _mm512_mask_blend_ps(0xFC, a.values, b.values); + case 15: + return _mm512_mask_blend_ps(0xFF, a.values, b.values); + case 16: + return _mm512_mask_blend_ps(0x300, a.values, b.values); + case 17: + return _mm512_mask_blend_ps(0x303, a.values, b.values); + case 18: + return _mm512_mask_blend_ps(0x30C, a.values, b.values); + case 19: + return _mm512_mask_blend_ps(0x30F, a.values, b.values); + case 20: + return _mm512_mask_blend_ps(0x330, a.values, b.values); + case 21: + return _mm512_mask_blend_ps(0x333, a.values, b.values); + case 22: + return _mm512_mask_blend_ps(0x33C, a.values, b.values); + case 23: + return _mm512_mask_blend_ps(0x33F, a.values, b.values); + case 24: + return _mm512_mask_blend_ps(0x3C0, a.values, b.values); + case 25: + return _mm512_mask_blend_ps(0x3C3, a.values, b.values); + case 26: + return _mm512_mask_blend_ps(0x3CC, a.values, b.values); + case 27: + return _mm512_mask_blend_ps(0x3CF, a.values, b.values); + case 28: + return _mm512_mask_blend_ps(0x3F0, a.values, b.values); + case 29: + return _mm512_mask_blend_ps(0x3F3, a.values, b.values); + case 30: + return _mm512_mask_blend_ps(0x3FC, a.values, b.values); + case 31: + return _mm512_mask_blend_ps(0x3FF, a.values, b.values); + case 32: + return _mm512_mask_blend_ps(0xC00, a.values, b.values); + case 33: + return _mm512_mask_blend_ps(0xC03, a.values, b.values); + case 34: + return _mm512_mask_blend_ps(0xC0C, a.values, b.values); + case 35: + return _mm512_mask_blend_ps(0xC0F, a.values, b.values); + case 36: + return _mm512_mask_blend_ps(0xC30, a.values, b.values); + case 37: + return _mm512_mask_blend_ps(0xC33, a.values, b.values); + case 38: + return _mm512_mask_blend_ps(0xC3C, a.values, b.values); + case 39: + return _mm512_mask_blend_ps(0xC3F, a.values, b.values); + case 40: + return _mm512_mask_blend_ps(0xCC0, a.values, b.values); + case 41: + return _mm512_mask_blend_ps(0xCC3, a.values, b.values); + case 42: + return _mm512_mask_blend_ps(0xCCC, a.values, b.values); + case 43: + return _mm512_mask_blend_ps(0xCCF, a.values, b.values); + case 44: + return _mm512_mask_blend_ps(0xCF0, a.values, b.values); + case 45: + return _mm512_mask_blend_ps(0xCF3, a.values, b.values); + case 46: + return _mm512_mask_blend_ps(0xCFC, a.values, b.values); + case 47: + return _mm512_mask_blend_ps(0xCFF, a.values, b.values); + case 48: + return _mm512_mask_blend_ps(0xF00, a.values, b.values); + case 49: + return _mm512_mask_blend_ps(0xF03, a.values, b.values); + case 50: + return _mm512_mask_blend_ps(0xF0C, a.values, b.values); + case 51: + return _mm512_mask_blend_ps(0xF0F, a.values, b.values); + case 52: + return _mm512_mask_blend_ps(0xF30, a.values, b.values); + case 53: + return _mm512_mask_blend_ps(0xF33, a.values, b.values); + case 54: + return _mm512_mask_blend_ps(0xF3C, a.values, b.values); + case 55: + return _mm512_mask_blend_ps(0xF3F, a.values, b.values); + case 56: + return _mm512_mask_blend_ps(0xFC0, a.values, b.values); + case 57: + return _mm512_mask_blend_ps(0xFC3, a.values, b.values); + case 58: + return _mm512_mask_blend_ps(0xFCC, a.values, b.values); + case 59: + return _mm512_mask_blend_ps(0xFCF, a.values, b.values); + case 60: + return _mm512_mask_blend_ps(0xFF0, a.values, b.values); + case 61: + return _mm512_mask_blend_ps(0xFF3, a.values, b.values); + case 62: + return _mm512_mask_blend_ps(0xFFC, a.values, b.values); + case 63: + return _mm512_mask_blend_ps(0xFFF, a.values, b.values); + case 64: + return _mm512_mask_blend_ps(0x3000, a.values, b.values); + case 65: + return _mm512_mask_blend_ps(0x3003, a.values, b.values); + case 66: + return _mm512_mask_blend_ps(0x300C, a.values, b.values); + case 67: + return _mm512_mask_blend_ps(0x300F, a.values, b.values); + case 68: + return _mm512_mask_blend_ps(0x3030, a.values, b.values); + case 69: + return _mm512_mask_blend_ps(0x3033, a.values, b.values); + case 70: + return _mm512_mask_blend_ps(0x303C, a.values, b.values); + case 71: + return _mm512_mask_blend_ps(0x303F, a.values, b.values); + case 72: + return _mm512_mask_blend_ps(0x30C0, a.values, b.values); + case 73: + return _mm512_mask_blend_ps(0X30C3, a.values, b.values); + case 74: + return _mm512_mask_blend_ps(0x30CC, a.values, b.values); + case 75: + return _mm512_mask_blend_ps(0x30CF, a.values, b.values); + case 76: + return _mm512_mask_blend_ps(0x30F0, a.values, b.values); + case 77: + return _mm512_mask_blend_ps(0x30F3, a.values, b.values); + case 78: + return _mm512_mask_blend_ps(0x30FC, a.values, b.values); + case 79: + return _mm512_mask_blend_ps(0x30FF, a.values, b.values); + case 80: + return _mm512_mask_blend_ps(0x3300, a.values, b.values); + case 81: + return _mm512_mask_blend_ps(0X3303, a.values, b.values); + case 82: + return _mm512_mask_blend_ps(0x330C, a.values, b.values); + case 83: + return _mm512_mask_blend_ps(0x330F, a.values, b.values); + case 84: + return _mm512_mask_blend_ps(0x3330, a.values, b.values); + case 85: + return _mm512_mask_blend_ps(0x3333, a.values, b.values); + case 86: + return _mm512_mask_blend_ps(0x333C, a.values, b.values); + case 87: + return _mm512_mask_blend_ps(0X333F, a.values, b.values); + case 88: + return _mm512_mask_blend_ps(0x33C0, a.values, b.values); + case 89: + return _mm512_mask_blend_ps(0x33C3, a.values, b.values); + case 90: + return _mm512_mask_blend_ps(0x33CC, a.values, b.values); + case 91: + return _mm512_mask_blend_ps(0x33CF, a.values, b.values); + case 92: + return _mm512_mask_blend_ps(0x33F0, a.values, b.values); + case 93: + return _mm512_mask_blend_ps(0x33F3, a.values, b.values); + case 94: + return _mm512_mask_blend_ps(0x33FC, a.values, b.values); + case 95: + return _mm512_mask_blend_ps(0x33FF, a.values, b.values); + case 96: + return _mm512_mask_blend_ps(0X3C00, a.values, b.values); + case 97: + return _mm512_mask_blend_ps(0x3C03, a.values, b.values); + case 98: + return _mm512_mask_blend_ps(0x3C0C, a.values, b.values); + case 99: + return _mm512_mask_blend_ps(0x3C0F, a.values, b.values); + case 100: + return _mm512_mask_blend_ps(0x3C30, a.values, b.values); + case 101: + return _mm512_mask_blend_ps(0x3C33, a.values, b.values); + case 102: + return _mm512_mask_blend_ps(0x3C3C, a.values, b.values); + case 103: + return _mm512_mask_blend_ps(0x3C3F, a.values, b.values); + case 104: + return _mm512_mask_blend_ps(0x3CC0, a.values, b.values); + case 105: + return _mm512_mask_blend_ps(0x3CC3, a.values, b.values); + case 106: + return _mm512_mask_blend_ps(0x3CCC, a.values, b.values); + case 107: + return _mm512_mask_blend_ps(0x3CCF, a.values, b.values); + case 108: + return _mm512_mask_blend_ps(0x3CF0, a.values, b.values); + case 109: + return _mm512_mask_blend_ps(0x3CF3, a.values, b.values); + case 110: + return _mm512_mask_blend_ps(0x3CFC, a.values, b.values); + case 111: + return _mm512_mask_blend_ps(0x3CFF, a.values, b.values); + case 112: + return _mm512_mask_blend_ps(0x3F00, a.values, b.values); + case 113: + return _mm512_mask_blend_ps(0x3F03, a.values, b.values); + case 114: + return _mm512_mask_blend_ps(0x3F0C, a.values, b.values); + case 115: + return _mm512_mask_blend_ps(0x3F0F, a.values, b.values); + case 116: + return _mm512_mask_blend_ps(0x3F30, a.values, b.values); + case 117: + return _mm512_mask_blend_ps(0x3F33, a.values, b.values); + case 118: + return _mm512_mask_blend_ps(0x3F3C, a.values, b.values); + case 119: + return _mm512_mask_blend_ps(0x3F3F, a.values, b.values); + case 120: + return _mm512_mask_blend_ps(0x3FC0, a.values, b.values); + case 121: + return _mm512_mask_blend_ps(0x3FC3, a.values, b.values); + case 122: + return _mm512_mask_blend_ps(0x3FCC, a.values, b.values); + case 123: + return _mm512_mask_blend_ps(0x3FCF, a.values, b.values); + case 124: + return _mm512_mask_blend_ps(0x3FF0, a.values, b.values); + case 125: + return _mm512_mask_blend_ps(0x3FF3, a.values, b.values); + case 126: + return _mm512_mask_blend_ps(0x3FFC, a.values, b.values); + case 127: + return _mm512_mask_blend_ps(0x3FFF, a.values, b.values); + case 128: + return _mm512_mask_blend_ps(0xC000, a.values, b.values); + case 129: + return _mm512_mask_blend_ps(0xC003, a.values, b.values); + case 130: + return _mm512_mask_blend_ps(0xC00C, a.values, b.values); + case 131: + return _mm512_mask_blend_ps(0xC00F, a.values, b.values); + case 132: + return _mm512_mask_blend_ps(0xC030, a.values, b.values); + case 133: + return _mm512_mask_blend_ps(0xC033, a.values, b.values); + case 134: + return _mm512_mask_blend_ps(0xC03C, a.values, b.values); + case 135: + return _mm512_mask_blend_ps(0xC03F, a.values, b.values); + case 136: + return _mm512_mask_blend_ps(0xC0C0, a.values, b.values); + case 137: + return _mm512_mask_blend_ps(0xC0C3, a.values, b.values); + case 138: + return _mm512_mask_blend_ps(0xC0CC, a.values, b.values); + case 139: + return _mm512_mask_blend_ps(0xC0CF, a.values, b.values); + case 140: + return _mm512_mask_blend_ps(0xC0F0, a.values, b.values); + case 141: + return _mm512_mask_blend_ps(0xC0F3, a.values, b.values); + case 142: + return _mm512_mask_blend_ps(0xC0FC, a.values, b.values); + case 143: + return _mm512_mask_blend_ps(0xC0FF, a.values, b.values); + case 144: + return _mm512_mask_blend_ps(0xC300, a.values, b.values); + case 145: + return _mm512_mask_blend_ps(0xC303, a.values, b.values); + case 146: + return _mm512_mask_blend_ps(0xC30C, a.values, b.values); + case 147: + return _mm512_mask_blend_ps(0xC30F, a.values, b.values); + case 148: + return _mm512_mask_blend_ps(0xC330, a.values, b.values); + case 149: + return _mm512_mask_blend_ps(0xC333, a.values, b.values); + case 150: + return _mm512_mask_blend_ps(0xC33C, a.values, b.values); + case 151: + return _mm512_mask_blend_ps(0xC33F, a.values, b.values); + case 152: + return _mm512_mask_blend_ps(0xC3C0, a.values, b.values); + case 153: + return _mm512_mask_blend_ps(0xC3C3, a.values, b.values); + case 154: + return _mm512_mask_blend_ps(0xC3CC, a.values, b.values); + case 155: + return _mm512_mask_blend_ps(0xC3CF, a.values, b.values); + case 156: + return _mm512_mask_blend_ps(0xC3F0, a.values, b.values); + case 157: + return _mm512_mask_blend_ps(0xC3F3, a.values, b.values); + case 158: + return _mm512_mask_blend_ps(0xC3FC, a.values, b.values); + case 159: + return _mm512_mask_blend_ps(0xC3FF, a.values, b.values); + case 160: + return _mm512_mask_blend_ps(0xCC00, a.values, b.values); + case 161: + return _mm512_mask_blend_ps(0xCC03, a.values, b.values); + case 162: + return _mm512_mask_blend_ps(0xCC0C, a.values, b.values); + case 163: + return _mm512_mask_blend_ps(0xCC0F, a.values, b.values); + case 164: + return _mm512_mask_blend_ps(0xCC30, a.values, b.values); + case 165: + return _mm512_mask_blend_ps(0xCC33, a.values, b.values); + case 166: + return _mm512_mask_blend_ps(0xCC3C, a.values, b.values); + case 167: + return _mm512_mask_blend_ps(0xCC3F, a.values, b.values); + case 168: + return _mm512_mask_blend_ps(0xCCC0, a.values, b.values); + case 169: + return _mm512_mask_blend_ps(0xCCC3, a.values, b.values); + case 170: + return _mm512_mask_blend_ps(0xCCCC, a.values, b.values); + case 171: + return _mm512_mask_blend_ps(0xCCCF, a.values, b.values); + case 172: + return _mm512_mask_blend_ps(0xCCF0, a.values, b.values); + case 173: + return _mm512_mask_blend_ps(0xCCF3, a.values, b.values); + case 174: + return _mm512_mask_blend_ps(0xCCFC, a.values, b.values); + case 175: + return _mm512_mask_blend_ps(0xCCFF, a.values, b.values); + case 176: + return _mm512_mask_blend_ps(0xCF00, a.values, b.values); + case 177: + return _mm512_mask_blend_ps(0xCF03, a.values, b.values); + case 178: + return _mm512_mask_blend_ps(0xCF0C, a.values, b.values); + case 179: + return _mm512_mask_blend_ps(0xCF0F, a.values, b.values); + case 180: + return _mm512_mask_blend_ps(0xCF30, a.values, b.values); + case 181: + return _mm512_mask_blend_ps(0xCF33, a.values, b.values); + case 182: + return _mm512_mask_blend_ps(0xCF3C, a.values, b.values); + case 183: + return _mm512_mask_blend_ps(0xCF3F, a.values, b.values); + case 184: + return _mm512_mask_blend_ps(0xCFC0, a.values, b.values); + case 185: + return _mm512_mask_blend_ps(0xCFC3, a.values, b.values); + case 186: + return _mm512_mask_blend_ps(0xCFCC, a.values, b.values); + case 187: + return _mm512_mask_blend_ps(0xCFCF, a.values, b.values); + case 188: + return _mm512_mask_blend_ps(0xCFF0, a.values, b.values); + case 189: + return _mm512_mask_blend_ps(0xCFF3, a.values, b.values); + case 190: + return _mm512_mask_blend_ps(0xCFFC, a.values, b.values); + case 191: + return _mm512_mask_blend_ps(0xCFFF, a.values, b.values); + case 192: + return _mm512_mask_blend_ps(0xF000, a.values, b.values); + case 193: + return _mm512_mask_blend_ps(0xF003, a.values, b.values); + case 194: + return _mm512_mask_blend_ps(0xF00C, a.values, b.values); + case 195: + return _mm512_mask_blend_ps(0xF00F, a.values, b.values); + case 196: + return _mm512_mask_blend_ps(0xF030, a.values, b.values); + case 197: + return _mm512_mask_blend_ps(0xF033, a.values, b.values); + case 198: + return _mm512_mask_blend_ps(0xF03C, a.values, b.values); + case 199: + return _mm512_mask_blend_ps(0xF03F, a.values, b.values); + case 200: + return _mm512_mask_blend_ps(0XF0C0, a.values, b.values); + case 201: + return _mm512_mask_blend_ps(0xF0C3, a.values, b.values); + case 202: + return _mm512_mask_blend_ps(0xF0CC, a.values, b.values); + case 203: + return _mm512_mask_blend_ps(0xF0CF, a.values, b.values); + case 204: + return _mm512_mask_blend_ps(0xF0F0, a.values, b.values); + case 205: + return _mm512_mask_blend_ps(0xF0F3, a.values, b.values); + case 206: + return _mm512_mask_blend_ps(0xF0FC, a.values, b.values); + case 207: + return _mm512_mask_blend_ps(0xF0FF, a.values, b.values); + case 208: + return _mm512_mask_blend_ps(0XF300, a.values, b.values); + case 209: + return _mm512_mask_blend_ps(0xF303, a.values, b.values); + case 210: + return _mm512_mask_blend_ps(0xF30C, a.values, b.values); + case 211: + return _mm512_mask_blend_ps(0xF30F, a.values, b.values); + case 212: + return _mm512_mask_blend_ps(0xF330, a.values, b.values); + case 213: + return _mm512_mask_blend_ps(0xF333, a.values, b.values); + case 214: + return _mm512_mask_blend_ps(0XF33C, a.values, b.values); + case 215: + return _mm512_mask_blend_ps(0xF33F, a.values, b.values); + case 216: + return _mm512_mask_blend_ps(0xF3C0, a.values, b.values); + case 217: + return _mm512_mask_blend_ps(0xF3C3, a.values, b.values); + case 218: + return _mm512_mask_blend_ps(0xF3CC, a.values, b.values); + case 219: + return _mm512_mask_blend_ps(0xF3CF, a.values, b.values); + case 220: + return _mm512_mask_blend_ps(0xF3F0, a.values, b.values); + case 221: + return _mm512_mask_blend_ps(0xF3F3, a.values, b.values); + case 222: + return _mm512_mask_blend_ps(0xF3FC, a.values, b.values); + case 223: + return _mm512_mask_blend_ps(0XF3FF, a.values, b.values); + case 224: + return _mm512_mask_blend_ps(0xFC00, a.values, b.values); + case 225: + return _mm512_mask_blend_ps(0xFC03, a.values, b.values); + case 226: + return _mm512_mask_blend_ps(0xFC0C, a.values, b.values); + case 227: + return _mm512_mask_blend_ps(0xFC0F, a.values, b.values); + case 228: + return _mm512_mask_blend_ps(0xFC30, a.values, b.values); + case 229: + return _mm512_mask_blend_ps(0xFC33, a.values, b.values); + case 230: + return _mm512_mask_blend_ps(0xFC3C, a.values, b.values); + case 231: + return _mm512_mask_blend_ps(0xFC3F, a.values, b.values); + case 232: + return _mm512_mask_blend_ps(0xFCC0, a.values, b.values); + case 233: + return _mm512_mask_blend_ps(0xFCC3, a.values, b.values); + case 234: + return _mm512_mask_blend_ps(0xFCCC, a.values, b.values); + case 235: + return _mm512_mask_blend_ps(0xFCCF, a.values, b.values); + case 236: + return _mm512_mask_blend_ps(0xFCF0, a.values, b.values); + case 237: + return _mm512_mask_blend_ps(0xFCF3, a.values, b.values); + case 238: + return _mm512_mask_blend_ps(0xFCFC, a.values, b.values); + case 239: + return _mm512_mask_blend_ps(0xFCFF, a.values, b.values); + case 240: + return _mm512_mask_blend_ps(0xFF00, a.values, b.values); + case 241: + return _mm512_mask_blend_ps(0xFF03, a.values, b.values); + case 242: + return _mm512_mask_blend_ps(0xFF0C, a.values, b.values); + case 243: + return _mm512_mask_blend_ps(0xFF0F, a.values, b.values); + case 244: + return _mm512_mask_blend_ps(0xFF30, a.values, b.values); + case 245: + return _mm512_mask_blend_ps(0xFF33, a.values, b.values); + case 246: + return _mm512_mask_blend_ps(0xFF3C, a.values, b.values); + case 247: + return _mm512_mask_blend_ps(0xFF3F, a.values, b.values); + case 248: + return _mm512_mask_blend_ps(0xFFC0, a.values, b.values); + case 249: + return _mm512_mask_blend_ps(0xFFC3, a.values, b.values); + case 250: + return _mm512_mask_blend_ps(0xFFCC, a.values, b.values); + case 251: + return _mm512_mask_blend_ps(0xFFCF, a.values, b.values); + case 252: + return _mm512_mask_blend_ps(0xFFF0, a.values, b.values); + case 253: + return _mm512_mask_blend_ps(0xFFF3, a.values, b.values); + case 254: + return _mm512_mask_blend_ps(0xFFFC, a.values, b.values); + } + return b; + } + static Vectorized> blendv(const Vectorized>& a, + const Vectorized>& b, + const Vectorized>& mask) { + // convert c10::complex index mask to V index mask: xy -> xxyy + auto mask_ = _mm512_unpacklo_ps(mask.values, mask.values); + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask_), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mmask, a.values, b.values); + } + template + static Vectorized> arange(c10::complex base = 0., + step_t step = static_cast(1)) { + return Vectorized>(base, + base + step, + base + c10::complex(2)*step, + base + c10::complex(3)*step, + base + c10::complex(4)*step, + base + c10::complex(5)*step, + base + c10::complex(6)*step, + base + c10::complex(7)*step); + } + static Vectorized> set(const Vectorized>& a, + const Vectorized>& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized> loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_ps(reinterpret_cast(ptr)); + + __at_align__ float tmp_values[2*size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < 2*size(); ++i) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(c10::complex)); + return _mm512_load_ps(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[2*size()]; + _mm512_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(c10::complex)); + } + } + // AVX512 doesn't have horizontal add & horizontal sub instructions. + // TODO: hadd_pd() & hsub_pd() may have scope for improvement. + static inline __m512 hadd_ps(__m512 a, __m512 b) { + __m512i idx1 = _mm512_set_epi32(30, 14, 28, 12, 26, 10, 24, 8, 22, 6, 20, 4, 18, 2, 16, 0); + __m512i idx2 = _mm512_set_epi32(31, 15, 29, 13, 27, 11, 25, 9, 23, 7, 21, 5, 19, 3, 17, 1); + return _mm512_add_ps(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); + } + static inline __m512 hsub_ps(__m512 a, __m512 b) { + __m512i idx1 = _mm512_set_epi32(30, 14, 28, 12, 26, 10, 24, 8, 22, 6, 20, 4, 18, 2, 16, 0); + __m512i idx2 = _mm512_set_epi32(31, 15, 29, 13, 27, 11, 25, 9, 23, 7, 21, 5, 19, 3, 17, 1); + return _mm512_sub_ps(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b), + _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); + } + const c10::complex& operator[](int idx) const = delete; + c10::complex& operator[](int idx) = delete; + Vectorized> map(c10::complex (*const f)(const c10::complex &)) const { + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (int i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + __m512 abs_2_() const { + auto val_2 = _mm512_mul_ps(values, values); // a*a b*b + auto ret = hadd_ps(val_2, val_2); // a*a+b*b a*a+b*b + return ret; + } + __m512 abs_() const { + return _mm512_sqrt_ps(abs_2_()); // abs abs + } + Vectorized> abs() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); + return _mm512_and_ps(abs_(), real_mask); // abs 0 + } + __m512 angle_() const { + //angle = atan2(b/a) + auto b_a = _mm512_permute_ps(values, 0xB1); // b a + return Sleef_atan2f16_u10(values, b_a); // 90-angle angle + } + Vectorized> angle() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); + auto angle = _mm512_permute_ps(angle_(), 0xB1); // angle 90-angle + return _mm512_and_ps(angle, real_mask); // angle 0 + } + Vectorized> sgn() const { + auto abs = abs_(); + auto zero = _mm512_setzero_ps(); + auto mask = _mm512_cmp_ps_mask(abs, zero, _CMP_EQ_OQ); + auto abs_val = Vectorized(abs); + + auto div = values / abs_val.values; // x / abs(x) + + return _mm512_mask_blend_ps(mask, div, zero); + } + __m512 real_() const { + const __m512 real_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000, + 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, 0x00000000)); + return _mm512_and_ps(values, real_mask); + } + Vectorized> real() const { + return real_(); + } + __m512 imag_() const { + const __m512 imag_mask = _mm512_castsi512_ps(_mm512_setr_epi32(0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, + 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, + 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF, + 0x00000000, 0xFFFFFFFF, 0x00000000, 0xFFFFFFFF)); + return _mm512_and_ps(values, imag_mask); + } + Vectorized> imag() const { + return _mm512_permute_ps(imag_(), 0xB1); //b a + } + __m512 conj_() const { + const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + return _mm512_xor_ps(values, sign_mask); // a -b + } + Vectorized> conj() const { + return conj_(); + } + Vectorized> log() const { + // Most trigonomic ops use the log() op to improve complex number performance. + return map(std::log); + } + Vectorized> log2() const { + const __m512 log2_ = _mm512_set1_ps(std::log(2)); + return _mm512_div_ps(log(), log2_); + } + Vectorized> log10() const { + const __m512 log10_ = _mm512_set1_ps(std::log(10)); + return _mm512_div_ps(log(), log10_); + } + Vectorized> log1p() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> asin() const { + // asin(x) + // = -i*ln(iz + sqrt(1 -z^2)) + // = -i*ln((ai - b) + sqrt(1 - (a + bi)*(a + bi))) + // = -i*ln((-b + ai) + sqrt(1 - (a**2 - b**2) - 2*abi)) + const __m512 one = _mm512_set1_ps(1); + + auto conj = conj_(); + auto b_a = _mm512_permute_ps(conj, 0xB1); //-b a + auto ab = _mm512_mul_ps(conj, b_a); //-ab -ab + auto im = _mm512_add_ps(ab, ab); //-2ab -2ab + + auto val_2 = _mm512_mul_ps(values, values); // a*a b*b + auto re = hsub_ps(val_2, _mm512_permute_ps(val_2, 0xB1)); // a*a-b*b b*b-a*a + re = _mm512_sub_ps(one, re); + + auto root = Vectorized(_mm512_mask_blend_ps(0xAAAA, re, im)).sqrt(); //sqrt(re + i*im) + auto ln = Vectorized(_mm512_add_ps(b_a, root)).log(); //ln(iz + sqrt()) + return Vectorized(_mm512_permute_ps(ln.values, 0xB1)).conj(); //-i*ln() + } + Vectorized> acos() const { + return map(std::acos); + } + Vectorized> atan() const; + Vectorized> atan2(const Vectorized> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> erf() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> erfc() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> exp() const { + //exp(a + bi) + // = exp(a)*(cos(b) + sin(b)i) + auto exp = Sleef_expf16_u10(values); //exp(a) exp(b) + exp = _mm512_mask_blend_ps(0xAAAA, exp, _mm512_permute_ps(exp, 0xB1)); //exp(a) exp(a) + + auto sin_cos = Sleef_sincosf16_u10(values); //[sin(a), cos(a)] [sin(b), cos(b)] + auto cos_sin = _mm512_mask_blend_ps(0xAAAA, _mm512_permute_ps(sin_cos.y, 0xB1), + sin_cos.x); //cos(b) sin(b) + return _mm512_mul_ps(exp, cos_sin); + } + Vectorized> expm1() const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> sin() const { + return map(std::sin); + } + Vectorized> sinh() const { + return map(std::sinh); + } + Vectorized> cos() const { + return map(std::cos); + } + Vectorized> cosh() const { + return map(std::cosh); + } + Vectorized> ceil() const { + return _mm512_ceil_ps(values); + } + Vectorized> floor() const { + return _mm512_floor_ps(values); + } + Vectorized> hypot(const Vectorized> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> igamma(const Vectorized> &x) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> igammac(const Vectorized> &x) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> neg() const { + auto zero = _mm512_setzero_ps(); + return _mm512_sub_ps(zero, values); + } + Vectorized> nextafter(const Vectorized> &b) const { + AT_ERROR("not supported for complex numbers"); + } + Vectorized> round() const { + return _mm512_roundscale_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized> tan() const { + return map(std::tan); + } + Vectorized> tanh() const { + return map(std::tanh); + } + Vectorized> trunc() const { + return _mm512_roundscale_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized> sqrt() const { + return map(std::sqrt); + } + Vectorized> reciprocal() const; + Vectorized> rsqrt() const { + return sqrt().reciprocal(); + } + Vectorized> pow(const Vectorized> &exp) const { + __at_align__ c10::complex x_tmp[size()]; + __at_align__ c10::complex y_tmp[size()]; + store(x_tmp); + exp.store(y_tmp); + for (int i = 0; i < size(); i++) { + x_tmp[i] = std::pow(x_tmp[i], y_tmp[i]); + } + return loadu(x_tmp); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized> operator==(const Vectorized>& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF)); + } + Vectorized> operator!=(const Vectorized>& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF)); + } + Vectorized> operator<(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator<=(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> operator>=(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + + Vectorized> eq(const Vectorized>& other) const; + Vectorized> ne(const Vectorized>& other) const; + Vectorized> lt(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> le(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> gt(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } + Vectorized> ge(const Vectorized>& other) const { + TORCH_CHECK(false, "not supported for complex numbers"); + } +}; + +template <> Vectorized> inline operator+(const Vectorized> &a, + const Vectorized> &b) { + return _mm512_add_ps(a, b); +} + +template <> Vectorized> inline operator-(const Vectorized> &a, + const Vectorized> &b) { + return _mm512_sub_ps(a, b); +} + +template <> Vectorized> inline operator*(const Vectorized> &a, + const Vectorized> &b) { + //(a + bi) * (c + di) = (ac - bd) + (ad + bc)i + const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto ac_bd = _mm512_mul_ps(a, b); //ac bd + + auto d_c = _mm512_permute_ps(b, 0xB1); //d c + d_c = _mm512_xor_ps(sign_mask, d_c); //d -c + auto ad_bc = _mm512_mul_ps(a, d_c); //ad -bc + + auto ret = Vectorized>::hsub_ps(ac_bd, ad_bc); //ac - bd ad + bc + return ret; +} + +template <> Vectorized> inline operator/(const Vectorized> &a, + const Vectorized> &b) { + //re + im*i = (a + bi) / (c + di) + //re = (ac + bd)/abs_2() + //im = (bc - ad)/abs_2() + const __m512 sign_mask = _mm512_setr_ps(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, + -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0); + auto ac_bd = _mm512_mul_ps(a, b); //ac bd + + auto d_c = _mm512_permute_ps(b, 0xB1); //d c + d_c = _mm512_xor_ps(sign_mask, d_c); //-d c + auto ad_bc = _mm512_mul_ps(a, d_c); //-ad bc + + auto re_im = Vectorized>::hadd_ps(ac_bd, ad_bc);//ac + bd bc - ad + return _mm512_div_ps(re_im, b.abs_2_()); +} + +// reciprocal. Implement this here so we can use multiplication. +Vectorized> Vectorized>::reciprocal() const { + //re + im*i = (a + bi) / (c + di) + //re = (ac + bd)/abs_2() = c/abs_2() + //im = (bc - ad)/abs_2() = d/abs_2() + const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + auto c_d = _mm512_xor_ps(sign_mask, values); //c -d + return _mm512_div_ps(c_d, abs_2_()); +} + +Vectorized> Vectorized>::atan() const { + // atan(x) = i/2 * ln((i + z)/(i - z)) + const __m512 i = _mm512_setr_ps(0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, + 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0); + const Vectorized i_half = _mm512_setr_ps(0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, + 0.0, 0.5, 0.0, 0.5, 0.0, 0.5, 0.0, 0.5); + + auto sum = Vectorized(_mm512_add_ps(i, values)); // a 1+b + auto sub = Vectorized(_mm512_sub_ps(i, values)); // -a 1-b + auto ln = (sum/sub).log(); // ln((i + z)/(i - z)) + return i_half*ln; // i/2*ln() +} + +template <> +Vectorized> inline maximum(const Vectorized>& a, + const Vectorized>& b) { + auto zero_vector = _mm512_set1_epi32(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_LT_OQ); + auto max = _mm512_mask_blend_ps(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi32(zero_vector, isnan_mask, 0xFFFFFFFF); + return _mm512_or_ps(max, _mm512_castsi512_ps(isnan)); +} + +template <> +Vectorized> inline minimum(const Vectorized>& a, + const Vectorized>& b) { + auto zero_vector = _mm512_set1_epi32(0); + auto abs_a = a.abs_2_(); + auto abs_b = b.abs_2_(); + auto mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_GT_OQ); + auto min = _mm512_mask_blend_ps(mask, a, b); + // Exploit the fact that all-ones is a NaN. + auto isnan_mask = _mm512_cmp_ps_mask(abs_a, abs_b, _CMP_UNORD_Q); + auto isnan = _mm512_mask_set1_epi32(zero_vector, isnan_mask, 0xFFFFFFFF); + return _mm512_or_ps(min, _mm512_castsi512_ps(isnan)); +} + +template <> +Vectorized> inline operator&(const Vectorized>& a, + const Vectorized>& b) { + return _mm512_and_ps(a, b); +} + +template <> +Vectorized> inline operator|(const Vectorized>& a, + const Vectorized>& b) { + return _mm512_or_ps(a, b); +} + +template <> +Vectorized> inline operator^(const Vectorized>& a, + const Vectorized>& b) { + return _mm512_xor_ps(a, b); +} + +Vectorized> Vectorized>::eq( + const Vectorized>& other) const { + return (*this == other) & Vectorized>(_mm512_set1_ps(1.0f)); +} + +Vectorized> Vectorized>::ne( + const Vectorized>& other) const { + return (*this != other) & Vectorized>(_mm512_set1_ps(1.0f)); +} + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_double.h new file mode 100644 index 0000000000000..7128219748a06 --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_double.h @@ -0,0 +1,454 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#if (defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) +#include +#endif + +namespace at { +namespace vec { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template <> class Vectorized { +private: + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; +public: + // values needs to be public for compilation with clang + // as vec512.h uses it + __m512d values; + using value_type = double; + using size_type = int; + static constexpr size_type size() { + return 8; + } + Vectorized() {} + Vectorized(__m512d v) : values(v) {} + Vectorized(double val) { + values = _mm512_set1_pd(val); + } + Vectorized(double val1, double val2, double val3, double val4, + double val5, double val6, double val7, double val8) { + values = _mm512_setr_pd(val1, val2, val3, val4, val5, val6, val7, val8); + } + operator __m512d() const { + return values; + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + return _mm512_mask_blend_pd(mask, a.values, b.values); + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto all_ones = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mmask = _mm512_cmp_epi64_mask(_mm512_castpd_si512(mask.values), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_pd(mmask, a.values, b.values); + } + template + static Vectorized arange(double base = 0., step_t step = static_cast(1)) { + return Vectorized(base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, + base + 7 * step); + } + static Vectorized set(const Vectorized& a, const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_pd(reinterpret_cast(ptr)); + + + __at_align__ double tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < size(); ++i) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(double)); + return _mm512_load_pd(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + _mm512_storeu_pd(reinterpret_cast(ptr), values); + } else if (count > 0) { + double tmp_values[size()]; + _mm512_storeu_pd(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(double)); + } + } + const double& operator[](int idx) const = delete; + double& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit + __mmask8 cmp = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_EQ_OQ); + return static_cast(cmp); + } + Vectorized isnan() const { + auto cmp_mask = _mm512_cmp_pd_mask(values, _mm512_set1_pd(0.0), _CMP_UNORD_Q); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + Vectorized map(double (*const f)(double)) const { + __at_align__ double tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm512_set1_pd(-0.f); + return _mm512_andnot_pd(mask, values); + } + Vectorized angle() const { + const auto zero_vec = _mm512_castsi512_pd(zero_vector); + const auto nan_vec = _mm512_set1_pd(NAN); + const auto not_nan_mask = _mm512_cmp_pd_mask(values, values, _CMP_EQ_OQ); + const auto not_nan = _mm512_mask_set1_epi64(zero_vector, not_nan_mask, + 0xFFFFFFFFFFFFFFFF); + const auto nan_mask = _mm512_cmp_pd_mask(_mm512_castsi512_pd(not_nan), + zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_pd(c10::pi); + + const auto neg_mask = _mm512_cmp_pd_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_pd(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_pd(nan_mask, angle, nan_vec); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_pd(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosd8_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asind8_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atand8_u10(values)); + } + Vectorized atan2(const Vectorized &b) const { + return Vectorized(Sleef_atan2d8_u10(values, b)); + } + Vectorized copysign(const Vectorized &sign) const { + return Vectorized(Sleef_copysignd8(values, sign)); + } + Vectorized erf() const { + return Vectorized(Sleef_erfd8_u10(values)); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcd8_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expd8_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1d8_u10(values)); + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodd8(values, q)); + } + Vectorized hypot(const Vectorized &b) const { + return Vectorized(Sleef_hypotd8_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized igamma(const Vectorized &x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized &x) const { + __at_align__ double tmp[size()]; + __at_align__ double tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized log() const { + return Vectorized(Sleef_logd8_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2d8_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10d8_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pd8_u10(values)); + } + Vectorized sin() const { + return Vectorized(Sleef_sind8_u10(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhd8_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosd8_u10(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshd8_u10(values)); + } + Vectorized ceil() const { + return _mm512_ceil_pd(values); + } + Vectorized floor() const { + return _mm512_floor_pd(values); + } + Vectorized frac() const; + Vectorized neg() const { + return _mm512_xor_pd(_mm512_set1_pd(-0.), values); + } + Vectorized nextafter(const Vectorized &b) const { + return Vectorized(Sleef_nextafterd8(values, b)); + } + Vectorized round() const { + return _mm512_roundscale_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tand8_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhd8_u10(values)); + } + Vectorized trunc() const { + return _mm512_roundscale_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammad8_u10(values)); + } + Vectorized sqrt() const { + return _mm512_sqrt_pd(values); + } + Vectorized reciprocal() const { + return _mm512_div_pd(_mm512_set1_pd(1), values); + } + Vectorized rsqrt() const { + return _mm512_div_pd(_mm512_set1_pd(1), _mm512_sqrt_pd(values)); + } + Vectorized pow(const Vectorized &b) const { + return Vectorized(Sleef_powd8_u10(values, b)); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator!=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_NEQ_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator<(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LT_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator<=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_LE_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator>(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GT_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized operator>=(const Vectorized& other) const { + auto cmp_mask = _mm512_cmp_pd_mask(values, other.values, _CMP_GE_OQ); + return _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vector, cmp_mask, + 0xFFFFFFFFFFFFFFFF)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm512_add_pd(a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm512_sub_pd(a, b); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return _mm512_mul_pd(a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return _mm512_div_pd(a, b); +} + +// frac. Implement this here so we can use subtraction. +Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + auto zero_vec = _mm512_set1_epi64(0); + Vectorized max = _mm512_max_pd(a, b); + auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask, + 0xFFFFFFFFFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_pd(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + auto zero_vec = _mm512_set1_epi64(0); + Vectorized min = _mm512_min_pd(a, b); + auto isnan_mask = _mm512_cmp_pd_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_pd(_mm512_mask_set1_epi64(zero_vec, isnan_mask, + 0xFFFFFFFFFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_pd(min, isnan); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) { + return _mm512_min_pd(max, _mm512_max_pd(min, a)); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) { + return _mm512_max_pd(min, a); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) { + return _mm512_min_pd(max, a); +} + +template <> +Vectorized inline operator&(const Vectorized& a, const Vectorized& b) { + return _mm512_and_pd(a, b); +} + +template <> +Vectorized inline operator|(const Vectorized& a, const Vectorized& b) { + return _mm512_or_pd(a, b); +} + +template <> +Vectorized inline operator^(const Vectorized& a, const Vectorized& b) { + return _mm512_xor_pd(a, b); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1.0); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1.0); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1.0); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1.0); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0); +} + +template <> +inline void convert(const double* src, double* dst, int64_t n) { + int64_t i; +#pragma unroll + for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { + _mm512_storeu_pd(dst + i, _mm512_loadu_pd(src + i)); + } +#pragma unroll + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { + return _mm512_fmadd_pd(a, b, c); +} + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_float.h new file mode 100644 index 0000000000000..1a2b113de9d36 --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_float.h @@ -0,0 +1,469 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) +#include +#endif + +namespace at { +namespace vec { +// See Note [Acceptable use of anonymous namespace in header] +namespace { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +template <> class Vectorized { +private: + static constexpr __m512i zero_vec {0, 0, 0, 0, 0, 0, 0, 0}; +public: + __m512 values; + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return 16; + } + Vectorized() {} + Vectorized(__m512 v) : values(v) {} + Vectorized(float val) { + values = _mm512_set1_ps(val); + } + Vectorized(float val1, float val2, float val3, float val4, + float val5, float val6, float val7, float val8, + float val9, float val10, float val11, float val12, + float val13, float val14, float val15, float val16) { + values = _mm512_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8, + val9, val10, val11, val12, val13, val14, val15, val16); + } + operator __m512() const { + return values; + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + return _mm512_mask_blend_ps(mask, a.values, b.values); + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto all_ones = _mm512_set1_epi32(0xFFFFFFFF); + auto mmask = _mm512_cmp_epi32_mask(_mm512_castps_si512(mask.values), all_ones, _MM_CMPINT_EQ); + return _mm512_mask_blend_ps(mmask, a.values, b.values); + } + template + static Vectorized arange(float base = 0.f, step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, + base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, + base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step); + } + static Vectorized set(const Vectorized& a, const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) + return _mm512_loadu_ps(reinterpret_cast(ptr)); + __at_align__ float tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < size(); ++i) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, reinterpret_cast(ptr), count * sizeof(float)); + return _mm512_loadu_ps(tmp_values); + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + _mm512_storeu_ps(reinterpret_cast(ptr), values); + } else if (count > 0) { + float tmp_values[size()]; + _mm512_storeu_ps(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(float)); + } + } + const float& operator[](int idx) const = delete; + float& operator[](int idx) = delete; + int zero_mask() const { + // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit + __mmask16 cmp = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_EQ_OQ); + return static_cast(cmp); + } + Vectorized isnan() const { + auto mask = _mm512_cmp_ps_mask(values, _mm512_set1_ps(0.0), _CMP_UNORD_Q); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + Vectorized map(float (*const f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + auto mask = _mm512_set1_ps(-0.f); + return _mm512_andnot_ps(mask, values); + } + Vectorized angle() const { + __m512 zero_vec = _mm512_set1_ps(0.f); + const auto nan_vec = _mm512_set1_ps(NAN); + const auto not_nan_mask = _mm512_cmp_ps_mask(values, values, _CMP_EQ_OQ); + const auto not_nan_vec = _mm512_mask_set1_epi32(_mm512_castps_si512(zero_vec), + not_nan_mask, 0xFFFFFFFF); + const auto nan_mask = _mm512_cmp_ps_mask(_mm512_castsi512_ps(not_nan_vec), + zero_vec, _CMP_EQ_OQ); + const auto pi = _mm512_set1_ps(c10::pi); + + const auto neg_mask = _mm512_cmp_ps_mask(values, zero_vec, _CMP_LT_OQ); + auto angle = _mm512_mask_blend_ps(neg_mask, zero_vec, pi); + angle = _mm512_mask_blend_ps(nan_mask, angle, nan_vec); + return angle; + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_ps(0); + } + Vectorized conj() const { + return *this; + } + Vectorized acos() const { + return Vectorized(Sleef_acosf16_u10(values)); + } + Vectorized asin() const { + return Vectorized(Sleef_asinf16_u10(values)); + } + Vectorized atan() const { + return Vectorized(Sleef_atanf16_u10(values)); + } + Vectorized atan2(const Vectorized &b) const { + return Vectorized(Sleef_atan2f16_u10(values, b)); + } + Vectorized copysign(const Vectorized &sign) const { + return Vectorized(Sleef_copysignf16(values, sign)); + } + Vectorized erf() const { + return Vectorized(Sleef_erff16_u10(values)); + } + Vectorized erfc() const { + return Vectorized(Sleef_erfcf16_u15(values)); + } + Vectorized erfinv() const { + return map(calc_erfinv); + } + Vectorized exp() const { + return Vectorized(Sleef_expf16_u10(values)); + } + Vectorized expm1() const { + return Vectorized(Sleef_expm1f16_u10(values)); + } + Vectorized fmod(const Vectorized& q) const { + return Vectorized(Sleef_fmodf16(values, q)); + } + Vectorized log() const { + return Vectorized(Sleef_logf16_u10(values)); + } + Vectorized log2() const { + return Vectorized(Sleef_log2f16_u10(values)); + } + Vectorized log10() const { + return Vectorized(Sleef_log10f16_u10(values)); + } + Vectorized log1p() const { + return Vectorized(Sleef_log1pf16_u10(values)); + } + Vectorized frac() const; + Vectorized sin() const { + return Vectorized(Sleef_sinf16_u10(values)); + } + Vectorized sinh() const { + return Vectorized(Sleef_sinhf16_u10(values)); + } + Vectorized cos() const { + return Vectorized(Sleef_cosf16_u10(values)); + } + Vectorized cosh() const { + return Vectorized(Sleef_coshf16_u10(values)); + } + Vectorized ceil() const { + return _mm512_ceil_ps(values); + } + Vectorized floor() const { + return _mm512_floor_ps(values); + } + Vectorized hypot(const Vectorized &b) const { + return Vectorized(Sleef_hypotf16_u05(values, b)); + } + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized igamma(const Vectorized &x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igamma(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized igammac(const Vectorized &x) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_x[size()]; + store(tmp); + x.store(tmp_x); + for (int64_t i = 0; i < size(); i++) { + tmp[i] = calc_igammac(tmp[i], tmp_x[i]); + } + return loadu(tmp); + } + Vectorized neg() const { + return _mm512_xor_ps(_mm512_set1_ps(-0.f), values); + } + Vectorized nextafter(const Vectorized &b) const { + return Vectorized(Sleef_nextafterf16(values, b)); + } + Vectorized round() const { + return _mm512_roundscale_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + Vectorized tan() const { + return Vectorized(Sleef_tanf16_u10(values)); + } + Vectorized tanh() const { + return Vectorized(Sleef_tanhf16_u10(values)); + } + Vectorized trunc() const { + return _mm512_roundscale_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); + } + Vectorized lgamma() const { + return Vectorized(Sleef_lgammaf16_u10(values)); + } + Vectorized sqrt() const { + return _mm512_sqrt_ps(values); + } + Vectorized reciprocal() const { + return _mm512_div_ps(_mm512_set1_ps(1), values); + } + Vectorized rsqrt() const { + return _mm512_div_ps(_mm512_set1_ps(1), _mm512_sqrt_ps(values)); + } + Vectorized pow(const Vectorized &b) const { + return Vectorized(Sleef_powf16_u10(values, b)); + } + // Comparison using the _CMP_**_OQ predicate. + // `O`: get false if an operand is NaN + // `Q`: do not raise if an operand is NaN + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_EQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_NEQ_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LT_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_LE_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GT_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmp_ps_mask(values, other.values, _CMP_GE_OQ); + return _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, mask, + 0xFFFFFFFF)); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm512_add_ps(a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm512_sub_ps(a, b); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return _mm512_mul_ps(a, b); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return _mm512_div_ps(a, b); +} + +// frac. Implement this here so we can use subtraction +Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + auto zero_vec = _mm512_set1_epi32(0); + auto max = _mm512_max_ps(a, b); + auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask, + 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_ps(max, isnan); +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + auto zero_vec = _mm512_set1_epi32(0); + auto min = _mm512_min_ps(a, b); + auto isnan_mask = _mm512_cmp_ps_mask(a, b, _CMP_UNORD_Q); + auto isnan = _mm512_castsi512_ps(_mm512_mask_set1_epi32(zero_vec, isnan_mask, + 0xFFFFFFFF)); + // Exploit the fact that all-ones is a NaN. + return _mm512_or_ps(min, isnan); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) { + return _mm512_min_ps(max, _mm512_max_ps(min, a)); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) { + return _mm512_min_ps(max, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) { + return _mm512_max_ps(min, a); +} + +template <> +Vectorized inline operator&(const Vectorized& a, const Vectorized& b) { + return _mm512_and_ps(a, b); +} + +template <> +Vectorized inline operator|(const Vectorized& a, const Vectorized& b) { + return _mm512_or_ps(a, b); +} + +template <> +Vectorized inline operator^(const Vectorized& a, const Vectorized& b) { + return _mm512_xor_ps(a, b); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, float* dst, int64_t n) { + int64_t i; +#pragma unroll + for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { + _mm512_storeu_ps(dst + i, _mm512_loadu_ps(src + i)); + } +#pragma unroll + for (; i < n; i++) { + dst[i] = src[i]; + } +} + +template <> +Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { + return _mm512_fmadd_ps(a, b, c); +} + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_int.h b/aten/src/ATen/cpu/vec/vec512/vec512_int.h new file mode 100644 index 0000000000000..cc866c065bfba --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_int.h @@ -0,0 +1,1173 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +namespace at { +namespace vec { +namespace { + +#ifdef CPU_CAPABILITY_AVX512 + +struct Vectorizedi { +protected: + __m512i values; + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; + static inline __m512i invert(const __m512i& v) { + const auto ones = _mm512_set1_epi64(-1); + return _mm512_xor_si512(ones, v); + } +public: + Vectorizedi() {} + Vectorizedi(__m512i v) : values(v) {} + operator __m512i() const { + return values; + } +}; + +#else + +struct Vectorizedi {}; // dummy definition to make Vectorizedi always defined + +#endif // CPU_CAPABILITY_AVX512 + +#ifdef CPU_CAPABILITY_AVX512 + +template <> +class Vectorized : public Vectorizedi { +private: + static const Vectorized ones; +public: + using value_type = int64_t; + using size_type = int; + static constexpr size_type size() { + return 8; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int64_t v) { values = _mm512_set1_epi64(v); } + Vectorized(int64_t val1, int64_t val2, int64_t val3, int64_t val4, + int64_t val5, int64_t val6, int64_t val7, int64_t val8) { + values = _mm512_setr_epi64(val1, val2, val3, val4, + val5, val6, val7, val8); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + return _mm512_mask_blend_epi64(mask, a.values, b.values); + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi64(0xFFFFFFFFFFFFFFFF); + auto mask_ = _mm512_cmp_epi64_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi64(mask_, a.values, b.values); + } + template + static Vectorized arange(int64_t base = 0, step_t step = static_cast(1)) { + return Vectorized(base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step); + } + static Vectorized + set(Vectorized a, Vectorized b, int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int64_t count) { + __at_align__ int64_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < size(); ++i) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int64_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __at_align__ int64_t tmp_values[size()]; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int64_t)); + } + } + const int64_t& operator[](int idx) const = delete; + int64_t& operator[](int idx) = delete; + Vectorized abs() const { + auto is_larger_mask = _mm512_cmpgt_epi64_mask(zero_vector, values); + auto is_larger = _mm512_mask_set1_epi64(zero_vector, is_larger_mask, 0xFFFFFFFFFFFFFFFF); + auto inverse = _mm512_xor_si512(values, is_larger); + return _mm512_sub_epi64(inverse, is_larger); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi64(0); + } + Vectorized conj() const { + return *this; + } + Vectorized frac() const; + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi64_mask(values, other.values); + return _mm512_mask_set1_epi64(zero_vector, mask, 0xFFFFFFFFFFFFFFFF); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +class Vectorized : public Vectorizedi { +private: + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; + static const Vectorized ones; +public: + using value_type = int32_t; + static constexpr int size() { + return 16; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int32_t v) { values = _mm512_set1_epi32(v); } + Vectorized(int32_t val1, int32_t val2, int32_t val3, int32_t val4, + int32_t val5, int32_t val6, int32_t val7, int32_t val8, + int32_t val9, int32_t val10, int32_t val11, int32_t val12, + int32_t val13, int32_t val14, int32_t val15, int32_t val16) { + values = _mm512_setr_epi32(val1, val2, val3, val4, val5, val6, val7, val8, + val9, val10, val11, val12, val13, val14, val15, val16); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + return _mm512_mask_blend_epi32(mask, a.values, b.values); + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi32(0xFFFFFFFF); + auto mask_ = _mm512_cmp_epi32_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi32(mask_, a.values, b.values); + } + template + static Vectorized arange(int32_t base = 0, step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, + base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, + base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step); + } + static Vectorized + set(Vectorized a, Vectorized b, int32_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<1>(a, b); + case 2: + return blend<3>(a, b); + case 3: + return blend<7>(a, b); + case 4: + return blend<15>(a, b); + case 5: + return blend<31>(a, b); + case 6: + return blend<63>(a, b); + case 7: + return blend<127>(a, b); + case 8: + return blend<255>(a, b); + case 9: + return blend<511>(a, b); + case 10: + return blend<1023>(a, b); + case 11: + return blend<2047>(a, b); + case 12: + return blend<4095>(a, b); + case 13: + return blend<8191>(a, b); + case 14: + return blend<16383>(a, b); + case 15: + return blend<32767>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int32_t count) { + __at_align__ int32_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < size(); ++i) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int32_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __at_align__ int32_t tmp_values[size()]; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int32_t)); + } + } + void dump() const { + for (size_t i = 0; i < size(); ++i) { + std::cout << (int)((value_type*)&values)[i] << " "; + } + std::cout << std::endl; + } + const int32_t& operator[](int idx) const = delete; + int32_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm512_abs_epi32(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi32(0); + } + Vectorized conj() const { + return *this; + } + Vectorized frac() const; + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi32_mask(values, other.values); + return _mm512_mask_set1_epi32(zero_vector, mask, 0xFFFFFFFF); + } + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +inline void convert(const int32_t *src, float *dst, int64_t n) { + int64_t i; + // int32_t and float have same size +#ifndef _MSC_VER +# pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { + auto input_vec = _mm512_loadu_si512(reinterpret_cast(src + i)); + auto output_vec = _mm512_cvtepi32_ps(input_vec); + _mm512_storeu_ps(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +# pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int32_t *src, double *dst, int64_t n) { + int64_t i; + // int32_t has half the size of double +#ifndef _MSC_VER +# pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { + auto input_256_vec = _mm256_loadu_si256(reinterpret_cast(src + i)); + auto output_vec = _mm512_cvtepi32_pd(input_256_vec); + _mm512_storeu_pd(reinterpret_cast(dst + i), output_vec); + } +#ifndef _MSC_VER +# pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +class Vectorized : public Vectorizedi { +private: + static const Vectorized ones; + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; +public: + using value_type = int16_t; + static constexpr int size() { + return 32; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int16_t v) { values = _mm512_set1_epi16(v); } + Vectorized(int16_t val1, int16_t val2, int16_t val3, int16_t val4, + int16_t val5, int16_t val6, int16_t val7, int16_t val8, + int16_t val9, int16_t val10, int16_t val11, int16_t val12, + int16_t val13, int16_t val14, int16_t val15, int16_t val16, + int16_t val17, int16_t val18, int16_t val19, int16_t val20, + int16_t val21, int16_t val22, int16_t val23, int16_t val24, + int16_t val25, int16_t val26, int16_t val27, int16_t val28, + int16_t val29, int16_t val30, int16_t val31, int16_t val32) { + values = _mm512_set_epi16(val32, val31, val30, val29, val28, val27, val26, val25, + val24, val23, val22, val21, val20, val19, val18, val17, + val16, val15, val14, val13, val12, val11, val10, val9, + val8, val7, val6, val5, val4, val3, val2, val1); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + return _mm512_mask_blend_epi16(mask, a.values, b.values); + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi16(0xFFFF); + auto mask_ = _mm512_cmp_epi16_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi16(mask_, a.values, b.values); + } + template + static Vectorized arange(int16_t base = 0, step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, + base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, + base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step, + base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step, + base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step, + base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step, + base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step + ); + } + static Vectorized + set(Vectorized a, Vectorized b, int16_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<0x1>(a, b); + case 2: + return blend<0x3>(a, b); + case 3: + return blend<0x7>(a, b); + case 4: + return blend<0xF>(a, b); + case 5: + return blend<0x1F>(a, b); + case 6: + return blend<0x3F>(a, b); + case 7: + return blend<0x7F>(a, b); + case 8: + return blend<0xFF>(a, b); + case 9: + return blend<0x1FF>(a, b); + case 10: + return blend<0x3FF>(a, b); + case 11: + return blend<0x7FF>(a, b); + case 12: + return blend<0xFFF>(a, b); + case 13: + return blend<0x1FFF>(a, b); + case 14: + return blend<0x3FFF>(a, b); + case 15: + return blend<0x7FFF>(a, b); + case 16: + return blend<0xFFFF>(a, b); + case 17: + return blend<0x1FFFF>(a, b); + case 18: + return blend<0x3FFFF>(a, b); + case 19: + return blend<0x7FFFF>(a, b); + case 20: + return blend<0xFFFFF>(a, b); + case 21: + return blend<0x1FFFFF>(a, b); + case 22: + return blend<0x3FFFFF>(a, b); + case 23: + return blend<0x7FFFFF>(a, b); + case 24: + return blend<0xFFFFFF>(a, b); + case 25: + return blend<0x1FFFFFF>(a, b); + case 26: + return blend<0x3FFFFFF>(a, b); + case 27: + return blend<0x7FFFFFF>(a, b); + case 28: + return blend<0xFFFFFFF>(a, b); + case 29: + return blend<0x1FFFFFFF>(a, b); + case 30: + return blend<0x3FFFFFFF>(a, b); + case 31: + return blend<0x7FFFFFFF>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int16_t count) { + __at_align__ int16_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (auto i = 0; i < size(); ++i) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __at_align__ int16_t tmp_values[size()]; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); + } + } + const int16_t& operator[](int idx) const = delete; + int16_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm512_abs_epi16(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi16(0); + } + Vectorized conj() const { + return *this; + } + Vectorized frac() const; + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi16_mask(values, other.values); + return _mm512_mask_set1_epi16(zero_vector, mask, 0xFFFF); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +class Vectorized : public Vectorizedi { +private: + static constexpr __m512i zero_vector {0, 0, 0, 0, 0, 0, 0, 0}; + static const Vectorized ones; +public: + using value_type = int8_t; + static constexpr int size() { + return 64; + } + using Vectorizedi::Vectorizedi; + Vectorized() {} + Vectorized(int8_t v) { values = _mm512_set1_epi8(v); } + Vectorized(int8_t val1, int8_t val2, int8_t val3, int8_t val4, + int8_t val5, int8_t val6, int8_t val7, int8_t val8, + int8_t val9, int8_t val10, int8_t val11, int8_t val12, + int8_t val13, int8_t val14, int8_t val15, int8_t val16, + int8_t val17, int8_t val18, int8_t val19, int8_t val20, + int8_t val21, int8_t val22, int8_t val23, int8_t val24, + int8_t val25, int8_t val26, int8_t val27, int8_t val28, + int8_t val29, int8_t val30, int8_t val31, int8_t val32, + int8_t val33, int8_t val34, int8_t val35, int8_t val36, + int8_t val37, int8_t val38, int8_t val39, int8_t val40, + int8_t val41, int8_t val42, int8_t val43, int8_t val44, + int8_t val45, int8_t val46, int8_t val47, int8_t val48, + int8_t val49, int8_t val50, int8_t val51, int8_t val52, + int8_t val53, int8_t val54, int8_t val55, int8_t val56, + int8_t val57, int8_t val58, int8_t val59, int8_t val60, + int8_t val61, int8_t val62, int8_t val63, int8_t val64){ + values = _mm512_set_epi8(val64, val63, val62, val61, val60, val59, val58, val57, + val56, val55, val54, val53,val52, val51, val50, val49, + val48, val47, val46, val45, val44, val43, val42, val41, + val40, val39, val38, val37, val36, val35, val34, val33, + val32, val31, val30, val29, val28, val27, val26, val25, + val24, val23, val22, val21, val20, val19, val18, val17, + val16, val15, val14, val13, val12, val11, val10, val9, + val8, val7, val6, val5, val4, val3, val2, val1); + } + template + static Vectorized blend(Vectorized a, Vectorized b) { + return _mm512_mask_blend_epi8(mask, a.values, b.values); + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi8(0xFF); + auto mask_ = _mm512_cmp_epi8_mask(mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi8(mask_, a.values, b.values); + } + template + static Vectorized arange(int8_t base = 0, step_t step = static_cast(1)) { + return Vectorized( + base, base + step, base + 2 * step, base + 3 * step, + base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, + base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, + base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step, + base + 16 * step, base + 17 * step, base + 18 * step, base + 19 * step, + base + 20 * step, base + 21 * step, base + 22 * step, base + 23 * step, + base + 24 * step, base + 25 * step, base + 26 * step, base + 27 * step, + base + 28 * step, base + 29 * step, base + 30 * step, base + 31 * step, + base + 32 * step, base + 33 * step, base + 34 * step, base + 35 * step, + base + 36 * step, base + 37 * step, base + 38 * step, base + 39 * step, + base + 40 * step, base + 41 * step, base + 42 * step, base + 43 * step, + base + 44 * step, base + 45 * step, base + 46 * step, base + 47 * step, + base + 48 * step, base + 49 * step, base + 50 * step, base + 51 * step, + base + 52 * step, base + 53 * step, base + 54 * step, base + 55 * step, + base + 56 * step, base + 57 * step, base + 58 * step, base + 59 * step, + base + 60 * step, base + 61 * step, base + 62 * step, base + 63 * step); + } + static Vectorized + set(Vectorized a, Vectorized b, int8_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + return blend<0x1>(a, b); + case 2: + return blend<0x3>(a, b); + case 3: + return blend<0x7>(a, b); + case 4: + return blend<0xF>(a, b); + case 5: + return blend<0x1F>(a, b); + case 6: + return blend<0x3F>(a, b); + case 7: + return blend<0x7F>(a, b); + case 8: + return blend<0xFF>(a, b); + case 9: + return blend<0x1FF>(a, b); + case 10: + return blend<0x3FF>(a, b); + case 11: + return blend<0x7FF>(a, b); + case 12: + return blend<0xFFF>(a, b); + case 13: + return blend<0x1FFF>(a, b); + case 14: + return blend<0x3FFF>(a, b); + case 15: + return blend<0x7FFF>(a, b); + case 16: + return blend<0xFFFF>(a, b); + case 17: + return blend<0x1FFFF>(a, b); + case 18: + return blend<0x3FFFF>(a, b); + case 19: + return blend<0x7FFFF>(a, b); + case 20: + return blend<0xFFFFF>(a, b); + case 21: + return blend<0x1FFFFF>(a, b); + case 22: + return blend<0x3FFFFF>(a, b); + case 23: + return blend<0x7FFFFF>(a, b); + case 24: + return blend<0xFFFFFF>(a, b); + case 25: + return blend<0x1FFFFFF>(a, b); + case 26: + return blend<0x3FFFFFF>(a, b); + case 27: + return blend<0x7FFFFFF>(a, b); + case 28: + return blend<0xFFFFFFF>(a, b); + case 29: + return blend<0x1FFFFFFF>(a, b); + case 30: + return blend<0x3FFFFFFF>(a, b); + case 31: + return blend<0x7FFFFFFF>(a, b); + case 32: + return blend<0xFFFFFFFF>(a, b); + case 33: + return blend<0x1FFFFFFFF>(a, b); + case 34: + return blend<0x3FFFFFFFF>(a, b); + case 35: + return blend<0x7FFFFFFFF>(a, b); + case 36: + return blend<0xFFFFFFFFF>(a, b); + case 37: + return blend<0x1FFFFFFFFF>(a, b); + case 38: + return blend<0x3FFFFFFFFF>(a, b); + case 39: + return blend<0x7FFFFFFFFF>(a, b); + case 40: + return blend<0xFFFFFFFFFF>(a, b); + case 41: + return blend<0x1FFFFFFFFFF>(a, b); + case 42: + return blend<0x3FFFFFFFFFF>(a, b); + case 43: + return blend<0x7FFFFFFFFFF>(a, b); + case 44: + return blend<0xFFFFFFFFFFF>(a, b); + case 45: + return blend<0x1FFFFFFFFFFF>(a, b); + case 46: + return blend<0x3FFFFFFFFFFF>(a, b); + case 47: + return blend<0x7FFFFFFFFFFF>(a, b); + case 48: + return blend<0xFFFFFFFFFFFF>(a, b); + case 49: + return blend<0x1FFFFFFFFFFFF>(a, b); + case 50: + return blend<0x3FFFFFFFFFFFF>(a, b); + case 51: + return blend<0x7FFFFFFFFFFFF>(a, b); + case 52: + return blend<0xFFFFFFFFFFFFF>(a, b); + case 53: + return blend<0x1FFFFFFFFFFFFF>(a, b); + case 54: + return blend<0x3FFFFFFFFFFFFF>(a, b); + case 55: + return blend<0x7FFFFFFFFFFFFF>(a, b); + case 56: + return blend<0xFFFFFFFFFFFFFF>(a, b); + case 57: + return blend<0x1FFFFFFFFFFFFFF>(a, b); + case 58: + return blend<0x3FFFFFFFFFFFFFF>(a, b); + case 59: + return blend<0x7FFFFFFFFFFFFFF>(a, b); + case 60: + return blend<0xFFFFFFFFFFFFFFF>(a, b); + case 61: + return blend<0x1FFFFFFFFFFFFFFF>(a, b); + case 62: + return blend<0x3FFFFFFFFFFFFFFF>(a, b); + case 63: + return blend<0x7FFFFFFFFFFFFFFF>(a, b); + } + return b; + } + static Vectorized loadu(const void* ptr) { + return _mm512_loadu_si512(reinterpret_cast(ptr)); + } + static Vectorized loadu(const void* ptr, int8_t count) { + __at_align__ int8_t tmp_values[size()]; + // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 + // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two + // instructions while a loop would be compiled to one instruction. + for (size_t i = 0; i < size(); ++i) { + tmp_values[i] = 0; + } + std::memcpy(tmp_values, ptr, count * sizeof(int8_t)); + return loadu(tmp_values); + } + void store(void* ptr, int count = size()) const { + if (count == size()) { + // ptr need not to be aligned here. See + // https://software.intel.com/content/www/us/en/develop/documentation/cpp-compiler-developer-guide-and-reference/top/compiler-reference/intrinsics/intrinsics-for-intel-advanced-vector-extensions/intrinsics-for-load-and-store-operations-1/mm512-storeu-si512.html + _mm512_storeu_si512(reinterpret_cast<__m512i*>(ptr), values); + } else if (count > 0) { + __at_align__ int8_t tmp_values[size()]; + _mm512_storeu_si512(reinterpret_cast<__m512i*>(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(int8_t)); + } + } + const int8_t& operator[](int idx) const = delete; + int8_t& operator[](int idx) = delete; + Vectorized abs() const { + return _mm512_abs_epi8(values); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return _mm512_set1_epi8(0); + } + Vectorized conj() const { + return *this; + } + Vectorized frac() const; + Vectorized neg() const; + Vectorized operator==(const Vectorized& other) const { + auto mask = _mm512_cmpeq_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator!=(const Vectorized& other) const { + auto mask = _mm512_cmpneq_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<(const Vectorized& other) const { + auto mask = _mm512_cmplt_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator<=(const Vectorized& other) const { + auto mask = _mm512_cmple_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator>(const Vectorized& other) const { + auto mask = _mm512_cmpgt_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + Vectorized operator>=(const Vectorized& other) const { + auto mask = _mm512_cmpge_epi8_mask(values, other.values); + return _mm512_mask_set1_epi8(zero_vector, mask, 0xFF); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm512_add_epi64(a, b); +} + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm512_add_epi32(a, b); +} + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm512_add_epi16(a, b); +} + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return _mm512_add_epi8(a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm512_sub_epi64(a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm512_sub_epi32(a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm512_sub_epi16(a, b); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return _mm512_sub_epi8(a, b); +} + +// Negation. Defined here so we can utilize operator- +Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +Vectorized Vectorized::neg() const { + return Vectorized(0) - *this; +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return _mm512_mullo_epi64(a, b); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return _mm512_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return _mm512_mullo_epi16(a, b); +} + +template +Vectorized inline int_elementwise_binary_512(const Vectorized& a, const Vectorized& b, Op op) { + T values_a[Vectorized::size()]; + T values_b[Vectorized::size()]; + a.store(values_a); + b.store(values_b); + for (int i = 0; i != Vectorized::size(); i++) { + values_a[i] = op(values_a[i], values_b[i]); + } + return Vectorized::loadu(values_a); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + // We don't have an instruction for multiplying int8_t + return int_elementwise_binary_512(a, b, std::multiplies()); +} + +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return _mm512_min_epi64(a, b); +} + +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return _mm512_min_epi32(a, b); +} + +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return _mm512_min_epi16(a, b); +} + +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return _mm512_min_epi8(a, b); +} + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return _mm512_max_epi64(a, b); +} + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return _mm512_max_epi32(a, b); +} + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return _mm512_max_epi16(a, b); +} + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return _mm512_max_epi8(a, b); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) { + return _mm512_min_epi64(max_val, _mm512_max_epi64(a, min_val)); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) { + return _mm512_min_epi32(max_val, _mm512_max_epi32(a, min_val)); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) { + return _mm512_min_epi16(max_val, _mm512_max_epi16(a, min_val)); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min_val, const Vectorized& max_val) { + return _mm512_min_epi8(max_val, _mm512_max_epi8(a, min_val)); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) { + return _mm512_min_epi64(max_val, a); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) { + return _mm512_min_epi32(max_val, a); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) { + return _mm512_min_epi16(max_val, a); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max_val) { + return _mm512_min_epi8(max_val, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) { + return _mm512_max_epi64(min_val, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) { + return _mm512_max_epi32(min_val, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) { + return _mm512_max_epi16(min_val, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min_val) { + return _mm512_max_epi8(min_val, a); +} + +template +Vectorized inline convert_to_int32(const T* ptr) { + return Vectorized::loadu(ptr); +} + +template<> +Vectorized inline convert_to_int32(const int8_t* ptr) { + return _mm512_cvtepi8_epi32(_mm_loadu_si128(reinterpret_cast(ptr))); +} + +template<> +Vectorized inline convert_to_int32(const uint8_t* ptr) { + return _mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast(ptr))); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return int_elementwise_binary_512(a, b, std::divides()); +} + +template>::value, int> = 0> +inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { + return _mm512_and_si512(a, b); +} +template>::value, int> = 0> +inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { + return _mm512_or_si512(a, b); +} +template>::value, int> = 0> +inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { + return _mm512_xor_si512(a, b); +} +template>::value, int> = 0> +inline Vectorized operator~(const Vectorized& a) { + return _mm512_xor_si512(a, _mm512_set1_epi32(-1)); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1); +} + +Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1); +} + +Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1); +} + +Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1); +} + +Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1); +} + +Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1); +} + +#endif + +}}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h new file mode 100644 index 0000000000000..5b5ac195f3caa --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h @@ -0,0 +1,1195 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include +#include +#include +#include + +#include + +// This file defines Vectorized<> for the quantized types. +// +// +// Currently, we simply use these classes as efficient converters between +// the quantized types and Vectorized, usually in bandwidth-bound cases +// where doing the arithmetic in full-precision is acceptable (e.g. +// elementwise operators). +// +// +// Conversions are as follows: +// Vectorized -> 4x Vectorized +// Vectorized -> 4x Vectorized +// Vectorized -> 1x Vectorized +// +// The size of the returned float vector is specified by the special +// constexpr function float_num_vecs. The type of the value returned +// from dequantize (and expected as an argument to quantize) is +// specified by float_vec_return_type. +// +// When writing kernels with these vectors, it is expected that floating- +// point operations will be carried out in a loop over Vectorized::float_num_vecs +// iterations. + +namespace at { +namespace vec { +namespace { + +#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) + +struct Vectorizedqi { + protected: + __m512i vals __attribute__((aligned(64))); + + public: + Vectorizedqi() {} + Vectorizedqi(__m512i v) : vals(v) {} + operator __m512i() const { + return vals; + } +}; + + +template +__m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + T min_val, + T max_val); + +template <> +__m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + int32_t min_val, + int32_t max_val) { + // This function is for linkage only, will not be used + AT_ERROR("pack_saturate_and_clamp is not supported"); +} + +template <> +__m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + int8_t min_val, + int8_t max_val) { + __m512i packed_and_sat = _mm512_packs_epi16(first, second); + return _mm512_max_epi8( + _mm512_set1_epi8(min_val), + _mm512_min_epi8(packed_and_sat, _mm512_set1_epi8(max_val))); +} + +template <> +__m512i pack_saturate_and_clamp( + __m512i first, + __m512i second, + uint8_t min_val, + uint8_t max_val) { + __m512i packed_and_sat = _mm512_packus_epi16(first, second); + return _mm512_max_epu8( + _mm512_set1_epi8(min_val), + _mm512_min_epu8(packed_and_sat, _mm512_set1_epi8(max_val))); +} + + +template +inline void __attribute__((always_inline)) QuantizeAvx512( + const float* src, + typename T::underlying* dst, + int len, + float inverse_scale, + int64_t zero_point) { + constexpr int VLEN = 16; + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + const __m512i min_v = _mm512_set1_epi32(min_val); + const __m512i max_v = _mm512_set1_epi32(max_val); + // This is the largest int32 value < int32_max exactly representable in float + constexpr int32_t int32_float_max_val = + std::numeric_limits::max() - 127; + int i = 0; + __m512 inverse_scale_v = _mm512_set1_ps(inverse_scale); + // clang-format off + static const __m512i shuffle_mask_v = _mm512_set_epi8( + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00); + // clang-format on + __m512i permute_mask_v = + _mm512_set_epi32(0x0f, 0x0b, 0x07, 0x03, 0x0e, 0x0a, 0x06, 0x02, + 0x0d, 0x09, 0x05, 0x01, 0x0c, 0x08, 0x04, 0x00); + __m512i permute_mask_l8_v = + _mm512_set_epi32(0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x0c, 0x08, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); + int len_aligned = len / (VLEN * 4) * (VLEN * 4); + for (; i < len_aligned; i += 4 * VLEN) { + // x + __m512 x_vals = _mm512_load_ps(src + i); + __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v); + // If the floating point value is greater than int32_max, + // _mm512_cvtps_epi32 converts them to -ve. Clip at int32_float_max_val to + // Clip at int32_float_max_val to avoid this. + x_transformed_v = + _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // y + __m512 y_vals = _mm512_load_ps(src + i + VLEN); + __m512 y_transformed_v = _mm512_mul_ps(y_vals, inverse_scale_v); + y_transformed_v = + _mm512_min_ps(y_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // z + __m512 z_vals = _mm512_load_ps(src + i + 2 * VLEN); + __m512 z_transformed_v = _mm512_mul_ps(z_vals, inverse_scale_v); + z_transformed_v = + _mm512_min_ps(z_transformed_v, _mm512_set1_ps(int32_float_max_val)); + // w + __m512 w_vals = _mm512_load_ps(src + i + 3 * VLEN); + __m512 w_transformed_v = _mm512_mul_ps(w_vals, inverse_scale_v); + w_transformed_v = + _mm512_min_ps(w_transformed_v, _mm512_set1_ps(int32_float_max_val)); + + __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v); + __m512i y_rounded_v = _mm512_cvtps_epi32(y_transformed_v); + __m512i z_rounded_v = _mm512_cvtps_epi32(z_transformed_v); + __m512i w_rounded_v = _mm512_cvtps_epi32(w_transformed_v); + + // add zero point + x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point)); + y_rounded_v = _mm512_add_epi32(y_rounded_v, _mm512_set1_epi32(zero_point)); + z_rounded_v = _mm512_add_epi32(z_rounded_v, _mm512_set1_epi32(zero_point)); + w_rounded_v = _mm512_add_epi32(w_rounded_v, _mm512_set1_epi32(zero_point)); + + __m512i xy_packed_v = _mm512_packs_epi32(x_rounded_v, y_rounded_v); + __m512i zw_packed_v = _mm512_packs_epi32(z_rounded_v, w_rounded_v); + __m512i xyzw_clamped_v = pack_saturate_and_clamp( + xy_packed_v, zw_packed_v, min_val, max_val); + + xyzw_clamped_v = + _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); + _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + i), xyzw_clamped_v); + } + + // Additional 8-lane AVX512 version to take advantage when len is smaller + // based on fbgemm::QuantizeAvx2 (https://github.com/pytorch/FBGEMM) + for (; i < len / VLEN * VLEN; i += VLEN) { + __m512 x_vals = _mm512_load_ps(src + i); + __m512 x_transformed_v = _mm512_mul_ps(x_vals, inverse_scale_v); + x_transformed_v = + _mm512_min_ps(x_transformed_v, _mm512_set1_ps(int32_float_max_val)); + __m512i x_rounded_v = _mm512_cvtps_epi32(x_transformed_v); + x_rounded_v = _mm512_add_epi32(x_rounded_v, _mm512_set1_epi32(zero_point)); + __m512i x_clipped_v = + _mm512_max_epi32(min_v, _mm512_min_epi32(max_v, x_rounded_v)); + + x_clipped_v = _mm512_shuffle_epi8(x_clipped_v, shuffle_mask_v); + x_clipped_v = _mm512_permutexvar_epi32(permute_mask_l8_v, x_clipped_v); + _mm_storeu_si128( + reinterpret_cast<__m128i*>(dst + i), + _mm512_castsi512_si128(x_clipped_v)); + } + + for (; i < len; ++i) { + float transformed = src[i] * inverse_scale; + + // Not exactly the same behavior as the vectorized code. + // The vectorized code above always rounds to even in halfway cases + // (https://software.intel.com/en-us/node/523819), but std::nearbyint + // does the same only when the current rounding mode is FE_TONEAREST. + // However, in practice, this should not be a problem because most cases + // use the default rounding mode FE_TONEAREST. + // Note that we cannot implement the same behavior as the vectorized code + // using std::round because it does rounding away from zero in halfway + // cases. + transformed = zero_point + nearbyint(transformed); + float clipped = + std::min(std::max(transformed, float(min_val)), float(max_val)); + dst[i] = clipped; + } +} + +template<> +struct Vectorized : public Vectorizedqi { + using size_type = int; + static constexpr size_type size() { + return 16; + } + + static constexpr int float_num_vecs() { + return 1; + } + + static constexpr int int_num_vecs() { + return 1; + } + + using float_vec_return_type = std::array, 1>; + using int_vec_return_type = std::array, 1>; + using value_type = c10::qint32::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m512i vals_) { vals = vals_;} + + // Broadcast constructor + Vectorized(const c10::qint32& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi32(uw); + } + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + __m512 float_vals = _mm512_cvtepi32_ps(vals); + return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + Vectorized retval; + auto rhs_data = (__m512)rhs[0]; + at::native::quantize_vec( + scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 16); + return retval; + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epi32(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epi32(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epi32( + _mm512_max_epi32(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + return {_mm512_sub_epi32(vals, b)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + + __m512 scaled = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[0]), multiplier_v); + __m512i rounded = _mm512_cvtps_epi32(scaled); + return _mm512_add_epi32(rounded, zero_point_v); + } + + void dump() const { + for (size_t i = 0; i < 16; ++i) { + std::cout << ((int32_t*)&vals)[i] << " "; + } + std::cout << std::endl; + } + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + return _mm512_mullo_epi32(a, b); +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + return _mm512_add_epi32(a, b); +} + +/* + * Convert values from int32 back to int8/uint8 + */ +template +__m512i RequantizeAvx512( + const std::array, 4>& inp, + __m512 multiplier, + __m512i zp) { + static_assert( + std::is_same::value || std::is_same::value, + "Only int8_t/uint8_t are supported"); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + __m512i permute_mask_v = + _mm512_set_epi32(0x0f, 0x0b, 0x07, 0x03, 0x0e, 0x0a, 0x06, 0x02, + 0x0d, 0x09, 0x05, 0x01, 0x0c, 0x08, 0x04, 0x00); + __m512 x_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[0]), multiplier); + __m512 y_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[1]), multiplier); + __m512 z_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[2]), multiplier); + __m512 w_scaled_v = _mm512_mul_ps(_mm512_cvtepi32_ps(inp[3]), multiplier); + + __m512i x_rounded_v = _mm512_cvtps_epi32(x_scaled_v); + __m512i y_rounded_v = _mm512_cvtps_epi32(y_scaled_v); + __m512i z_rounded_v = _mm512_cvtps_epi32(z_scaled_v); + __m512i w_rounded_v = _mm512_cvtps_epi32(w_scaled_v); + + /* Add zero point */ + __m512i x_v = _mm512_add_epi32(x_rounded_v, zp); + __m512i y_v = _mm512_add_epi32(y_rounded_v, zp); + __m512i z_v = _mm512_add_epi32(z_rounded_v, zp); + __m512i w_v = _mm512_add_epi32(w_rounded_v, zp); + + /* Pack to int16_t and saturate */ + __m512i xy_packed_v = _mm512_packs_epi32(x_v, y_v); + __m512i zw_packed_v = _mm512_packs_epi32(z_v, w_v); + + __m512i xyzw_clamped_v = + pack_saturate_and_clamp(xy_packed_v, zw_packed_v, min_val, max_val); + + /* + * xyzw_clamped_v has results in the following layout so we need to + * permute: x0-3 y0-3 z0-3 w0-3 x4-7 y4-7 z4-7 w4-7 x8-11 y8-11 z8-11 w8-11 x12-15 y12-15 z12-15 w12-15 + */ + xyzw_clamped_v = _mm512_permutexvar_epi32(permute_mask_v, xyzw_clamped_v); + return xyzw_clamped_v; +} + +template<> +struct Vectorized : public Vectorizedqi { + static constexpr int size() { + return 64; + } + + static constexpr int float_num_vecs() { + return 4; + } + + static constexpr int int_num_vecs() { + return 4; + } + + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::qint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + + Vectorized() {} + Vectorized(__m512i vals_) { vals = vals_;} + + // Broadcast constructor + Vectorized(const c10::qint8& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi8(uw); + } + + // This is needed because the compiler emits awful code for the default + // constructor for moving the enum + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) { } + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + private: + __m512i cvtepi8_epi32(__m128i epi8_vals) const { + return _mm512_cvtepi8_epi32(epi8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_neg_zp_premul) const { + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_neg_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_neg_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_neg_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_neg_zp_premul); + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + int8_t quantized_values[64]; + QuantizeAvx512( + rhs_data, quantized_values, 64, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epi8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epi8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epi8( + _mm512_max_epi8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); + + __m512i int32_val0 = cvtepi8_epi32(int_val0); + __m512i int32_val1 = cvtepi8_epi32(int_val1); + __m512i int32_val2 = cvtepi8_epi32(int_val2); + __m512i int32_val3 = cvtepi8_epi32(int_val3); + + __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]); + + __m512i int32_b0 = cvtepi8_epi32(int_b0); + __m512i int32_b1 = cvtepi8_epi32(int_b1); + __m512i int32_b2 = cvtepi8_epi32(int_b2); + __m512i int32_b3 = cvtepi8_epi32(int_b3); + + __m512i res_0 = _mm512_sub_epi32(int32_val0, int32_b0); + __m512i res_1 = _mm512_sub_epi32(int32_val1, int32_b1); + __m512i res_2 = _mm512_sub_epi32(int32_val2, int32_b2); + __m512i res_3 = _mm512_sub_epi32(int32_val3, int32_b3); + + return {Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + return RequantizeAvx512(inp, multiplier_v, zero_point_v); + } + + void dump() const { + for (size_t i = 0; i < size(); ++i) { + std::cout << (int)((value_type*)&vals)[i] << " "; + } + std::cout << std::endl; + } + private: + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +template<> +struct Vectorized : public Vectorizedqi { + static constexpr int size() { + return 64; + } + + static constexpr int float_num_vecs() { + return 4; + } + + static constexpr int int_num_vecs() { + return 4; + } + + using float_vec_return_type = std::array, 4>; + using int_vec_return_type = std::array, 4>; + using value_type = typename c10::quint8::underlying; + + public: + using Vectorizedqi::Vectorizedqi; + Vectorized() {} + + Vectorized(__m512i vals_) { vals = vals_;} + + // Broadcast constructor + Vectorized(const c10::quint8& val) { + value_type uw = val.val_; + vals = _mm512_set1_epi8(uw); + } + + Vectorized(const Vectorized& other) : Vectorizedqi(other.vals) { } + + void store(void* ptr, int count = size()) const { + if (count != size()) { + memcpy(ptr, &vals, count * sizeof(value_type)); + } else { + _mm512_storeu_si512((__m512i*)ptr, vals); + } + } + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + private: + __m512i cvtepu8_epi32(__m128i epu8_vals) const { + return _mm512_cvtepu8_epi32(epu8_vals); + } + + public: + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = + vec::fmadd(scale, Vectorized(float_val0), scale_zp_premul); + auto val1 = + vec::fmadd(scale, Vectorized(float_val1), scale_zp_premul); + auto val2 = + vec::fmadd(scale, Vectorized(float_val2), scale_zp_premul); + auto val3 = + vec::fmadd(scale, Vectorized(float_val3), scale_zp_premul); + + return {val0, val1, val2, val3}; + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + uint8_t quantized_values[64]; + QuantizeAvx512( + rhs_data, quantized_values, 64, inverse_scale, zero_point); + return Vectorized::loadu(quantized_values); + } + + Vectorized maximum(Vectorized b) const { + return _mm512_max_epu8(vals, b.vals); + } + + Vectorized minimum(Vectorized b) const { + return _mm512_min_epu8(vals, b.vals); + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + return _mm512_min_epu8( + _mm512_max_epu8(vals, zero_point.vals), q_six.vals); + } + + int_vec_return_type widening_subtract(Vectorized b) const { + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); + + __m512i int32_val0 = cvtepu8_epi32(int_val0); + __m512i int32_val1 = cvtepu8_epi32(int_val1); + __m512i int32_val2 = cvtepu8_epi32(int_val2); + __m512i int32_val3 = cvtepu8_epi32(int_val3); + + __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]); + __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]); + __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]); + __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]); + + __m512i int32_b0 = cvtepu8_epi32(int_b0); + __m512i int32_b1 = cvtepu8_epi32(int_b1); + __m512i int32_b2 = cvtepu8_epi32(int_b2); + __m512i int32_b3 = cvtepu8_epi32(int_b3); + + __m512i res_0 = _mm512_sub_epi32(int32_val0, int32_b0); + __m512i res_1 = _mm512_sub_epi32(int32_val1, int32_b1); + __m512i res_2 = _mm512_sub_epi32(int32_val2, int32_b2); + __m512i res_3 = _mm512_sub_epi32(int32_val3, int32_b3); + return {Vectorized(res_0), + Vectorized(res_1), + Vectorized(res_2), + Vectorized(res_3)}; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + __m512 multiplier_v = _mm512_set1_ps(multiplier); + __m512i zero_point_v = _mm512_set1_epi32(zero_point); + return RequantizeAvx512(inp, multiplier_v, zero_point_v); + } + + void dump() const { + for (size_t i = 0; i < size(); ++i) { + std::cout << (int)((value_type*)&vals)[i] << " "; + } + std::cout << std::endl; + } + private: + + // Load from memory constructor + Vectorized(const void* ptr) { + vals = _mm512_loadu_si512((const __m512i*)ptr); + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +#else + +// NOTE: These are low-performance implementations that we fall back on. + +template < + typename T, + typename float_vec_return_type_, + typename int_vec_return_type_, + int size_> +struct VectorizedQuantizedConverter { + static constexpr int size() { + return size_; + } + + static constexpr int float_num_vecs() { + return size() / 8; + } + + static constexpr int int_num_vecs() { + return size() / 8; + } + + using float_vec_return_type = float_vec_return_type_; + using int_vec_return_type = int_vec_return_type_; + + using value_type = typename T::underlying; + std::array vals; + + VectorizedQuantizedConverter(T val) { + for (size_t i = 0; i < size(); ++i) { + vals[i] = val.val_; + } + } + + VectorizedQuantizedConverter(const void* ptr) { + memcpy(vals.data(), ptr, sizeof(value_type) * size()); + } + + void store(void* ptr, int count = size()) const { + memcpy(ptr, vals.data(), count * sizeof(value_type)); + } + + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point, + Vectorized scale_zp_premul) const { + float_vec_return_type rv; + for (int i = 0; i < float_num_vecs(); ++i) { + float tmp_vals[16]; + for (int j = 0; j < 16; ++j) { + tmp_vals[j] = at::native::dequantize_val( + scale[j], zero_point[j], T(vals[16 * i + j])); + } + rv[i] = Vectorized(tmp_vals[0], + tmp_vals[1], + tmp_vals[2], + tmp_vals[3], + tmp_vals[4], + tmp_vals[5], + tmp_vals[6], + tmp_vals[7], + tmp_vals[8], + tmp_vals[9], + tmp_vals[10], + tmp_vals[11], + tmp_vals[12], + tmp_vals[13], + tmp_vals[14], + tmp_vals[15]); + } + return rv; + } + + void dump() const { + for (int i = 0; i < size(); ++i) { + std::cout << vals[i] << " "; + } + std::cout << std::endl; + } + + protected: + VectorizedQuantizedConverter() {} +}; + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>() {} + Vectorized(c10::qint32 val) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint32, + std::array, 1>, + std::array, 1>, + 16>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint32*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + for (size_t i = 0; i < size(); ++i) { + retval[0].vals[i] = vals[i] - b.vals[i]; + } + return retval; + } + + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = + nearbyint(static_cast(inp[0].vals[i]) * multiplier) + + zero_point; + } + return retval; + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +template <> +Vectorized inline operator*( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] * b.vals[i]; + } + return retval; +} + +template <> +Vectorized inline operator+( + const Vectorized& a, + const Vectorized& b) { + Vectorized retval; + for (size_t i = 0; i < std::decay_t::size(); ++i) { + retval.vals[i] = a.vals[i] + b.vals[i]; + } + return retval; +} + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64> { + Vectorized() + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>() {} + Vectorized(c10::qint8 val) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::qint8, + std::array, 4>, + std::array, 4>, + 64>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::qint8*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +template <> +struct Vectorized : public VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64> { + Vectorized() + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>() {} + Vectorized(c10::quint8 val) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>(val) {} + Vectorized(const void* ptr) + : VectorizedQuantizedConverter< + c10::quint8, + std::array, 4>, + std::array, 4>, + 64>(ptr) {} + + static Vectorized loadu(const void* ptr) { + return Vectorized(ptr); + } + + static Vectorized quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + std::array qvals; + std::array float_vals; + + for (int i = 0; i < float_num_vecs(); ++i) { + rhs[i].store(&float_vals[i * 16], 16); + } + + at::native::quantize_vec( + scale, + zero_point, + float_vals.data(), + (c10::quint8*)qvals.data(), + 16 * float_num_vecs()); + + return Vectorized::loadu(qvals.data()); + } + + Vectorized maximum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::max(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized minimum(Vectorized b) const { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min(vals[i], b.vals[i]); + } + return retval; + } + + Vectorized relu(Vectorized zero_point) const { + return maximum(zero_point); + } + + + Vectorized relu6( + Vectorized zero_point, + Vectorized q_six) { + Vectorized retval; + for (size_t i = 0; i < size(); ++i) { + retval.vals[i] = std::min( + std::max(vals[i], zero_point.vals[i]), q_six.vals[i]); + } + return retval; + } + + int_vec_return_type widening_subtract(Vectorized b) const { + int_vec_return_type retval; + constexpr int elem_per_int_vec = size() / int_num_vecs(); + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + retval[i].vals[j] = + static_cast(vals[i * elem_per_int_vec + j]) - + static_cast(b.vals[i * elem_per_int_vec + j]); + } + } + return retval; + } + static Vectorized requantize_from_int( + const int_vec_return_type& inp, + float multiplier, + int32_t zero_point) { + constexpr int elem_per_int_vec = size() / int_num_vecs(); + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + Vectorized retval; + for (size_t i = 0; i < int_num_vecs(); ++i) { + for (size_t j = 0; j < elem_per_int_vec; ++j) { + int32_t rounded = + nearbyint(static_cast(inp[i].vals[j]) * multiplier) + + zero_point; + retval.vals[i * elem_per_int_vec + j] = + std::min(std::max(rounded, min_val), max_val); + } + } + return retval; + } +}; + +template <> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + return a.maximum(b); +} + +#endif // defined(CPU_CAPABILITY_AVX512) && !defined(MSVC) + +}}} diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_base.h b/aten/src/ATen/cpu/vec/vec_base.h similarity index 84% rename from aten/src/ATen/cpu/vec/vec256/vec256_base.h rename to aten/src/ATen/cpu/vec/vec_base.h index 596dac67c2cd3..da5f318bf530c 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include #include @@ -32,13 +32,28 @@ #include #include +// These macros helped us unify vec_base.h +#ifdef CPU_CAPABILITY_AVX512 #if defined(__GNUC__) -#define __at_align32__ __attribute__((aligned(32))) +#define __at_align__ __attribute__((aligned(64))) #elif defined(_WIN32) -#define __at_align32__ __declspec(align(32)) +#define __at_align__ __declspec(align(64)) #else -#define __at_align32__ +#define __at_align__ #endif +#define VECTOR_WIDTH 64 +#define int_vector __m512i +#else // CPU_CAPABILITY_AVX512 +#if defined(__GNUC__) +#define __at_align__ __attribute__((aligned(32))) +#elif defined(_WIN32) +#define __at_align__ __declspec(align(32)) +#else +#define __at_align__ +#endif +#define VECTOR_WIDTH 32 +#define int_vector __m256i +#endif // CPU_CAPABILITY_AVX512 namespace at { namespace vec { @@ -70,11 +85,11 @@ using int_same_size_t = typename int_of_size::type; // NOTE: If you specialize on a type, you must define all operations! -// emulates vectorized types +// emulates Vectorized types template struct Vectorized { private: - __at_align32__ T values[32 / sizeof(T)]; + __at_align__ T values[VECTOR_WIDTH / sizeof(T)]; public: using value_type = T; using size_type = int; @@ -111,7 +126,7 @@ struct Vectorized { // identifier is odr-used or not, and in any case it's hard to tell if // a variable is odr-used or not. So best to just cut the problem at the root. static constexpr size_type size() { - return 32 / sizeof(T); + return VECTOR_WIDTH / sizeof(T); } Vectorized() : values{0} {} Vectorized(T val) { @@ -134,60 +149,60 @@ struct Vectorized { template static Vectorized blend(const Vectorized& a, const Vectorized& b) { int64_t mask = mask_; - Vectorized vec; + Vectorized vector; for (int64_t i = 0; i < size(); i++) { if (mask & 0x01) { - vec[i] = b[i]; + vector[i] = b[i]; } else { - vec[i] = a[i]; + vector[i] = a[i]; } mask = mask >> 1; } - return vec; + return vector; } static Vectorized blendv(const Vectorized& a, const Vectorized& b, const Vectorized& mask) { - Vectorized vec; + Vectorized vector; int_same_size_t buffer[size()]; mask.store(buffer); for (int64_t i = 0; i < size(); i++) { if (buffer[i] & 0x01) { - vec[i] = b[i]; + vector[i] = b[i]; } else { - vec[i] = a[i]; + vector[i] = a[i]; } } - return vec; + return vector; } template // step sometimes requires a higher precision type (e.g., T=int, step_t=double) static Vectorized arange(T base = static_cast(0), step_t step = static_cast(1)) { - Vectorized vec; + Vectorized vector; for (int64_t i = 0; i < size(); i++) { - vec.values[i] = base + i * step; + vector.values[i] = base + i * step; } - return vec; + return vector; } static Vectorized set(const Vectorized& a, const Vectorized& b, int64_t count = size()) { - Vectorized vec; + Vectorized vector; for (int64_t i = 0; i < size(); i++) { if (i < count) { - vec[i] = b[i]; + vector[i] = b[i]; } else { - vec[i] = a[i]; + vector[i] = a[i]; } } - return vec; + return vector; } static Vectorized loadu(const void* ptr) { - Vectorized vec; - std::memcpy(vec.values, ptr, 32); - return vec; + Vectorized vector; + std::memcpy(vector.values, ptr, VECTOR_WIDTH); + return vector; } static Vectorized loadu(const void* ptr, int64_t count) { - Vectorized vec; - std::memcpy(vec.values, ptr, count * sizeof(T)); - return vec; + Vectorized vector; + std::memcpy(vector.values, ptr, count * sizeof(T)); + return vector; } void store(void* ptr, int count = size()) const { std::memcpy(ptr, values, count * sizeof(T)); @@ -203,15 +218,15 @@ struct Vectorized { return mask; } Vectorized isnan() const { - Vectorized vec; + Vectorized vector; for (int64_t i = 0; i != size(); i++) { if (_isnan(values[i])) { - std::memset(static_cast(vec.values + i), 0xFF, sizeof(T)); + std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); } else { - std::memset(static_cast(vec.values + i), 0, sizeof(T)); + std::memset(static_cast(vector.values + i), 0, sizeof(T)); } } - return vec; + return vector; } Vectorized map(T (*const f)(T)) const { Vectorized ret; @@ -488,15 +503,15 @@ struct Vectorized { template inline Vectorized binary_pred(const Vectorized& other, Op op) const { // All bits are set to 1 if the pred is true, otherwise 0. - Vectorized vec; + Vectorized vector; for (int64_t i = 0; i != size(); i++) { if (op(values[i], other.values[i])) { - std::memset(static_cast(vec.values + i), 0xFF, sizeof(T)); + std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); } else { - std::memset(static_cast(vec.values + i), 0, sizeof(T)); + std::memset(static_cast(vector.values + i), 0, sizeof(T)); } } - return vec; + return vector; } public: @@ -511,11 +526,11 @@ struct Vectorized { template inline Vectorized binary_pred_bool(const Vectorized& other, Op op) const { // 1 if the pred is true, otherwise 0. - Vectorized vec; + Vectorized vector; for (int i = 0; i != size(); ++ i) { - vec[i] = bool(op(values[i], other.values[i])); + vector[i] = bool(op(values[i], other.values[i])); } - return vec; + return vector; } public: @@ -668,41 +683,62 @@ Vectorized inline clamp_min(const Vectorized &a, const Vectorized &min_ struct Vectorizedi; -#ifdef CPU_CAPABILITY_AVX2 - +#if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) template static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) { - __m256i buffer; - __m256i a_buffer = _mm256_loadu_si256(reinterpret_cast((const T*)a)); - __m256i b_buffer = _mm256_loadu_si256(reinterpret_cast((const T*)b)); + int_vector buffer; +#if defined(CPU_CAPABILITY_AVX2) + int_vector a_buffer = _mm256_load_si256(reinterpret_cast((const T*)a)); + int_vector b_buffer = _mm256_load_si256(reinterpret_cast((const T*)b)); +#elif defined(CPU_CAPABILITY_AVX512) + int_vector a_buffer = _mm512_load_si512(reinterpret_cast((const T*)a)); + int_vector b_buffer = _mm512_load_si512(reinterpret_cast((const T*)b)); +#endif buffer = op(a_buffer, b_buffer); - __at_align32__ T results[Vectorized::size()]; - _mm256_storeu_si256(reinterpret_cast<__m256i*>(results), buffer); + __at_align__ T results[Vectorized::size()]; + +#if defined(CPU_CAPABILITY_AVX2) + _mm256_store_si256(reinterpret_cast(results), buffer); +#elif defined(CPU_CAPABILITY_AVX512) + _mm512_store_si512(reinterpret_cast(results), buffer); +#endif return Vectorized::loadu(results); } template>::value, int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { - // We enclose _mm256_and_si256 with lambda because it is always_inline - return bitwise_binary_op(a, b, [](__m256i a, __m256i b) { return _mm256_and_si256(a, b); }); + // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); +#endif } template>::value, int> = 0> inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { - // We enclose _mm256_or_si256 with lambda because it is always_inline - return bitwise_binary_op(a, b, [](__m256i a, __m256i b) { return _mm256_or_si256(a, b); }); + // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); +#endif } template>::value, int> = 0> inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { - // We enclose _mm256_xor_si256 with lambda because it is always_inline - return bitwise_binary_op(a, b, [](__m256i a, __m256i b) { return _mm256_xor_si256(a, b); }); + // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline +#if defined(CPU_CAPABILITY_AVX2) + return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); +#elif defined(CPU_CAPABILITY_AVX512) + return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); +#endif } #else template static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) { - static constexpr uint32_t element_no = 32 / sizeof(intmax_t); - __at_align32__ intmax_t buffer[element_no]; + static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t); + __at_align__ intmax_t buffer[element_no]; const intmax_t *a_ptr = reinterpret_cast((const T*) a); const intmax_t *b_ptr = reinterpret_cast((const T*) b); for (uint32_t i = 0U; i < element_no; ++ i) { @@ -724,12 +760,12 @@ inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_xor()); } -#endif +#endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) template>::value, int> = 0> inline Vectorized operator~(const Vectorized& a) { Vectorized ones; // All bits are 1 - memset((T*) ones, 0xFF, 32); + memset((T*) ones, 0xFF, VECTOR_WIDTH); return a ^ ones; } @@ -802,7 +838,9 @@ inline mask_gather(const Vectorized& src, T const* base_addr, } // Cast a given vector to another type without changing the bits representation. -// So a Vec of 256 bits containing all ones can be cast to a +// So a Vectorized of 512 bits containing all ones can be cast to a +// Vectorized of 512 bits containing all ones (i.e., eight negative 1s). +// A Vec of 256 bits containing all ones can be cast to a // Vec of 256 bits containing all ones (i.e., four negative 1s). namespace { // There is a struct here because we don't have static_if and I can't @@ -840,10 +878,16 @@ inline Vectorized> convert_to_int_of_same_size(const Vectoriz return Vectorized>::loadu(static_cast(buffer)); } -// E.g., inputs: a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} -// b Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} -// returns: Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} -// Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} +// Example inputs for AVX512: +// a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} +// b Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} +// returns: +// Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} +// Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} +// Example inputs for AVX2: a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} +// b Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} +// returns: Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} +// Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} template inline std::enable_if_t::size() % 2 == 0, std::pair, Vectorized>> deinterleave2(const Vectorized& a, const Vectorized& b) { @@ -866,8 +910,14 @@ deinterleave2(const Vectorized& a, const Vectorized& b) { } // inverse operation of deinterleave2 -// E.g., inputs: a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} -// b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} +// Example inputs for AVX512: +// a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} +// b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} +// returns, for AVX512: +// Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} +// Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} +// Example inputs for AVX2 : a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} +// b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} // returns: Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} // Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} template diff --git a/aten/src/ATen/cpu/vml.h b/aten/src/ATen/cpu/vml.h index 08dca99de045b..b9cc47f3fe73b 100644 --- a/aten/src/ATen/cpu/vml.h +++ b/aten/src/ATen/cpu/vml.h @@ -35,21 +35,8 @@ #include #endif -// [Note SSE-AVX transitions] -// There is a bug in Glibc2.23 -// https://bugs.launchpad.net/ubuntu/+source/glibc/+bug/1663280. Calling zeroall -// when using AVX/AVX2 code resolves this. -#if defined(CPU_CAPABILITY_AVX) && defined(__GLIBC__) && __GLIBC_MINOR__ == 23 -#define DL_RUNTIME_BUG(op, type_) \ - using value_t = typename c10::scalar_value_type::type;\ - volatile value_t x = (value_t)(1); \ - x = std::op(x); \ - _mm256_zeroall(); -#define DL_RUNTIME_BUG_BFLOAT16() _mm256_zeroall(); -#else #define DL_RUNTIME_BUG(op, type_) #define DL_RUNTIME_BUG_BFLOAT16() -#endif namespace at { namespace vml { @@ -117,36 +104,36 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) { }); \ } -IMPLEMENT_VML_BUG(abs) -IMPLEMENT_VML_BUG(acos) -IMPLEMENT_VML_BUG(asin) -IMPLEMENT_VML_BUG(atan) -IMPLEMENT_VML_BUG(ceil) -IMPLEMENT_VML_BUG(cos) +IMPLEMENT_VML(abs) +IMPLEMENT_VML(acos) +IMPLEMENT_VML(asin) +IMPLEMENT_VML(atan) +IMPLEMENT_VML(ceil) +IMPLEMENT_VML(cos) // IMPLEMENT_VML_BUG(cosh) -IMPLEMENT_VML_BUG(erf) -IMPLEMENT_VML_BUG(erfc) +IMPLEMENT_VML(erf) +IMPLEMENT_VML(erfc) IMPLEMENT_VML(erfinv) -IMPLEMENT_VML_BUG(exp) -IMPLEMENT_VML_BUG(expm1) -IMPLEMENT_VML_BUG(floor) +IMPLEMENT_VML(exp) +IMPLEMENT_VML(expm1) +IMPLEMENT_VML(floor) IMPLEMENT_VML(i0) IMPLEMENT_VML(i0e) IMPLEMENT_VML(reciprocal) -IMPLEMENT_VML_BUG(log) -IMPLEMENT_VML_BUG(log10) -IMPLEMENT_VML_BUG(log1p) -IMPLEMENT_VML_BUG(log2) +IMPLEMENT_VML(log) +IMPLEMENT_VML(log10) +IMPLEMENT_VML(log1p) +IMPLEMENT_VML(log2) IMPLEMENT_VML(neg) -IMPLEMENT_VML_BUG(sin) +IMPLEMENT_VML(sin) // IMPLEMENT_VML_BUG(sinh) -IMPLEMENT_VML_BUG(sqrt) -IMPLEMENT_VML_BUG(round) +IMPLEMENT_VML(sqrt) +IMPLEMENT_VML(round) IMPLEMENT_VML(rsqrt) -IMPLEMENT_VML_BUG(tan) -IMPLEMENT_VML_BUG(tanh) -IMPLEMENT_VML_BUG(trunc) -IMPLEMENT_VML_BUG(lgamma) +IMPLEMENT_VML(tan) +IMPLEMENT_VML(tanh) +IMPLEMENT_VML(trunc) +IMPLEMENT_VML(lgamma) #if AT_MKL_ENABLED() && !defined(__APPLE__) diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index b801f819f3e2a..16b3d1b28740b 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -952,91 +952,109 @@ void lu_solve_kernel(const Tensor& b, const Tensor& lu, const Tensor& pivots) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(cholesky_stub, DEFAULT, &cholesky_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(cholesky_stub, &cholesky_kernel); +REGISTER_AVX512_DISPATCH(cholesky_stub, &cholesky_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(cholesky_stub, &cholesky_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(cholesky_stub, &cholesky_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(cholesky_inverse_stub, DEFAULT, &cholesky_inverse_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); +REGISTER_AVX512_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(eig_stub, DEFAULT, &eig_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(eig_stub, &eig_kernel_impl); +REGISTER_AVX512_DISPATCH(eig_stub, &eig_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(eig_stub, &eig_kernel_impl); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(eig_stub, &eig_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); +REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(linalg_eigh_stub, DEFAULT, &linalg_eigh_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); +REGISTER_AVX512_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(geqrf_stub, DEFAULT, &geqrf_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(geqrf_stub, &geqrf_kernel); +REGISTER_AVX512_DISPATCH(geqrf_stub, &geqrf_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(geqrf_stub, &geqrf_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(geqrf_stub, &geqrf_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(orgqr_stub, &orgqr_kernel_impl); +REGISTER_AVX512_DISPATCH(orgqr_stub, &orgqr_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(orgqr_stub, &orgqr_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(ormqr_stub, DEFAULT, &ormqr_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(ormqr_stub, &ormqr_kernel); +REGISTER_AVX512_DISPATCH(ormqr_stub, &ormqr_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(ormqr_stub, &ormqr_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(ormqr_stub, &ormqr_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(lstsq_stub, DEFAULT, &lstsq_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(lstsq_stub, &lstsq_kernel); +REGISTER_AVX512_DISPATCH(lstsq_stub, &lstsq_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(lstsq_stub, &lstsq_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(lstsq_stub, &lstsq_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(triangular_solve_stub, DEFAULT, &triangular_solve_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); +REGISTER_AVX512_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(triangular_solve_stub, &triangular_solve_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(lu_stub, DEFAULT, &lu_kernel); -REGISTER_AVX_DISPATCH(lu_stub, &lu_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_AVX512_DISPATCH(lu_stub, &lu_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(lu_stub, &lu_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(lu_stub, &lu_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_ARCH_DISPATCH(lu_solve_stub, DEFAULT, &lu_solve_kernel); -REGISTER_AVX_DISPATCH(lu_solve_stub, &lu_solve_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_AVX512_DISPATCH(lu_solve_stub, &lu_solve_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(lu_solve_stub, &lu_solve_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(lu_solve_stub, &lu_solve_kernel); }} // namespace at::native diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index 7e3666ef98493..ada5ed5ee7552 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -16,12 +16,12 @@ static CPUCapability compute_cpu_capability() { return CPUCapability::VSX; } #else + if (strcmp(envar, "avx512") == 0) { + return CPUCapability::AVX512; + } if (strcmp(envar, "avx2") == 0) { return CPUCapability::AVX2; } - if (strcmp(envar, "avx") == 0) { - return CPUCapability::AVX; - } #endif if (strcmp(envar, "default") == 0) { return CPUCapability::DEFAULT; @@ -31,12 +31,13 @@ static CPUCapability compute_cpu_capability() { #if !defined(__powerpc__) && !defined(__s390x__) if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && \ + cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_fma3()) { + return CPUCapability::AVX512; + } if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) { return CPUCapability::AVX2; } - if (cpuinfo_has_x86_avx()) { - return CPUCapability::AVX; - } } #endif #ifdef HAVE_VSX_CPU_DEFINITION @@ -54,8 +55,8 @@ CPUCapability get_cpu_capability() { void* DispatchStubImpl::get_call_ptr( DeviceType device_type , void *DEFAULT -#ifdef HAVE_AVX_CPU_DEFINITION - , void *AVX +#ifdef HAVE_AVX512_CPU_DEFINITION + , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 @@ -72,8 +73,8 @@ void* DispatchStubImpl::get_call_ptr( if (!fptr) { fptr = choose_cpu_impl( DEFAULT -#ifdef HAVE_AVX_CPU_DEFINITION - , AVX +#ifdef HAVE_AVX512_CPU_DEFINITION + , AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , AVX2 @@ -102,8 +103,8 @@ void* DispatchStubImpl::get_call_ptr( void* DispatchStubImpl::choose_cpu_impl( void *DEFAULT -#ifdef HAVE_AVX_CPU_DEFINITION - , void *AVX +#ifdef HAVE_AVX512_CPU_DEFINITION + , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 @@ -114,18 +115,26 @@ void* DispatchStubImpl::choose_cpu_impl( ) { auto capability = static_cast(get_cpu_capability()); (void)capability; +#ifdef HAVE_AVX512_CPU_DEFINITION + if (capability >= static_cast(CPUCapability::AVX512)) { + // Quantization kernels have also been disabled on Windows + // for AVX512 because some of their tests are flaky on Windows. + // Ideally, we should have AVX512 kernels for all kernels. + if (C10_UNLIKELY(!AVX512)) { + // dispatch to AVX2, since the AVX512 kernel is missing + TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel"); + return AVX2; + } else { + return AVX512; + } + } +#endif #ifdef HAVE_AVX2_CPU_DEFINITION if (capability >= static_cast(CPUCapability::AVX2)) { TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel"); return AVX2; } #endif -#ifdef HAVE_AVX_CPU_DEFINITION - if (capability >= static_cast(CPUCapability::AVX)) { - TORCH_INTERNAL_ASSERT(AVX, "DispatchStub: missing AVX kernel"); - return AVX; - } -#endif #ifdef HAVE_VSX_CPU_DEFINITION if (capability >= static_cast(CPUCapability::VSX)) { TORCH_INTERNAL_ASSERT(VSX, "DispatchStub: missing VSX kernel"); diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index 315f5007dbdbd..94a2dc421a6ca 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -9,8 +9,8 @@ // Implements instruction set specific function dispatch. // -// Kernels that may make use of specialized instruction sets (e.g. AVX) are -// compiled multiple times with different compiler flags (e.g. -mavx). A +// Kernels that may make use of specialized instruction sets (e.g. AVX2) are +// compiled multiple times with different compiler flags (e.g. -mavx2). A // DispatchStub contains a table of function pointers for a kernel. At runtime, // the fastest available kernel is chosen based on the features reported by // cpuinfo. @@ -50,8 +50,8 @@ enum class CPUCapability { #ifdef HAVE_VSX_CPU_DEFINITION VSX = 1, #else - AVX = 1, - AVX2 = 2, + AVX2 = 1, + AVX512 = 2, #endif NUM_OPTIONS }; @@ -71,8 +71,8 @@ struct TORCH_API DispatchStubImpl { void* get_call_ptr( DeviceType device_type , void *DEFAULT -#ifdef HAVE_AVX_CPU_DEFINITION - , void *AVX +#ifdef HAVE_AVX512_CPU_DEFINITION + , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 @@ -89,8 +89,8 @@ struct TORCH_API DispatchStubImpl { */ void* choose_cpu_impl( void *DEFAULT -#ifdef HAVE_AVX_CPU_DEFINITION - , void *AVX +#ifdef HAVE_AVX512_CPU_DEFINITION + , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 @@ -126,8 +126,8 @@ struct DispatchStub { return reinterpret_cast( impl.get_call_ptr(device_type , reinterpret_cast(DEFAULT) -#ifdef HAVE_AVX_CPU_DEFINITION - , reinterpret_cast(AVX) +#ifdef HAVE_AVX512_CPU_DEFINITION + , reinterpret_cast(AVX512) #endif #ifdef HAVE_AVX2_CPU_DEFINITION , reinterpret_cast(AVX2) @@ -155,8 +155,8 @@ struct DispatchStub { } static FnPtr DEFAULT; -#ifdef HAVE_AVX_CPU_DEFINITION - static FnPtr AVX; +#ifdef HAVE_AVX512_CPU_DEFINITION + static FnPtr AVX512; #endif #ifdef HAVE_AVX2_CPU_DEFINITION static FnPtr AVX2; @@ -203,10 +203,10 @@ struct RegisterHIPDispatch { #define REGISTER_ARCH_DISPATCH(name, arch, fn) \ template <> decltype(fn) DispatchStub::arch = fn; -#ifdef HAVE_AVX_CPU_DEFINITION -#define REGISTER_AVX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX, fn) +#ifdef HAVE_AVX512_CPU_DEFINITION +#define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn) #else -#define REGISTER_AVX_DISPATCH(name, fn) +#define REGISTER_AVX512_DISPATCH(name, fn) #endif #ifdef HAVE_AVX2_CPU_DEFINITION @@ -223,8 +223,8 @@ struct RegisterHIPDispatch { #define REGISTER_NO_CPU_DISPATCH(name, fn_type) \ REGISTER_ARCH_DISPATCH(name, DEFAULT, static_cast(nullptr)) \ - REGISTER_AVX_DISPATCH(name, static_cast(nullptr)) \ - REGISTER_AVX2_DISPATCH(name, static_cast(nullptr)) \ + REGISTER_AVX512_DISPATCH(name, static_cast(nullptr)) \ + REGISTER_AVX2_DISPATCH(name, static_cast(nullptr)) \ REGISTER_VSX_DISPATCH(name, static_cast(nullptr)) #define REGISTER_CUDA_DISPATCH(name, fn) \ @@ -244,6 +244,8 @@ struct RegisterHIPDispatch { // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn) #elif defined(CPU_CAPABILITY) #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) +#define REGISTER_NO_AVX512_DISPATCH(name, fn_type) \ + REGISTER_AVX512_DISPATCH(name, static_cast(nullptr)) #endif diff --git a/aten/src/ATen/native/SegmentReduce.cpp b/aten/src/ATen/native/SegmentReduce.cpp index b8b9405528c78..ff9529d32dd4d 100644 --- a/aten/src/ATen/native/SegmentReduce.cpp +++ b/aten/src/ATen/native/SegmentReduce.cpp @@ -275,10 +275,10 @@ REGISTER_ARCH_DISPATCH( DEFAULT, &_segment_reduce_cpu_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_AVX2_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_AVX512_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_VSX_DISPATCH(_segment_reduce_stub, &_segment_reduce_cpu_kernel); // Currently some computation is being duplicated across forward and backward. @@ -319,7 +319,7 @@ REGISTER_ARCH_DISPATCH( DEFAULT, &_segment_reduce_cpu_backward_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_AVX_DISPATCH( +REGISTER_AVX512_DISPATCH( _segment_reduce_backward_stub, &_segment_reduce_cpu_backward_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/aten/src/ATen/native/cpu/README.md b/aten/src/ATen/native/cpu/README.md index abf83b560ba44..b3fbd171ad470 100644 --- a/aten/src/ATen/native/cpu/README.md +++ b/aten/src/ATen/native/cpu/README.md @@ -4,7 +4,7 @@ The most important things to know: compiled multiple times for different instruction sets.** Yes, this folder is named `cpu`, but that doesn't mean put any old CPU kernel it. Only put CPU kernels which need to be compiled -multiple times to take advantage of AVX/SSE instructions, but +multiple times to take advantage of AVX512/AVX2/SSE instructions, but only on processors that support them. **Ensure that all implementations in this folder are put in an @@ -52,14 +52,14 @@ All of the `*.cpp` files in this folder will be compiled under all compiler flags specified by `CPU_CAPABILITY_FLAGS` in `aten/src/ATen/CMakeLists.txt`. The purpose of this is to allow the compilation with various compiler -flags to enable features such as AVX instructions, while using runtime -dispatch, which makes sure only valid instructions will be used on any +flags to enable features such as AVX2 or AVX512 instructions, while using +runtime dispatch, which makes sure only valid instructions will be used on any given platform. -Vectorized.h provides a generic implementation of a vec type that allows +vec.h provides a generic implementation of vec type that allows the programmer to write code packing various primitives (such as floats) -within 256bit registers. vec defines various operators such as + and * -and provides functions to allow operations such as max, min, etc. +within 256bit & 512bits registers. vec defines various operators such as ++ and * and provides functions to allow operations such as max, min, etc. As an example `ReduceOpsKernel.cpp` implements a generic `kernel_` that reduces an entire array using a given associative binary operation such as +. @@ -74,5 +74,5 @@ generic code, which will be compiled under multipled compilation settings. `../ReduceOps.cpp` now includes the header `ReduceOpsKernel.h`, which contains a generic definition of `sumImplAll`. This function allows the user to reduce over a dimension or all dimensions. The appropiate capability is chosen at -runtime using cpuinfo. If the current platform has AVX, `sumImpl` will be set -to `sumImplAll`. +runtime using cpuinfo. If the current platform has AVX2, `sumImpl` will be set +to `sumImplAll`. diff --git a/aten/src/ATen/native/cpu/Reduce.h b/aten/src/ATen/native/cpu/Reduce.h index 30a7dc64a3a05..d97e81f43673f 100644 --- a/aten/src/ATen/native/cpu/Reduce.h +++ b/aten/src/ATen/native/cpu/Reduce.h @@ -32,7 +32,8 @@ static inline bool is_outer_reduction(const int64_t* strides) { } template -static inline void reduction128(char** data, int64_t n, int64_t stride, func_t op, vec_func_t vop, bool reduce) { +static inline void vectorized_reduction(char** data, int64_t n, int64_t stride, + func_t op, vec_func_t vop, bool reduce) { VEC_LOOP_HEADER(func_t, data) const char* in1_ptr = data[1]; Vec acc[4]; @@ -80,7 +81,7 @@ static inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t); int64_t count = n / (4 * Vec::size()); if (count > 0) { - reduction128(data, count, vector_stride, op, vop, /*reduce=*/true); + vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true); } char* ptrs[3] = { data[0], data[0], data[1] }; int64_t strides[] = { 0, 0, sizeof(scalar_t) }; @@ -92,10 +93,14 @@ template static inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) { VEC_LOOP_HEADER(func_t, data) - // reduce down each column of 4 * Vec::size() elements (128 bytes) + // reduce down each column of 4 * Vec::size() elements (128 or 256 bytes) +#if defined(CPU_CAPABILITY_AVX512) + int64_t outer_stride[2] = { 256, 256 }; +#else int64_t outer_stride[2] = { 128, 128 }; +#endif UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] { - reduction128(data, size0, inner_stride, op, vop, /*reduce=*/false); + vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false); }); // reduce down the remaining columns diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 7591de27aee5d..ae0dcec19d4a6 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -395,7 +395,6 @@ static void argmin_kernel_impl(TensorIterator &iter) { } // anonymous namespace -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(std_var_stub, &std_var_kernel_impl); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(prod_stub, &prod_kernel_impl); diff --git a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp index 3fdde8c07f1cc..6318543e7eb4c 100644 --- a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp +++ b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp @@ -219,9 +219,15 @@ inline void _vec_softmax( int64_t outer_stride = dim_size * dim_stride; int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1); int vectorized_step = Vec().size(); // Currently, we only support scalar_t with double or float32 - TORCH_CHECK( +#ifdef CPU_CAPABILITY_AVX512 + TORCH_INTERNAL_ASSERT( + (vectorized_step == 16) || (vectorized_step == 8), + "vectorized_step must be 16 with dtype float or 8 with dtype double"); +#else + TORCH_INTERNAL_ASSERT( (vectorized_step == 8) || (vectorized_step == 4), "vectorized_step must be 8 with dtype float or 4 with dtype double"); +#endif parallel_for( 0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) { int64_t idx = begin; diff --git a/aten/src/ATen/native/cpu/SumKernel.cpp b/aten/src/ATen/native/cpu/SumKernel.cpp index 4ae4b3585f64c..db8b077e22eae 100644 --- a/aten/src/ATen/native/cpu/SumKernel.cpp +++ b/aten/src/ATen/native/cpu/SumKernel.cpp @@ -611,7 +611,6 @@ void nansum_kernel_impl(TensorIterator &iter) { } // namespace (anonymous) -// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(sum_stub, &sum_kernel_impl); REGISTER_DISPATCH(nansum_stub, &nansum_kernel_impl); diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index c4cdc3d99389b..2d6fcd226e8b7 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -715,8 +715,15 @@ REGISTER_DISPATCH(exponential_stub, &CPU_CAPABILITY::exponential_kernel); REGISTER_DISPATCH(geometric_stub, &CPU_CAPABILITY::geometric_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(log_normal_stub, &CPU_CAPABILITY::log_normal_kernel); +#ifdef CPU_CAPABILITY_AVX512 +// normal_stub isn't being dispatched to AVX512 because it exposes +// flakiness in test_sgd of test/test_optim.py +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(normal_stub, void(*)(Tensor&, const double, const double, c10::optional)); +#else // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(normal_stub, &CPU_CAPABILITY::normal_kernel); +#endif // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(uniform_stub, &CPU_CAPABILITY::uniform_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/aten/src/ATen/native/cpu/avx_mathfun.h b/aten/src/ATen/native/cpu/avx_mathfun.h index 33f8569e6e8ad..080cd833d3a10 100644 --- a/aten/src/ATen/native/cpu/avx_mathfun.h +++ b/aten/src/ATen/native/cpu/avx_mathfun.h @@ -32,26 +32,17 @@ #include -/* yes I know, the top of this file is quite ugly */ +/* The original source of this file has been modified. */ +#if defined(CPU_CAPABILITY_AVX2) + #if defined(__GNUC__) # define ALIGN32_BEG __attribute__((aligned(32))) #elif defined(_WIN32) # define ALIGN32_BEG __declspec(align(32)) #endif -/* __m128 is ugly to write */ -typedef __m256 v8sf; // vector of 8 float (avx) -typedef __m256i v8si; // vector of 8 int (avx) -typedef __m128i v4si; // vector of 8 int (avx) - -#define _PI32AVX_CONST(Name, Val) \ - static const ALIGN32_BEG int _pi32avx_##Name[4] = { Val, Val, Val, Val } - -_PI32AVX_CONST(1, 1); -_PI32AVX_CONST(inv1, ~1); -_PI32AVX_CONST(2, 2); -_PI32AVX_CONST(4, 4); - +typedef __m256 v8sf; // vector of 8 float (avx2) +typedef __m256i v8si; // vector of 8 int (avx2) /* declare some AVX constants -- why can't I figure a better way to do that? */ #define _PS256_CONST(Name, Val) \ @@ -91,67 +82,6 @@ _PS256_CONST(cephes_log_p8, + 3.3333331174E-1); _PS256_CONST(cephes_log_q1, -2.12194440e-4); _PS256_CONST(cephes_log_q2, 0.693359375); -#ifndef CPU_CAPABILITY_AVX2 - -typedef union imm_xmm_union { - v8si imm; - v4si xmm[2]; -} imm_xmm_union; - -#define COPY_IMM_TO_XMM(imm_, xmm0_, xmm1_) { \ - imm_xmm_union u __attribute__((aligned(32))); \ - u.imm = imm_; \ - xmm0_ = u.xmm[0]; \ - xmm1_ = u.xmm[1]; \ -} - -#define COPY_XMM_TO_IMM(xmm0_, xmm1_, imm_) { \ - imm_xmm_union u __attribute__((aligned(32))); \ - u.xmm[0]=xmm0_; u.xmm[1]=xmm1_; imm_ = u.imm; \ - } - - -#define AVX2_BITOP_USING_SSE2(fn) \ -static inline v8si _mm256_##fn(v8si x, int a) \ -{ \ - /* use SSE2 instruction to perform the bitop AVX2 */ \ - v4si x1, x2; \ - v8si ret; \ - COPY_IMM_TO_XMM(x, x1, x2); \ - x1 = _mm_##fn(x1,a); \ - x2 = _mm_##fn(x2,a); \ - COPY_XMM_TO_IMM(x1, x2, ret); \ - return(ret); \ -} - -#warning "Using SSE2 to perform AVX2 bitshift ops" -AVX2_BITOP_USING_SSE2(slli_epi32) -AVX2_BITOP_USING_SSE2(srli_epi32) - -#define AVX2_INTOP_USING_SSE2(fn) \ -static inline v8si _mm256_##fn(v8si x, v8si y) \ -{ \ - /* use SSE2 instructions to perform the AVX2 integer operation */ \ - v4si x1, x2; \ - v4si y1, y2; \ - v8si ret; \ - COPY_IMM_TO_XMM(x, x1, x2); \ - COPY_IMM_TO_XMM(y, y1, y2); \ - x1 = _mm_##fn(x1,y1); \ - x2 = _mm_##fn(x2,y2); \ - COPY_XMM_TO_IMM(x1, x2, ret); \ - return(ret); \ -} - -#warning "Using SSE2 to perform AVX2 integer ops" -AVX2_INTOP_USING_SSE2(and_si128) -AVX2_INTOP_USING_SSE2(andnot_si128) -AVX2_INTOP_USING_SSE2(cmpeq_epi32) -AVX2_INTOP_USING_SSE2(sub_epi32) -AVX2_INTOP_USING_SSE2(add_epi32) - -#endif /* CPU_CAPABILITY_AVX2 */ - /* natural logarithm computed for 8 simultaneous float return NaN for x <= 0 @@ -326,11 +256,6 @@ inline v8sf sin256_ps(v8sf x) { // any x v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, sign_bit, y; v8si imm0, imm2; -#ifndef CPU_CAPABILITY_AVX2 - v4si imm0_1, imm0_2; - v4si imm2_1, imm2_2; -#endif - sign_bit = x; /* take the absolute value */ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask); @@ -346,7 +271,6 @@ inline v8sf sin256_ps(v8sf x) { // any x If we don't have AVX, let's perform them using SSE2 directives */ -#ifdef CPU_CAPABILITY_AVX2 /* store the integer part of y in mm0 */ imm2 = _mm256_cvttps_epi32(y); /* j=(j+1) & (~1) (see the cephes sources) */ @@ -366,35 +290,6 @@ inline v8sf sin256_ps(v8sf x) { // any x */ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2); imm2 = _mm256_cmpeq_epi32(imm2,*(v8si*)_pi32_256_0); -#else - /* we use SSE2 routines to perform the integer ops */ - COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2); - - imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1); - imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1); - - imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1); - imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1); - - COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2); - y = _mm256_cvtepi32_ps(imm2); - - imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4); - imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4); - - imm0_1 = _mm_slli_epi32(imm0_1, 29); - imm0_2 = _mm_slli_epi32(imm0_2, 29); - - COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0); - - imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2); - imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2); - - imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128()); - imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128()); - - COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2); -#endif v8sf swap_sign_bit = _mm256_castsi256_ps(imm0); v8sf poly_mask = _mm256_castsi256_ps(imm2); @@ -453,18 +348,12 @@ inline v8sf cos256_ps(v8sf x) { // any x v8sf xmm1, xmm2 = _mm256_setzero_ps(), xmm3, y; v8si imm0, imm2; -#ifndef CPU_CAPABILITY_AVX2 - v4si imm0_1, imm0_2; - v4si imm2_1, imm2_2; -#endif - /* take the absolute value */ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask); /* scale by 4/Pi */ y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI); -#ifdef CPU_CAPABILITY_AVX2 /* store the integer part of y in mm0 */ imm2 = _mm256_cvttps_epi32(y); /* j=(j+1) & (~1) (see the cephes sources) */ @@ -479,39 +368,6 @@ inline v8sf cos256_ps(v8sf x) { // any x /* get the polynom selection mask */ imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2); imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0); -#else - - /* we use SSE2 routines to perform the integer ops */ - COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2); - - imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1); - imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1); - - imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1); - imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1); - - COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2); - y = _mm256_cvtepi32_ps(imm2); - - imm2_1 = _mm_sub_epi32(imm2_1, *(v4si*)_pi32avx_2); - imm2_2 = _mm_sub_epi32(imm2_2, *(v4si*)_pi32avx_2); - - imm0_1 = _mm_andnot_si128(imm2_1, *(v4si*)_pi32avx_4); - imm0_2 = _mm_andnot_si128(imm2_2, *(v4si*)_pi32avx_4); - - imm0_1 = _mm_slli_epi32(imm0_1, 29); - imm0_2 = _mm_slli_epi32(imm0_2, 29); - - COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0); - - imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2); - imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2); - - imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128()); - imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128()); - - COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2); -#endif v8sf sign_bit = _mm256_castsi256_ps(imm0); v8sf poly_mask = _mm256_castsi256_ps(imm2); @@ -571,12 +427,6 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { v8sf xmm1, xmm2, xmm3 = _mm256_setzero_ps(), sign_bit_sin, y; v8si imm0, imm2, imm4; -#ifndef CPU_CAPABILITY_AVX2 - v4si imm0_1, imm0_2; - v4si imm2_1, imm2_2; - v4si imm4_1, imm4_2; -#endif - sign_bit_sin = x; /* take the absolute value */ x = _mm256_and_ps(x, *(v8sf*)_ps256_inv_sign_mask); @@ -586,7 +436,6 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { /* scale by 4/Pi */ y = _mm256_mul_ps(x, *(v8sf*)_ps256_cephes_FOPI); -#ifdef CPU_CAPABILITY_AVX2 /* store the integer part of y in imm2 */ imm2 = _mm256_cvttps_epi32(y); @@ -606,38 +455,7 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { imm2 = _mm256_and_si256(imm2, *(v8si*)_pi32_256_2); imm2 = _mm256_cmpeq_epi32(imm2, *(v8si*)_pi32_256_0); //v8sf poly_mask = _mm256_castsi256_ps(imm2); -#else - /* we use SSE2 routines to perform the integer ops */ - COPY_IMM_TO_XMM(_mm256_cvttps_epi32(y),imm2_1,imm2_2); - - imm2_1 = _mm_add_epi32(imm2_1, *(v4si*)_pi32avx_1); - imm2_2 = _mm_add_epi32(imm2_2, *(v4si*)_pi32avx_1); - imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_inv1); - imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_inv1); - - COPY_XMM_TO_IMM(imm2_1,imm2_2,imm2); - y = _mm256_cvtepi32_ps(imm2); - - imm4_1 = imm2_1; - imm4_2 = imm2_2; - - imm0_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_4); - imm0_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_4); - - imm0_1 = _mm_slli_epi32(imm0_1, 29); - imm0_2 = _mm_slli_epi32(imm0_2, 29); - - COPY_XMM_TO_IMM(imm0_1, imm0_2, imm0); - - imm2_1 = _mm_and_si128(imm2_1, *(v4si*)_pi32avx_2); - imm2_2 = _mm_and_si128(imm2_2, *(v4si*)_pi32avx_2); - - imm2_1 = _mm_cmpeq_epi32(imm2_1, _mm_setzero_si128()); - imm2_2 = _mm_cmpeq_epi32(imm2_2, _mm_setzero_si128()); - - COPY_XMM_TO_IMM(imm2_1, imm2_2, imm2); -#endif v8sf swap_sign_bit_sin = _mm256_castsi256_ps(imm0); v8sf poly_mask = _mm256_castsi256_ps(imm2); @@ -653,22 +471,9 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { x = _mm256_add_ps(x, xmm2); x = _mm256_add_ps(x, xmm3); -#ifdef CPU_CAPABILITY_AVX2 imm4 = _mm256_sub_epi32(imm4, *(v8si*)_pi32_256_2); imm4 = _mm256_andnot_si256(imm4, *(v8si*)_pi32_256_4); imm4 = _mm256_slli_epi32(imm4, 29); -#else - imm4_1 = _mm_sub_epi32(imm4_1, *(v4si*)_pi32avx_2); - imm4_2 = _mm_sub_epi32(imm4_2, *(v4si*)_pi32avx_2); - - imm4_1 = _mm_andnot_si128(imm4_1, *(v4si*)_pi32avx_4); - imm4_2 = _mm_andnot_si128(imm4_2, *(v4si*)_pi32avx_4); - - imm4_1 = _mm_slli_epi32(imm4_1, 29); - imm4_2 = _mm_slli_epi32(imm4_2, 29); - - COPY_XMM_TO_IMM(imm4_1, imm4_2, imm4); -#endif v8sf sign_bit_cos = _mm256_castsi256_ps(imm4); @@ -713,3 +518,5 @@ inline void sincos256_ps(v8sf x, v8sf *s, v8sf *c) { *s = _mm256_xor_ps(xmm1, sign_bit_sin); *c = _mm256_xor_ps(xmm2, sign_bit_cos); } + +#endif // CPU_CAPABILITY_AVX2 diff --git a/aten/src/ATen/native/mkl/SpectralOps.cpp b/aten/src/ATen/native/mkl/SpectralOps.cpp index 7c3635ed54f27..f146ffdc2ca46 100644 --- a/aten/src/ATen/native/mkl/SpectralOps.cpp +++ b/aten/src/ATen/native/mkl/SpectralOps.cpp @@ -149,8 +149,8 @@ static void _fft_fill_with_conjugate_symmetry_cpu_( // Register this one implementation for all cpu types instead of compiling multiple times REGISTER_ARCH_DISPATCH(fft_fill_with_conjugate_symmetry_stub, DEFAULT, &_fft_fill_with_conjugate_symmetry_cpu_) -REGISTER_AVX_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) REGISTER_AVX2_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) +REGISTER_AVX512_DISPATCH(fft_fill_with_conjugate_symmetry_stub, &_fft_fill_with_conjugate_symmetry_cpu_) // _out variants can be shared between PocketFFT and MKL Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index ef0b2720a362e..5efd68601c2f9 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -197,7 +198,28 @@ int64_t hsum(const uint8_t* A, int len) { for (const auto k : c10::irange(8)) { row_sum += temp[k]; } -#endif // CPU_CAPABILITY_AVX2 +#elif defined(CPU_CAPABILITY_AVX512) + __m512i sum_v = _mm512_setzero_si512(); + __m512i one_epi16_v = _mm512_set1_epi16(1); + __m512i one_epi8_v = _mm512_set1_epi8(1); + // vectorized + for (; i < len / 64 * 64; i += 64) { + __m512i src_v = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i)); + sum_v = _mm512_add_epi32( + sum_v, + _mm512_madd_epi16( + // first argument is unsigned, second is signed + _mm512_maddubs_epi16(src_v, one_epi8_v), + one_epi16_v) + ); + } + + alignas(64) int32_t temp[16]; + _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v); + for (const auto k : c10::irange(16)) { + row_sum += temp[k]; + } +#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512 // scalar for (; i < len; ++i) { @@ -233,7 +255,28 @@ int64_t hsum(const int8_t* A, int len) { for (const auto k : c10::irange(8)) { row_sum += temp[k]; } -#endif // CPU_CAPABILITY_AVX2 +#elif defined(CPU_CAPABILITY_AVX512) + __m512i sum_v = _mm512_setzero_si512(); + __m512i one_epi16_v = _mm512_set1_epi16(1); + __m512i one_epi8_v = _mm512_set1_epi8(1); + // vectorized + for (; i < len / 64 * 64; i += 64) { + __m512i src_v = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i)); + sum_v = _mm512_add_epi32( + sum_v, + _mm512_madd_epi16( + // first argument is unsigned, second is signed + _mm512_maddubs_epi16(one_epi8_v, src_v), + one_epi16_v) + ); + } + + alignas(64) int32_t temp[16]; + _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v); + for (const auto k : c10::irange(16)) { + row_sum += temp[k]; + } +#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512 // scalar for (; i < len; ++i) { @@ -255,7 +298,7 @@ int64_t hsum(const int32_t* A, int len) { __m256i src_epi32 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i)); // widen __m128i src_lo_epi32 = _mm256_castsi256_si128(src_epi32); - __m128i src_hi_epi32 = _mm256_extractf128_si256(src_epi32, 1); + __m128i src_hi_epi32 = _mm256_extracti128_si256(src_epi32, 1); __m256i src_lo_epi64 = _mm256_cvtepi32_epi64(src_lo_epi32); __m256i src_hi_epi64 = _mm256_cvtepi32_epi64(src_hi_epi32); // add @@ -268,7 +311,27 @@ int64_t hsum(const int32_t* A, int len) { for (const auto k : c10::irange(4)) { row_sum += temp[k]; } -#endif // CPU_CAPABILITY_AVX2 +#elif defined(CPU_CAPABILITY_AVX512) + __m512i sum_epi64 = _mm512_setzero_si512(); + // vectorized + for (; i < len / 16 * 16; i += 16) { + __m512i src_epi32 = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i)); + // widen + __m256i src_lo_epi32 = _mm512_castsi512_si256(src_epi32); + __m256i src_hi_epi32 = _mm512_extracti32x8_epi32(src_epi32, 1); + __m512i src_lo_epi64 = _mm512_cvtepi32_epi64(src_lo_epi32); + __m512i src_hi_epi64 = _mm512_cvtepi32_epi64(src_hi_epi32); + // add + sum_epi64 = _mm512_add_epi64(sum_epi64, src_lo_epi64); + sum_epi64 = _mm512_add_epi64(sum_epi64, src_hi_epi64); + } + + alignas(64) int64_t temp[8]; + _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_epi64); + for (const auto k : c10::irange(8)) { + row_sum += temp[k]; + } +#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512 // scalar for (; i < len; ++i) { @@ -313,7 +376,36 @@ int64_t hsum_sq(const uint8_t* A, int len) { } sum_v_epu32 = _mm256_setzero_si256(); } -#endif // CPU_CAPABILITY_AVX2 +#elif defined(CPU_CAPABILITY_AVX512) + __m512i sum_v_epu32 = _mm512_setzero_si512(); + alignas(64) int32_t temp[16]; + int overflow_threshold = 262144; // 2147483647(max of int32)/(512*512)*8 = 262144 + int loop = len / overflow_threshold + 1; + for(int j=0; j<=loop; j++){ + for (; ((i < overflow_threshold * j) && (i < len / 32 * 32)); i += 32) { + // (i31, ..., i0) + __m256i src_epu8 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i)); + __m512i src_epu16 = _mm512_cvtepu8_epi16(src_epu8); + // (i31 ^ 2, ..., i0 ^ 2) + __m512i sq_epu16 = _mm512_mullo_epi16(src_epu16, src_epu16); + // (i15 ^ 2, ..., i0 ^ 2) + __m256i sq_lo_epu16 = _mm512_castsi512_si256(sq_epu16); + // (i31 ^ 2, ..., i16 ^ 2) + __m256i sq_hi_epu16 = _mm512_extracti32x8_epi32(sq_epu16, 1); + // widen to epu32 + __m512i sq_lo_epu32 = _mm512_cvtepu16_epi32(sq_lo_epu16); + __m512i sq_hi_epu32 = _mm512_cvtepu16_epi32(sq_hi_epu16); + // add to running sum + sum_v_epu32 = _mm512_add_epi32(sum_v_epu32, sq_lo_epu32); + sum_v_epu32 = _mm512_add_epi32(sum_v_epu32, sq_hi_epu32); + } + _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v_epu32); + for (const auto k : c10::irange(16)) { + row_sum += temp[k]; + } + sum_v_epu32 = _mm512_setzero_si512(); + } +#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512 // scalar for (; i < len; ++i) { @@ -361,7 +453,40 @@ int64_t hsum_sq(const int8_t* A, int len) { } sum_v_epi32 = _mm256_setzero_si256(); } -#endif // CPU_CAPABILITY_AVX2 +#elif defined(CPU_CAPABILITY_AVX512) + // vectorized + __m512i sum_v_epi32 = _mm512_setzero_si512(); + alignas(64) int32_t temp[16]; + + int overflow_threshold = 1048576; //2147483647/(256*256)*8 = 1048576 + int loop = len / overflow_threshold + 1; + + for(int j=0; j<=loop; j++){ + for (; ((i < overflow_threshold * j) && (i < len / 32 * 32)); i += 32) { + // (i31, ..., i0) + __m256i src_epi8 = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i)); + __m512i src_epi16 = _mm512_cvtepi8_epi16(src_epi8); + // (i31 ^ 2, ..., i0 ^ 2) + __m512i sq_epi16 = _mm512_mullo_epi16(src_epi16, src_epi16); + // (i15 ^ 2, ..., i0 ^ 2) + __m256i sq_lo_epi16 = _mm512_castsi512_si256(sq_epi16); + // (i31 ^ 2, ..., i16 ^ 2) + __m256i sq_hi_epi16 = _mm512_extracti32x8_epi32(sq_epi16, 1); + // widen to epi32 + __m512i sq_lo_epi32 = _mm512_cvtepi16_epi32(sq_lo_epi16); + __m512i sq_hi_epi32 = _mm512_cvtepi16_epi32(sq_hi_epi16); + // add to running sum + sum_v_epi32 = _mm512_add_epi32(sum_v_epi32, sq_lo_epi32); + sum_v_epi32 = _mm512_add_epi32(sum_v_epi32, sq_hi_epi32); + } + _mm512_store_si512(reinterpret_cast<__m512i*>(temp), sum_v_epi32); + + for (const auto k : c10::irange(16)) { + row_sum += temp[k]; + } + sum_v_epi32 = _mm512_setzero_si512(); + } +#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512 // scalar for (; i < len; ++i) { @@ -391,7 +516,21 @@ float hsum_sq(const int32_t* A, int len) { for (const auto k : c10::irange(8)) { row_sum += static_cast(temp[k]); } -#endif // CPU_CAPABILITY_AVX2 +#elif defined(CPU_CAPABILITY_AVX512) + __m512 sum_ps = _mm512_setzero_ps(); + // vectorized + for (; i < len / 16 * 16; i += 16) { + __m512i src_epi32 = _mm512_loadu_si512(reinterpret_cast<__m512i const*>(A + i)); + __m512 src_ps = _mm512_cvtepi32_ps(src_epi32); + sum_ps = _mm512_add_ps(sum_ps, _mm512_mul_ps(src_ps, src_ps)); + } + + alignas(64) float temp[16]; + _mm512_store_ps(temp, sum_ps); + for (const auto k : c10::irange(16)) { + row_sum += static_cast(temp[k]); + } +#endif // CPU_CAPABILITY_AVX2 or CPU_CAPABILITY_AVX512 // scalar for (; i < len; ++i) { @@ -1239,7 +1378,7 @@ void qmaxpool_2d_nhwc_kernel( } template -void do_avg_pool_nhwc_on_AVX2( +void do_avg_pool_nhwc_on_AVX_n( const typename T::underlying* i_p, typename T::underlying* o_p, int& c_start, @@ -1256,17 +1395,25 @@ void do_avg_pool_nhwc_on_AVX2( int hsize, int wsize, int csize) { -#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) // buffer for channel accumulator, used to interchange channel-loop // to inner-most, so that memory access of the input tensor data is // continuous. +#ifdef CPU_CAPABILITY_AVX2 constexpr int cb_size = 16; +#else + constexpr int cb_size = 8; +#endif constexpr int vec_width = Vectorized::size() / 4; constexpr int cb_step = cb_size * vec_width; Vectorized acc_buffer[cb_size]; Vectorized acc_buffer_fp[cb_size]; +#ifdef CPU_CAPABILITY_AVX2 if (vec_width == 8) { +#else + if (vec_width == 16) { +#endif for (int c = c_start; c < csize; c += cb_step) { int cend = std::min(cb_size, (csize - c) / vec_width); // initialize loop @@ -1292,14 +1439,23 @@ void do_avg_pool_nhwc_on_AVX2( // convert int32 accumulative to fp32 vec::convert((int*)acc_buffer, (float*)acc_buffer_fp, cend * vec_width); - // first quantize using AVX using 32 lanes, then 8, finally falls + // first quantize using AVX2 or AVX512 using 32 lanes, then 8, finally falls // back to single +#ifdef CPU_CAPABILITY_AVX2 QuantizeAvx2( (float*)acc_buffer_fp, o_p + c, cend * vec_width, multiplier, output_zero_point); +#else + QuantizeAvx512( + (float*)acc_buffer_fp, + o_p + c, + cend * vec_width, + multiplier, + output_zero_point); +#endif } c_start = csize / vec_width * vec_width; } @@ -1307,7 +1463,7 @@ void do_avg_pool_nhwc_on_AVX2( } template -void do_avg_pool_on_AVX2( +void do_avg_pool_on_AVX_n( typename T::underlying* i_p, typename T::underlying* o_p, int64_t& c, @@ -1326,9 +1482,13 @@ void do_avg_pool_on_AVX2( int64_t stride_D, int64_t stride_H, int64_t stride_W) { -#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) - constexpr auto vec_width = Vectorized::size() / 4; +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) + constexpr int vec_width = Vectorized::size() / 4; +#ifdef CPU_CAPABILITY_AVX2 if (vec_width == 8) { +#else + if (vec_width == 16) { +#endif for (; c + vec_width <= channel_size; c += vec_width) { int64_t tcntr = 0; @@ -1416,10 +1576,10 @@ void _qadaptive_avg_pool_kernel( istartH * istrideH + istartW * istrideW; - // Note: If AVX is not available, `do_avg_pool_on_AVX2 is a noop. + // Note: If AVX is not available, `do_avg_pool_on_AVX_n is a noop. // In that case, the following loop takes over // TODO: more vectorization with loop interleaving - do_avg_pool_on_AVX2( + do_avg_pool_on_AVX_n( internal_i_p, o_p, c, @@ -1438,7 +1598,6 @@ void _qadaptive_avg_pool_kernel( istrideD, istrideH, istrideW); - // 1) The following loop handles the remaining channels // 2) It also handles the Non-AVX2 path for (; c < sizeC; ++c) { @@ -1610,7 +1769,7 @@ void _qavg_pool_nhwc_kernel( // For int8 quantization, we implicitly use int32 as accumulation // Or else, it will go to the slow path // TODO: support 16bit, 32bit, and etc. - do_avg_pool_nhwc_on_AVX2( + do_avg_pool_nhwc_on_AVX_n( i_p, o_p, c_start, @@ -1744,7 +1903,7 @@ void qavg_pool3d_nhwc_kernel( } template -int64_t do_quantized_bilinear_on_AVX2( +int64_t do_quantized_bilinear_on_AVX_n( const typename T::underlying*& pos1, typename T::underlying*& pos2, int64_t input_height, @@ -1762,9 +1921,13 @@ int64_t do_quantized_bilinear_on_AVX2( const int64_t h1p, const int64_t w1p) { int64_t c = 0; -#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) +#if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) constexpr auto vec_width = Vectorized::size() / 4; +#ifdef CPU_CAPABILITY_AVX2 if (vec_width == 8) { +#else + if (vec_width == 16) { +#endif for (; c + vec_width <= channels; c += vec_width) { Vectorized pos1_fp_v[4]; Vectorized pos1_int_v[4]; @@ -1861,7 +2024,7 @@ void qupsample_bilinear2d_nhwc_kernel( o_p + (h2 * output_width + w2) * channels; // We have to isolate this function out because the VS does not // expand the macro correctly. - c = do_quantized_bilinear_on_AVX2( + c = do_quantized_bilinear_on_AVX_n( pos1, pos2, input_height, @@ -1989,7 +2152,7 @@ void q_batch_norm_kernel( reinterpret_cast(input.data_ptr()); scalar_t::underlying* Y = reinterpret_cast(output.data_ptr()); - constexpr int kVLen = 8; + constexpr int kVLen = Vectorized::size(); const int64_t outer_size = N * HxW; using Vec = Vectorized; // Hoisted variables @@ -2285,7 +2448,7 @@ void quantized_normalize_kernel( float y_scale = Y->q_scale(); float y_inv_scale = 1.0f / y_scale; - constexpr int kFloatVLen = 8; + constexpr int kFloatVLen = fVec::size(); int64_t kIntVLen = kFloatVLen * qVec::float_num_vecs(); int64_t kNumIntVecInLayer = N / kIntVLen; int64_t kNonVecRemInLayer = N % kIntVLen; @@ -3088,6 +3251,114 @@ void dequantize_tensor_per_tensor_affine_sub_byte_cpu( } // namespace +// Some quantization tests are flaky on Windows with AVX512. If --continue-through-error +// is used, only one fails. But if the failing test is skipped, another one fails. +// If the second test is also skipped, a third one fails. +// So, until Quantization support for Windows is fixed for AVX512, +// AVX2 kernels would be used instead. Ref: GH 56992. +#if defined(CPU_CAPABILITY_AVX512) && defined(_WIN32) +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_channel_affine_stub, + dequantize_tensor_per_channel_affine_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_tensor_affine_stub, + dequantize_tensor_per_tensor_affine_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub, + dequantize_tensor_per_channel_float_qparams_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(fake_quant_grad_learnable_tensor_stub, + fake_quant_learnable_grad_tensor_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(fake_quant_per_channel_cachemask_stub, + fake_quant_per_channel_cachemask_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(fake_quant_tensor_cachemask_stub, + fake_quant_tensor_cachemask_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_stub, + fake_quant_tensor_cachemask_tensor_qparams_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qadaptive_avg_pool2d_nhwc_stub, + qadaptive_avg_pool2d_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qadaptive_avg_pool3d_ndhwc_stub, + qadaptive_avg_pool3d_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qadd_relu_stub, qbinary_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qadd_scalar_relu_stub, qadd_scalar_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qadd_scalar_stub, qadd_scalar_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qadd_stub, qbinary_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qavg_pool2d_nhwc_stub, qavg_pool2d_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qavg_pool3d_nhwc_stub, qavg_pool3d_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qbatch_norm_relu_stub, qbatch_norm_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qbatch_norm_stub, qbatch_norm_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qcat_nhwc_stub, qcat_nhwc_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qcat_relu_nhwc_stub, qcat_nhwc_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qclamp_stub, qclamp_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qclamp_min_stub, qclamp_minmax_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qclamp_max_stub, qclamp_minmax_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qelu_stub, qelu_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qhardsigmoid_stub, qhardsigmoid_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qhardswish_stub, qhardswish_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qmaxpool_2d_nhwc_stub, qmaxpool_2d_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qmul_relu_stub, qbinary_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qmul_stub, qbinary_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qrelu6_stub, qrelu_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qrelu_leaky_stub, qrelu_leaky_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qrelu_stub, qrelu_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qsigmoid_stub, qsigmoid_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qtanh_stub, qtanh_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qthreshold_stub, qthreshold_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qtopk_stub, qtopk_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(fake_quant_grad_learnable_channel_stub, + fake_quant_learnable_per_channel_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_tensor_affine_stub, + quantize_tensor_per_tensor_affine_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_channel_affine_stub, + quantize_tensor_per_channel_affine_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_channel_float_qparams_stub, + quantize_tensor_per_channel_float_qparams_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(quantized_normalize_stub, qnormalize_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(qupsample_bilinear2d_nhwc_stub, qupsample_bilinear2d_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(quantize_tensor_per_tensor_affine_sub_byte_stub, + quantize_tensor_per_tensor_affine_sub_byte_fn); +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +REGISTER_NO_AVX512_DISPATCH(dequantize_tensor_per_tensor_affine_sub_byte_stub, + dequantize_tensor_per_tensor_affine_sub_byte_fn); +#else // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub, &dequantize_tensor_per_channel_affine_cpu); @@ -3167,7 +3438,8 @@ REGISTER_DISPATCH(qthreshold_stub, &qthreshold_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) -REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub, &fake_quantize_learnable_channel_grad_kernel_cpu); +REGISTER_DISPATCH(fake_quant_grad_learnable_channel_stub, + &fake_quantize_learnable_channel_grad_kernel_cpu); // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) REGISTER_DISPATCH( quantize_tensor_per_tensor_affine_stub, @@ -3193,7 +3465,7 @@ REGISTER_DISPATCH( REGISTER_DISPATCH( dequantize_tensor_per_tensor_affine_sub_byte_stub, &dequantize_tensor_per_tensor_affine_sub_byte_cpu); - +#endif // CPU_CAPABILITY_AVX512 && _WIN32 } // namespace native } // namespace at diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index 77bc02334db54..4ee0596da6e7b 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -1071,13 +1071,17 @@ namespace { // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(ComplexTests, TestComplexFloatImagRealConj) { // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - float aa[] = { 1.5488e-28,2.5488e-28,3.5488e-28,4.5488e-28,5.5488e-28,6.5488e-28,7.5488e-28,8.5488e-28 }; + float aa[] = { 1.5488e-28,2.5488e-28,3.5488e-28,4.5488e-28,5.5488e-28,6.5488e-28,7.5488e-28,8.5488e-28, + 9.5488e-28,10.5488e-28,11.5488e-28,12.5488e-28,13.5488e-28,14.5488e-28,15.5488e-28,16.5488e-28}; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - float exp[] = { aa[0],0,aa[2],0,aa[4],0,aa[6],0 }; + float exp[] = { aa[0],0,aa[2],0,aa[4],0,aa[6],0,aa[8],0,aa[10],0,aa[12],0,aa[14],0 }; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - float exp3[] = { aa[1],0,aa[3],0,aa[5],0,aa[7],0 }; + float exp3[] = { aa[1],0,aa[3],0,aa[5],0,aa[7],0,aa[9],0,aa[11],0,aa[13],0,aa[15],0 }; // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - float exp4[] = { 1.5488e-28, -2.5488e-28,3.5488e-28,-4.5488e-28,5.5488e-28,-6.5488e-28,7.5488e-28,-8.5488e-28 }; + float exp4[] = { 1.5488e-28, -2.5488e-28,3.5488e-28,-4.5488e-28, + 5.5488e-28,-6.5488e-28,7.5488e-28,-8.5488e-28, + 9.5488e-28,-10.5488e-28,11.5488e-28,-12.5488e-28, + 13.5488e-28,-14.5488e-28,15.5488e-28,-16.5488e-28 }; auto a = vcomplex::loadu(aa); auto actual1 = a.real(); auto actual3 = a.imag(); @@ -1304,6 +1308,7 @@ namespace { }, test_case); } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TYPED_TEST(FunctionalTests, Map) { using vec = TypeParam; using VT = ValueType; @@ -1339,15 +1344,16 @@ namespace { at::vec::map3([](vec x1, vec x2, vec x3) { return x1 + x2 + x3; }, y, x1, x2, x3, N); for (int64_t i = 0; i < N; i++) { ref_y[i] = x1[i] + x2[i] + x3[i]; } cmp(y, ref_y); - // test map3: y = x1 + x2 + x3 + x4 + // test map4: y = x1 + x2 + x3 + x4 at::vec::map4([](vec x1, vec x2, vec x3, vec x4) { return x1 + x2 + x3 + x4; }, y, x1, x2, x3, x4, N); for (int64_t i = 0; i < N; i++) { ref_y[i] = x1[i] + x2[i] + x3[i] + x4[i]; } cmp(y, ref_y); } - TYPED_TEST(FunctionalBF16Tests, Reduce) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) + TYPED_TEST(FunctionalBF16Tests, Reduce) { using vec = TypeParam; // Can't use ValueType here: - // Vectorized::value_type returns uint16_t on AVX2 + // Vectorized::value_type returns uint16_t on AVX2/AVX512 using VT = c10::BFloat16; using RT = float; // reference constexpr auto R = 2LL; // residual @@ -1394,7 +1400,6 @@ namespace { auto y2 = at::vec::map_reduce_all([](auto x) { return x - x.exp(); }, sum, x_b1, len); ASSERT_TRUE(cmp(y1, y2)) << "Failure Details:\nTest Seed to reproduce: " << seed << "\nmap_reduce_all, Length: " << len << "; fp32: " << y1 << "; bf16: " << RT(y2); - } // Map2ReduceAll for (int64_t len = 1; len <= N; len++) { diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index 0c9496e166ca1..8b0854866a946 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -13,7 +13,13 @@ #include #include #include + +#if defined(CPU_CAPABILITY_AVX512) +#define CACHE_LINE 64 +#else #define CACHE_LINE 32 +#endif + #if defined(__GNUC__) #define CACHE_ALIGN __attribute__((aligned(CACHE_LINE))) #define not_inline __attribute__((noinline)) @@ -26,7 +32,7 @@ CACHE_ALIGN #define #endif #if defined(CPU_CAPABILITY_DEFAULT) || defined(_MSC_VER) #define TEST_AGAINST_DEFAULT 1 -#elif !defined(CPU_CAPABILITY_AVX) && !defined(CPU_CAPABILITY_AVX2) && !defined(CPU_CAPABILITY_VSX) +#elif !defined(CPU_CAPABILITY_AVX512) && !defined(CPU_CAPABILITY_AVX2) && !defined(CPU_CAPABILITY_VSX) #define TEST_AGAINST_DEFAULT 1 #else #undef TEST_AGAINST_DEFAULT @@ -41,7 +47,8 @@ CACHE_ALIGN #define return __VA_ARGS__(std::forward(args)...); \ } -#if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_AVX2) && (defined(__GNUC__) || defined(__GNUG__)) +#if defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_AVX2) || \ + defined(CPU_CAPABILITY_AVX512) && (defined(__GNUC__) || defined(__GNUG__)) #undef CHECK_DEQUANT_WITH_LOW_PRECISION #define CHECK_WITH_FMA 1 #elif !defined(CPU_CAPABILITY_VSX) && !defined(CPU_CAPABILITY_AVX2) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 17b1a688c5eaf..429821496b3b7 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -722,44 +722,43 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) list(APPEND Caffe2_CPU_SRCS ${TORCH_SRCS}) endif() -# NOTE [ Linking AVX and non-AVX files ] +# NOTE [ Linking AVX-n and non-AVX-n files ] # -# Regardless of the CPU capabilities, we build some files with AVX and AVX2 +# Regardless of the CPU capabilities, we build some files with AVX2, and AVX512 # instruction set. If the host CPU doesn't support those, we simply ignore their # functions at runtime during dispatch. # # We must make sure that those files are at the end of the input list when # linking the torch_cpu library. Otherwise, the following error scenario might # occur: -# 1. A non-AVX and an AVX file both call a function defined with the `inline` +# 1. A non-AVX2 and an AVX2 file both call a function defined with the `inline` # keyword # 2. The compiler decides not to inline this function # 3. Two different versions of the machine code are generated for this function: -# one without AVX instructions and one with AVX. -# 4. When linking, the AVX version is found earlier in the input object files, +# one without AVX2 instructions and one with AVX2. +# 4. When linking, the AVX2 version is found earlier in the input object files, # so the linker makes the entire library use it, even in code not guarded by # the dispatcher. -# 5. A CPU without AVX support executes this function, encounters an AVX +# 5. A CPU without AVX2 support executes this function, encounters an AVX2 # instruction and crashes. # # Thus we organize the input files in the following order: -# 1. All files with no AVX support -# 2. All files with AVX support (conveniently, they all have names ending with -# 'AVX.cpp') -# 3. All files with AVX2 support ('*AVX2.cpp') +# 1. All files with no AVX-n support +# 2. All files with AVX2 support ('*AVX2.cpp') +# 3. All files with AVX512 support ('*AVX512.cpp') set(Caffe2_CPU_SRCS_NON_AVX) -set(Caffe2_CPU_SRCS_AVX) set(Caffe2_CPU_SRCS_AVX2) +set(Caffe2_CPU_SRCS_AVX512) foreach(input_filename ${Caffe2_CPU_SRCS}) - if(${input_filename} MATCHES "AVX\\.cpp") - list(APPEND Caffe2_CPU_SRCS_AVX ${input_filename}) - elseif(${input_filename} MATCHES "AVX2\\.cpp") + if(${input_filename} MATCHES "AVX2\\.cpp") list(APPEND Caffe2_CPU_SRCS_AVX2 ${input_filename}) + elseif(${input_filename} MATCHES "AVX512\\.cpp") + list(APPEND Caffe2_CPU_SRCS_AVX512 ${input_filename}) else() list(APPEND Caffe2_CPU_SRCS_NON_AVX ${input_filename}) endif() endforeach(input_filename) -set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS_NON_AVX} ${Caffe2_CPU_SRCS_AVX} ${Caffe2_CPU_SRCS_AVX2}) +set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS_NON_AVX} ${Caffe2_CPU_SRCS_AVX2} ${Caffe2_CPU_SRCS_AVX512}) # ========================================================== # END formerly-libtorch sources diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index 19579b9a32b13..aeeaf64193c03 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -63,14 +63,6 @@ if(INTERN_BUILD_ATEN_OPS) endif() endif(MSVC) - if(C_AVX_FOUND) - if(MSVC) - set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/TH/vector/AVX.cpp PROPERTIES COMPILE_FLAGS "${OPT_FLAG}/arch:AVX ${CXX_AVX_FLAGS}") - else(MSVC) - set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/TH/vector/AVX.cpp PROPERTIES COMPILE_FLAGS "${OPT_FLAG} ${CXX_AVX_FLAGS}") - endif(MSVC) - endif(C_AVX_FOUND) - if(NOT MSVC AND NOT "${CMAKE_C_COMPILER_ID}" MATCHES "Clang") set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/MapAllocator.cpp PROPERTIES COMPILE_FLAGS "-fno-openmp") endif() @@ -80,15 +72,16 @@ if(INTERN_BUILD_ATEN_OPS) list(APPEND CPU_CAPABILITY_NAMES "DEFAULT") list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}") - if(CXX_AVX_FOUND) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX_CPU_DEFINITION") - list(APPEND CPU_CAPABILITY_NAMES "AVX") + + if(CXX_AVX512_FOUND) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX512_CPU_DEFINITION") + list(APPEND CPU_CAPABILITY_NAMES "AVX512") if(MSVC) - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX") + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX512") else(MSVC) - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx") + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx512f -mavx512bw -mavx512vl -mavx512dq -mfma") endif(MSVC) - endif(CXX_AVX_FOUND) + endif(CXX_AVX512_FOUND) if(CXX_AVX2_FOUND) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DHAVE_AVX2_CPU_DEFINITION") @@ -103,11 +96,24 @@ if(INTERN_BUILD_ATEN_OPS) endif(COMPILER_SUPPORTS_NO_AVX256_SPLIT) list(APPEND CPU_CAPABILITY_NAMES "AVX2") - if(MSVC) - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX2") - else(MSVC) - list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx2 -mfma ${CPU_NO_AVX256_SPLIT_FLAGS}") - endif(MSVC) + if(DEFINED ENV{ATEN_AVX512_256}) + if($ENV{ATEN_AVX512_256} MATCHES "TRUE") + if(CXX_AVX512_FOUND) + message("-- ATen AVX2 kernels will use 32 ymm registers") + if(MSVC) + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX512") + else(MSVC) + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -march=native ${CPU_NO_AVX256_SPLIT_FLAGS}") + endif(MSVC) + endif(CXX_AVX512_FOUND) + endif() + else() + if(MSVC) + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG}/arch:AVX2") + else(MSVC) + list(APPEND CPU_CAPABILITY_FLAGS "${OPT_FLAG} -mavx2 -mfma ${CPU_NO_AVX256_SPLIT_FLAGS}") + endif(MSVC) + endif() endif(CXX_AVX2_FOUND) if(CXX_VSX_FOUND) diff --git a/cmake/Modules/FindAVX.cmake b/cmake/Modules/FindAVX.cmake index 7d472eb662cf4..c04427cbad850 100644 --- a/cmake/Modules/FindAVX.cmake +++ b/cmake/Modules/FindAVX.cmake @@ -12,6 +12,25 @@ SET(AVX_CODE " } ") +SET(AVX512_CODE " + #include + + int main() + { + __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0); + __m512i b = a; + __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ); + return 0; + } +") + SET(AVX2_CODE " #include @@ -56,6 +75,8 @@ ENDMACRO() CHECK_SSE(C "AVX" " ;-mavx;/arch:AVX") CHECK_SSE(C "AVX2" " ;-mavx2 -mfma;/arch:AVX2") +CHECK_SSE(C "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512") CHECK_SSE(CXX "AVX" " ;-mavx;/arch:AVX") CHECK_SSE(CXX "AVX2" " ;-mavx2 -mfma;/arch:AVX2") +CHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512") diff --git a/setup.py b/setup.py index 5264fe7fa4600..73d7d11adaeea 100644 --- a/setup.py +++ b/setup.py @@ -99,6 +99,12 @@ # BUILD_BINARY # enables the additional binaries/ build # +# ATEN_AVX512_256=TRUE +# ATen AVX2 kernels can use 32 ymm registers, instead of the default 16. +# This option can be used if AVX512 doesn't perform well on a machine. +# The FBGEMM library also uses AVX512_256 kernels on Xeon D processors, +# but it also has some (optimized) assembly code. +# # PYTORCH_BUILD_VERSION # PYTORCH_BUILD_NUMBER # specify the version of PyTorch, rather than the hard-coded version @@ -928,6 +934,7 @@ def print_box(msg): 'include/ATen/*.h', 'include/ATen/cpu/*.h', 'include/ATen/cpu/vec/vec256/*.h', + 'include/ATen/cpu/vec/vec512/*.h', 'include/ATen/cpu/vec/*.h', 'include/ATen/core/*.h', 'include/ATen/cuda/*.cuh', diff --git a/test/cpp/api/dispatch.cpp b/test/cpp/api/dispatch.cpp index e5bc35177dd48..6416fe3e80915 100644 --- a/test/cpp/api/dispatch.cpp +++ b/test/cpp/api/dispatch.cpp @@ -29,19 +29,19 @@ TEST_F(DispatchTest, TestAVX2) { } } -TEST_F(DispatchTest, TestAVX) { +TEST_F(DispatchTest, TestAVX512) { const std::vector ints {1, 2, 3, 4}; const std::vector result {1, 4, 27, 256}; const auto vals_tensor = torch::tensor(ints); const auto pows_tensor = torch::tensor(ints); #ifdef _WIN32 - _putenv("ATEN_CPU_CAPABILITY=avx"); + _putenv("ATEN_CPU_CAPABILITY=avx512"); #else - setenv("ATEN_CPU_CAPABILITY", "avx", 1); + setenv("ATEN_CPU_CAPABILITY", "avx512", 1); #endif - const auto actual_pow_avx = vals_tensor.pow(pows_tensor); + const auto actual_pow_avx512 = vals_tensor.pow(pows_tensor); for (int i = 0; i < 4; i++) { - ASSERT_EQ(result[i], actual_pow_avx[i].item()); + ASSERT_EQ(result[i], actual_pow_avx512[i].item()); } } diff --git a/test/quantization/bc/test_backward_compatibility.py b/test/quantization/bc/test_backward_compatibility.py index f9c24c7898007..65253869ddc81 100644 --- a/test/quantization/bc/test_backward_compatibility.py +++ b/test/quantization/bc/test_backward_compatibility.py @@ -2,7 +2,7 @@ import sys import os - +import unittest # torch import torch import torch.nn as nn @@ -11,7 +11,7 @@ import torch.nn.intrinsic.quantized as nniq # Testing utils -from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.common_utils import TestCase, IS_AVX512_VNNI_SUPPORTED from torch.testing._internal.common_quantized import override_qengines, qengine_is_fbgemm def remove_prefix(text, prefix): @@ -238,6 +238,7 @@ def test_conv3d_relu(self): # TODO: graph mode quantized conv3d module @override_qengines + @unittest.skipIf(IS_AVX512_VNNI_SUPPORTED, "This test fails on machines with AVX512_VNNI support. Ref: GH Issue 59098") def test_lstm(self): class LSTMModule(torch.nn.Module): def __init__(self): diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 94f72b1ee415b..fd30756dca5eb 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -339,6 +339,15 @@ def run_tests(argv=UNITTEST_ARGS): IS_MACOS = sys.platform == "darwin" IS_PPC = platform.machine() == "ppc64le" +def is_avx512_vnni_supported(): + if sys.platform != 'linux': + return False + with open("/proc/cpuinfo", encoding="ascii") as f: + lines = f.read() + return "avx512vnni" in lines + +IS_AVX512_VNNI_SUPPORTED = is_avx512_vnni_supported() + if IS_WINDOWS: @contextmanager def TemporaryFileName(*args, **kwargs):