From 277152c00f7bf0d76772e128c18eb9e194508476 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Tue, 12 May 2020 03:33:35 +0800 Subject: [PATCH 1/7] Ref implementation of FP16 fused rowwise sparse adagrad Reuse the rowwise_sparse_adagrad_fused_ref(...) API with float16 data type as template parameter. Implemented float16 weights update with tochastic rounding. Random integer number generating is based on xoshiro128++. A 64-bytes per-thread global random buffer is maintained and updated each time a new random number is generated. Signed-off-by: Yong Wu --- src/RefImplementations.cc | 191 ++++++++++++++++++++++++++++++++------ src/RefImplementations.h | 9 +- 2 files changed, 168 insertions(+), 32 deletions(-) diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 991c752b31..4130d16bd4 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -16,11 +16,65 @@ #include #include #include +#include using namespace std; namespace fbgemm { +// Thread-safe random number generator +// +// Return a random 32bit integer using xoshiro128++ +// http://prng.di.unimi.it/xoshiro128plusplus.c +inline uint32_t rnd128_next(int v, int vlen) { + constexpr int VLEN_MAX = 16; // max vector size + alignas(64) static thread_local uint32_t g_rnd128_buffer[4 * VLEN_MAX]; + static thread_local bool g_rnd128_initialized = false; + + // Splitmix64: http://prng.di.unimi.it/splitmix64.c + auto rnd128_init_next = [](uint64_t &x) { + uint64_t z = (x += 0x9e3779b97f4a7c15); + z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9; + z = (z ^ (z >> 27)) * 0x94d049bb133111eb; + return z ^ (z >> 31); + }; + + auto rotl = [](const uint32_t x, int k) { + return (x << k) | (x >> (32 - k)); + }; + + if (!g_rnd128_initialized) { + // Initialize rand buffer with uniq values per thread + uint64_t h0 = std::hash{}(std::this_thread::get_id()); + for (auto i = 0; i < 4; ++i) { + // Use thread hash as seed + g_rnd128_buffer[i * VLEN_MAX] = rnd128_init_next(h0); + uint64_t h1 = g_rnd128_buffer[i * VLEN_MAX]; + for (auto v = 1; v < VLEN_MAX; ++v) { + g_rnd128_buffer[i * VLEN_MAX + v] = rnd128_init_next(h1); + } + } + g_rnd128_initialized = true; + } + + const uint32_t result = + rotl(g_rnd128_buffer[v] + g_rnd128_buffer[3 * vlen + v], 7) + + g_rnd128_buffer[v]; + + const uint32_t t = g_rnd128_buffer[1 * vlen + v] << 9; + + g_rnd128_buffer[2 * vlen + v] ^= g_rnd128_buffer[0 * vlen + v]; + g_rnd128_buffer[3 * vlen + v] ^= g_rnd128_buffer[1 * vlen + v]; + g_rnd128_buffer[1 * vlen + v] ^= g_rnd128_buffer[2 * vlen + v]; + g_rnd128_buffer[0 * vlen + v] ^= g_rnd128_buffer[3 * vlen + v]; + + g_rnd128_buffer[2 * vlen + v] ^= t; + + g_rnd128_buffer[3 * vlen + v] = rotl(g_rnd128_buffer[3 * vlen + v], 11); + + return result; +} + void FloatToFloat16_ref( const float* src, float16* dst, @@ -1129,20 +1183,35 @@ int rowwise_sparse_adagrad_ref( return num_rows; } -template +template int rowwise_sparse_adagrad_fused_ref( int64_t block_size, int64_t output_size, int64_t index_size, int64_t data_size, - float* w, + DataType* w, const float* g, float* h, const IndexType* indices, const OffsetType* offsets_or_lengths, float epsilon, float lr, - bool use_offsets) { + bool use_offsets, + bool use_stochastic_rounding, + int emu_vector_size) { + constexpr bool isFloat16w = std::is_same::value; + // TODO: warning on vector-size not 8/16 + int vlen = emu_vector_size; + // Local random buffer to emulate SIMD vector + // R: generated 32bit base random numbers + // r: extracted 8-bit for rounding + uint32_t *R = nullptr, *r = nullptr; + if (isFloat16w && use_stochastic_rounding) { + // Random vector buffer for stochastic rounding + R = new uint32_t[vlen]; + r = new uint32_t[vlen]; + } + int64_t current = 0; for (int m = 0; m < output_size; ++m) { int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] @@ -1160,11 +1229,11 @@ int rowwise_sparse_adagrad_fused_ref( // float gj = g_[j]; // final_sum += gj * gj; // } - constexpr int VLEN = 8; - array partial_sum = {0.0f}; + constexpr int VLEN_AVX2 = 8; + array partial_sum = {0.0f}; for (auto j = 0; j < block_size; ++j) { float gj = g_[j]; - partial_sum[j % VLEN] += gj * gj; + partial_sum[j % VLEN_AVX2] += gj * gj; } float final_sum = ((partial_sum[0] + partial_sum[1]) + (partial_sum[2] + partial_sum[3])) + @@ -1178,16 +1247,75 @@ int rowwise_sparse_adagrad_fused_ref( } float* h_ = h + idx; - float* w_ = w + idx * block_size; + DataType* w_ = w + idx * block_size; float hi = *h_ = *h_ + final_sum; float float_step = lr / (std::sqrt(hi) + epsilon); - for (int j = 0; j < block_size; ++j) { - w_[j] += g_[j] * float_step; + int nvec = (block_size + vlen - 1) / vlen; + int rem = (block_size % vlen) ? (block_size % vlen) : vlen; + + // Emulate JIT behavior of stochastic rounding with vector-length + // + // Generate R buffer every 4 steps of nvec loop. Each 8-bit in R + // (uint32_t) will be used once. It is shifted to bits[5..13] then + // added to FP32 weights before FP16 conversion. + // + // The shifted 8 bit region + // +-------+--------+--------+--------+ + // | | | xxxxx|xxx | + // 31 23 15 7 0 + // + // Half float has 10 bits of mantissa, and float has 23, we are shifting + // the bits to cover the region where half floats can't represent data. + // This is bit 13-23 of the mantissa of fp32. + // This will be effectively adding a random variable of [0,1] + + for (int n = 0; n < nvec; ++n) { + int len = (n == nvec - 1) ? rem : vlen; + int sr_idx = n % 4; + + if (isFloat16w && use_stochastic_rounding) { + if (sr_idx == 0) { + for (int v = 0; v < vlen; ++v) { + R[v] = rnd128_next(v, vlen); + r[v] = (R[v] << 24) >> 19; + } + } else if (sr_idx == 1) { + for (int v = 0; v < vlen; ++v) + r[v] = ((R[v] >> 8) << 24) >> 19; + } else if (sr_idx == 2) { + for (int v = 0; v < vlen; ++v) + r[v] = ((R[v] << 8) >> 24) << 5; + } else { // 3 + for (int v = 0; v < vlen; ++v) + r[v] = (R[v] >> 24) << 5; + } + } + + for (int v = 0; v < len; ++v) { + int j = n * vlen + v; + if (isFloat16w) { + union { + float w_f32; + uint32_t w_i32; + }; + w_f32 = cpu_half2float(w_[j]); + w_f32 = std::fma(float_step, g_[j], w_f32); + if (use_stochastic_rounding) + w_i32 += r[v]; + w_[j] = cpu_float2half_rn(w_f32); + } else { // float + w_[j] += g_[j] * float_step; + } + } } } } + if (R != nullptr) + delete[] R; + if (r != nullptr) + delete[] r; return current == index_size; } @@ -1336,27 +1464,32 @@ template FBGEMM_API int rowwise_sparse_adagrad_ref( float lr, float weight_decay); -#define INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE) \ - template FBGEMM_API int rowwise_sparse_adagrad_fused_ref( \ - int64_t block_size, \ - int64_t output_size, \ - int64_t index_size, \ - int64_t data_size, \ - float* w, \ - const float* g, \ - float* h, \ - const INDEX_TYPE* indices, \ - const OFFSET_TYPE* offsets_or_lengths, \ - float epsilon, \ - float lr, \ - bool use_offsets); +#define INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, OFFSET_TYPE) \ + template FBGEMM_API int rowwise_sparse_adagrad_fused_ref( \ + int64_t block_size, \ + int64_t output_size, \ + int64_t index_size, \ + int64_t data_size, \ + DATA_TYPE* w, \ + const float* g, \ + float* h, \ + const INDEX_TYPE* indices, \ + const OFFSET_TYPE* offsets_or_lengths, \ + float epsilon, \ + float lr, \ + bool use_offsets, \ + bool use_stochastic_rounding, \ + int emu_vector_size); + +#define INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, INDEX_TYPE) \ + INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, int32_t) \ + INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, int64_t) + +INSTANTIATE_SPMDM_OFFSET_T(float, int32_t) +INSTANTIATE_SPMDM_OFFSET_T(float, int64_t) +INSTANTIATE_SPMDM_OFFSET_T(float16, int32_t) +INSTANTIATE_SPMDM_OFFSET_T(float16, int64_t) -#define INSTANTIATE_SPMDM_OFFSET_T(INDEX_TYPE) \ - INSTANTIATE_SPMDM_BASE(INDEX_TYPE, int32_t) \ - INSTANTIATE_SPMDM_BASE(INDEX_TYPE, int64_t) - -INSTANTIATE_SPMDM_OFFSET_T(int32_t) -INSTANTIATE_SPMDM_OFFSET_T(int64_t) #undef INSTANTIATE_SPMDM_OFFSET_T #undef INSTANTIATE_SPMDM_BASE diff --git a/src/RefImplementations.h b/src/RefImplementations.h index c6de1e25c8..dcaa43cb11 100644 --- a/src/RefImplementations.h +++ b/src/RefImplementations.h @@ -9,6 +9,7 @@ #include #include +#include "fbgemm/Types.h" #include "fbgemm/ConvUtils.h" #include "fbgemm/FbgemmI8Spmdm.h" @@ -311,19 +312,21 @@ FBGEMM_API int rowwise_sparse_adagrad_ref( float lr, float weight_decay = 0.f); -template +template FBGEMM_API int rowwise_sparse_adagrad_fused_ref( std::int64_t block_size, std::int64_t output_size, std::int64_t index_size, std::int64_t data_size, - float* w, // input/output parameters + DataType* w, // input/output parameters const float* g, // inupt gradients float* h, // input/output momentums const IndexType* indices, const OffsetType* offsets_or_lengths, float epsilon, float lr, - bool use_offsets = true); + bool use_offsets = true, + bool use_stochastic_rounding = true, // For DataType=float16 + int emu_vector_size = 8); } // namespace fbgemm From 98e16c95130b3ee9533ccc543b74889707860e8b Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Tue, 12 May 2020 05:26:42 +0800 Subject: [PATCH 2/7] JIT implementation of FP16 fused rowwise sparse adagrad Support GenerateRowWiseSparseAdaGradFused(...) JIT API with float16 data type as template parameter and keep backward compatiblity. Implement float16 weights update with stochastic rounding (SR). Random integer number generating is based on xoshiro128++. A 64-bytes per-thread global random buffer is maintained and updated each time a new random number is generated. Support both AVX2 and AVX512 random number generating and float16 weights udpate with SR. Signed-off-by: Yong Wu --- include/fbgemm/FbgemmEmbedding.h | 15 +- src/RowWiseSparseAdagradFused.cc | 459 +++++++++++++++++++++++++------ 2 files changed, 385 insertions(+), 89 deletions(-) diff --git a/include/fbgemm/FbgemmEmbedding.h b/include/fbgemm/FbgemmEmbedding.h index ceb5f60aee..b7f3fb0b65 100644 --- a/include/fbgemm/FbgemmEmbedding.h +++ b/include/fbgemm/FbgemmEmbedding.h @@ -161,14 +161,16 @@ GenerateSparseAdaGrad( float weight_decay = 0.0f); // RowWiseSparseAdaGrad fused with SLS gradient -template +// Weights can be either float or float16 +template class RowWiseSparseAdaGradFusedSignature { public: using Type = std::function; }; -template +template FBGEMM_API - typename RowWiseSparseAdaGradFusedSignature::Type + typename + RowWiseSparseAdaGradFusedSignature::Type GenerateRowWiseSparseAdaGradFused( int block_size, // number of parameters per row int prefetch = 16, - bool use_offsets = true); + bool use_offsets = true, + bool use_stochastic_rounding = true); namespace internal { // Specialization for block size 1 internally called by GenerateEmbeddingSpMDM diff --git a/src/RowWiseSparseAdagradFused.cc b/src/RowWiseSparseAdagradFused.cc index 072a908d7b..9f158bafad 100644 --- a/src/RowWiseSparseAdagradFused.cc +++ b/src/RowWiseSparseAdagradFused.cc @@ -23,34 +23,36 @@ namespace fbgemm { namespace { namespace x86 = asmjit::x86; -template +template class ReturnFunctionSignature { public: using jit_sparse_adagrad_kernel = bool (*)( int64_t output_size, int64_t index_size, int64_t data_size, // number of rows in w - float* w, // input/output parameters + dataType* w, // input/output parameters const float* g, // input gradients float* h, // input/output momentums const indxType* indices, // indices of each row const offsetType* offsets_or_lengths, float epsilon, float lr, - const int* mask_avx2); + uint32_t* rand_buffer); }; template < typename indxType, typename offsetType, + typename dataType, inst_set_t instSet = inst_set_t::avx2> class GenRowWiseSparseAdagradFused { public: GenRowWiseSparseAdagradFused() {} - typename ReturnFunctionSignature:: + typename ReturnFunctionSignature:: jit_sparse_adagrad_kernel - getOrCreate(int block_size, int prefetch, bool use_offsets); + getOrCreate(const int* mask_avx2, int block_size, int prefetch, + bool use_offsets, bool use_stochastic_rounding); private: static asmjit::JitRuntime& runtime() { @@ -63,45 +65,58 @@ class GenRowWiseSparseAdagradFused { // The hash depends on embedding dimension (block size), prefetch distance, // and use_offsets static CodeCache< - tuple, - typename ReturnFunctionSignature:: + tuple, + typename ReturnFunctionSignature:: jit_sparse_adagrad_kernel> codeCache_; ///< JIT Code Cache for reuse. }; // class GenRowWiseSparseAdagradFused -template -mutex GenRowWiseSparseAdagradFused::rtMutex_; +template +mutex GenRowWiseSparseAdagradFused:: +rtMutex_; -template +template CodeCache< - tuple, - typename ReturnFunctionSignature:: + tuple, + typename ReturnFunctionSignature:: jit_sparse_adagrad_kernel> - GenRowWiseSparseAdagradFused::codeCache_; + GenRowWiseSparseAdagradFused + ::codeCache_; -template -typename ReturnFunctionSignature:: +template +typename ReturnFunctionSignature:: jit_sparse_adagrad_kernel - GenRowWiseSparseAdagradFused::getOrCreate( + GenRowWiseSparseAdagradFused:: + getOrCreate( + const int* mask_avx2, // runtime constant int block_size, int prefetch, - bool use_offsets) { - tuple kernelSig = - make_tuple(block_size, prefetch, use_offsets); + bool use_offsets, + bool use_stochastic_rounding) { + tuple kernelSig = + make_tuple(mask_avx2, block_size, prefetch, use_offsets, + use_stochastic_rounding); return codeCache_.getOrCreate( kernelSig, [&]() -> typename ReturnFunctionSignature< indxType, - offsetType>::jit_sparse_adagrad_kernel { + offsetType, + dataType>::jit_sparse_adagrad_kernel { asmjit::CodeHolder code; code.init(runtime().codeInfo()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); bool areIndices64b = is_same::value; + bool areWeightsFp16 = is_same::value; #if defined(FBGEMM_LOG_CODE) string filename = "RowWiseSparseAdagradFused"; filename += "_emd_dim_" + to_string(block_size); + filename += "_wei_float"; + filename += areWeightsFp16 ? "16" : "32"; filename += areIndices64b ? "_64bit" : "_32bit"; filename += instSet == inst_set_t::avx512 ? "_avx512" : "_avx2"; if (prefetch) { @@ -123,10 +138,10 @@ typename ReturnFunctionSignature:: x86::Gp lengths = a->gpz(11); x86::Xmm epsilon(0); x86::Xmm lr(1); - x86::Gp mask_avx2 = a->gpz(12); + x86::Gp rand_buffer = a->gpz(12); - // reuse mask_avx2 because mask_avx2 is used only at the beginning - x86::Gpd lengths_R = a->gpz(12).r32(); + // FP32 weights does not need rand_buffer + x86::Gpd lengths_R = areWeightsFp16 ? a->zbx().r32() : a->gpz(12).r32(); x86::Gp scratchReg1 = a->gpz(13); x86::Gp scratchReg2 = a->gpz(14); // for prefetching @@ -136,14 +151,14 @@ typename ReturnFunctionSignature:: int64_t, // output_size int64_t, // index_size int64_t, // data_size - float*, // w + dataType*, // w const float*, // g float*, // h const indxType*, // indices const int*, // lengths float, // epsilon - float, // lr then mask_avx2 - const int*>(asmjit::CallConv::kIdHost)); + float, // lr then rand_buffer + uint32_t*>(asmjit::CallConv::kIdHost)); asmjit::FuncFrame frame; frame.init(func); @@ -162,10 +177,16 @@ typename ReturnFunctionSignature:: asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); } - // TODO - frame.setDirtyRegs( - x86::Reg::kGroupGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); + if (areWeightsFp16) { + // Random buffer will use R12 exclusively + frame.setDirtyRegs( + x86::Reg::kGroupGp, + asmjit::Support::bitMask(3, 8, 9, 10, 11, 12, 13, 14)); + } else { + frame.setDirtyRegs( + x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); + } asmjit::FuncArgsAssignment args(&func); args.assignAll( @@ -179,7 +200,7 @@ typename ReturnFunctionSignature:: lengths, epsilon, lr, - mask_avx2); + rand_buffer); args.updateFuncFrame(frame); frame.finalize(); @@ -210,6 +231,51 @@ typename ReturnFunctionSignature:: vec_reg_t lr_vreg = vec_reg_t(first_available_vec_reg_id); ++first_available_vec_reg_id; + a->vpbroadcastd(epsilon_vreg, epsilon); + a->vpbroadcastd(lr_vreg, lr); + + // Reserve vector registers for random buffer generating + // S0...S3: global random buffer state + // R: generated random number in uint32_t + // r0: extracted random byte (uint8_t) shifted to bits[5...13] + // r1: temp + vec_reg_t R_vreg, S0_vreg, S1_vreg, S2_vreg, S3_vreg, r0_vreg, r1_vreg; + if (areWeightsFp16 && use_stochastic_rounding) { + R_vreg = vec_reg_t(first_available_vec_reg_id); + first_available_vec_reg_id++; + S0_vreg = vec_reg_t(first_available_vec_reg_id); + first_available_vec_reg_id++; + S1_vreg = vec_reg_t(first_available_vec_reg_id); + first_available_vec_reg_id++; + S2_vreg = vec_reg_t(first_available_vec_reg_id); + first_available_vec_reg_id++; + S3_vreg = vec_reg_t(first_available_vec_reg_id); + first_available_vec_reg_id++; + r0_vreg = vec_reg_t(first_available_vec_reg_id); + first_available_vec_reg_id++; + r1_vreg = vec_reg_t(first_available_vec_reg_id); + first_available_vec_reg_id++; + + // Load random buffer for FP16 stochastic rounding + if (instSet == inst_set_t::avx2) { + a->vmovdqa(S0_vreg.ymm(), x86::dword_ptr(rand_buffer)); + a->vmovdqa(S1_vreg.ymm(), x86::dword_ptr(rand_buffer, + 1 * vlen * sizeof(uint32_t))); + a->vmovdqa(S2_vreg.ymm(), x86::dword_ptr(rand_buffer, + 2 * vlen * sizeof(uint32_t))); + a->vmovdqa(S3_vreg.ymm(), x86::dword_ptr(rand_buffer, + 3 * vlen * sizeof(uint32_t))); + } else { // AVX512 + a->vmovdqa32(S0_vreg, x86::dword_ptr(rand_buffer)); + a->vmovdqa32(S1_vreg, x86::dword_ptr(rand_buffer, + 1 * vlen * sizeof(uint32_t))); + a->vmovdqa32(S2_vreg, x86::dword_ptr(rand_buffer, + 2 * vlen * sizeof(uint32_t))); + a->vmovdqa32(S3_vreg, x86::dword_ptr(rand_buffer, + 3 * vlen * sizeof(uint32_t))); + } + } + if (remainder) { if (instSet == inst_set_t::avx2) { src_vreg = vec_reg_t(first_available_vec_reg_id); @@ -217,11 +283,12 @@ typename ReturnFunctionSignature:: mask_vreg = x86::Ymm(first_available_vec_reg_id); ++first_available_vec_reg_id; - + // Use scratchReg1 as temp + a->mov(scratchReg1, asmjit::imm(mask_avx2)); a->vmovups( mask_vreg, x86::ymmword_ptr( - mask_avx2, (vlen - remainder) % vlen * sizeof(int32_t))); + scratchReg1, (vlen - remainder) % vlen * sizeof(int32_t))); } else { a->mov(scratchReg1, (1 << remainder) - 1); a->kmovw(x86::k(1), scratchReg1); @@ -239,9 +306,6 @@ typename ReturnFunctionSignature:: int unroll_factor = NUM_VEC_REG - first_available_vec_reg_id; - a->vpbroadcastd(epsilon_vreg, epsilon); - a->vpbroadcastd(lr_vreg, lr); - // Compute the end address of indices a->imul( scratchReg1, @@ -398,22 +462,19 @@ typename ReturnFunctionSignature:: } a->bind(pref_dist_reset_end); - a->imul(scratchReg2, static_cast(sizeof(float))); } a->add(indices, static_cast(sizeof(indxType))); - a->imul(scratchReg1, static_cast(sizeof(float))); - if (prefetch) { - a->prefetchw(x86::dword_ptr(h, scratchReg2)); + a->prefetchw(x86::dword_ptr(h, scratchReg2, 2)); } // load h - a->movss(float_step_xmm, x86::dword_ptr(h, scratchReg1)); + a->movss(float_step_xmm, x86::dword_ptr(h, scratchReg1, 2)); // *h + final_sum a->addss(float_step_xmm, partial_sum_xmm); // store h - a->movss(x86::dword_ptr(h, scratchReg1), float_step_xmm); + a->movss(x86::dword_ptr(h, scratchReg1, 2), float_step_xmm); // sqrt(hi) a->sqrtss(float_step_xmm, float_step_xmm); // bcast partial to all of ymm/zmm reg @@ -438,35 +499,170 @@ typename ReturnFunctionSignature:: auto g_ptr = x86::dword_ptr(g, (vec_idx + v) * vlen * sizeof(float)); - auto w_ptr = x86::dword_ptr( - w, scratchReg1, 0, (vec_idx + v) * vlen * sizeof(float)); - if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { - if (instSet == inst_set_t::avx2) { - a->vmaskmovps(src_vreg.ymm(), mask_vreg, g_ptr); - a->vmulps(src_vreg, float_step_vreg, src_vreg); - - a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr); - a->vaddps(out_vreg, src_vreg, out_vreg); + if (!areWeightsFp16) { // float weights + auto w_ptr = x86::dword_ptr( + w, scratchReg1, 2, + (vec_idx + v) * vlen * sizeof(dataType)); + if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { + if (instSet == inst_set_t::avx2) { + a->vmaskmovps(src_vreg.ymm(), mask_vreg, g_ptr); + a->vmulps(src_vreg, float_step_vreg, src_vreg); + + a->vmaskmovps(out_vreg.ymm(), mask_vreg, w_ptr); + a->vaddps(out_vreg, src_vreg, out_vreg); + + a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm()); + } else { + a->k(x86::k(1)).vmulps(out_vreg, float_step_vreg, g_ptr); + a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr); + a->k(x86::k(1)).vmovups(w_ptr, out_vreg); + } + } else { + a->vmulps(out_vreg, float_step_vreg, g_ptr); + a->vaddps(out_vreg, out_vreg, w_ptr); + a->vmovups(w_ptr, out_vreg); + } + } else { // float16 weights + auto w_ptr = x86::word_ptr( + w, scratchReg1, 1, + (vec_idx + v) * vlen * sizeof(dataType)); + + if (use_stochastic_rounding) { + // Index [0..3] for extracted bytes + // Each int32 has 4 8-bit rand byte + int sr_idx = (vec_idx + v) % 4; + + if (sr_idx == 0) { + // Generate R buffer every 4 steps of num_vec_regs_per_block + // loop. Each 8-bit in R (uint32_t) will be used once. It is + // shifted to the bits [5-13] then added to FP32 weights + // before FP16 conversion. + // + // The shifted 8 bit region + // +-------+--------+--------+--------+ + // | | | xxxxx|xxx | + // 31 23 15 7 0 + // + // Half float has 10 bits of mantissa, and float has 23, we + // are shifting the bits to cover the region where half + // floats can't represent data. This is bits[13..23] of the + // mantissa of FP32. This will be effectively adding a random + // variable of [0,1] + + // Random generator using xoshiro128++ + // Ref: http://prng.di.unimi.it/xoshiro128plusplus.c + a->vpaddd(r0_vreg, S0_vreg, S3_vreg); + a->vpslld(r1_vreg, r0_vreg, 7); + a->vpsrld(r0_vreg, r0_vreg, 25); + if (instSet == inst_set_t::avx2) + a->vpor(R_vreg.ymm(), r0_vreg.ymm(), r1_vreg.ymm()); + else + a->vpord(R_vreg, r0_vreg, r1_vreg); + a->vpaddd(R_vreg, R_vreg, S0_vreg); + + a->vpslld(r0_vreg, S1_vreg, 9); + + if (instSet == inst_set_t::avx2) { + a->vpxor(S2_vreg.ymm(), S2_vreg.ymm(), S0_vreg.ymm()); + a->vpxor(S3_vreg.ymm(), S3_vreg.ymm(), S1_vreg.ymm()); + a->vpxor(S1_vreg.ymm(), S1_vreg.ymm(), S2_vreg.ymm()); + a->vpxor(S0_vreg.ymm(), S0_vreg.ymm(), S3_vreg.ymm()); + + a->vpxor(S2_vreg.ymm(), S2_vreg.ymm(), r0_vreg.ymm()); + } else { + a->vpxord(S2_vreg, S2_vreg, S0_vreg); + a->vpxord(S3_vreg, S3_vreg, S1_vreg); + a->vpxord(S1_vreg, S1_vreg, S2_vreg); + a->vpxord(S0_vreg, S0_vreg, S3_vreg); + + a->vpxord(S2_vreg, S2_vreg, r0_vreg); + } + a->vpslld(r0_vreg, S3_vreg, 11); + a->vpsrld(r1_vreg, S3_vreg, 21); + if (instSet == inst_set_t::avx2) + a->vpor(S3_vreg.ymm(), r0_vreg.ymm(), r1_vreg.ymm()); + else + a->vpord(S3_vreg, r0_vreg, r1_vreg); + + // Extract byte 0 and shift to bits[5..13] + a->vpslld(r0_vreg, R_vreg, 24); + a->vpsrld(r0_vreg, r0_vreg, 19); + } else if (sr_idx == 1) { + // Extract byte 1 and shift to bits[[5..13] + a->vpsrld(r0_vreg, R_vreg, 8); + a->vpslld(r0_vreg, r0_vreg, 24); + a->vpsrld(r0_vreg, r0_vreg, 19); + } else if (sr_idx == 2) { + // Extract byte 2 and shift to bits[5..13] + a->vpslld(r0_vreg, R_vreg, 8); + a->vpsrld(r0_vreg, r0_vreg, 24); + a->vpslld(r0_vreg, r0_vreg, 5); + } else { // sr_idx == 3 + // Extract byte 3 and shift to bits[5..13] + a->vpsrld(r0_vreg, R_vreg, 24); + a->vpslld(r0_vreg, r0_vreg, 5); + } + } - a->vmaskmovps(w_ptr, mask_vreg, out_vreg.ymm()); + if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { + if (instSet == inst_set_t::avx2) { + a->vmaskmovps(src_vreg.ymm(), mask_vreg, g_ptr); + // No AVX2 mask load/store for 16bit + // Copy input to stack using loop instead and reuse GPR for h + a->lea(x86::rsp, x86::ptr(x86::rsp, -8)); + a->mov(x86::ptr(x86::rsp), h); + a->lea(x86::rsp, x86::ptr + (x86::rsp, static_cast(-vlen * sizeof(float16)))); + for (size_t r = 0; r < remainder; ++r) { + a->mov(h.r16(), x86::word_ptr + (w, scratchReg1, 1, + ((vec_idx + v) * vlen + r) * sizeof(dataType))); + a->mov(x86::ptr(x86::rsp, sizeof(dataType) * r), h.r16()); + } + a->vcvtph2ps(out_vreg, x86::word_ptr(x86::rsp)); + a->vfmadd231ps(out_vreg, float_step_vreg, src_vreg); + if (use_stochastic_rounding) { + a->vpaddd(out_vreg, r0_vreg, out_vreg); + } + a->vcvtps2ph(x86::word_ptr(x86::rsp), out_vreg, 0); + // Copy results back + for (size_t r = 0; r < remainder; ++r) { + a->mov(h.r16(), x86::ptr(x86::rsp, sizeof(dataType) * r)); + a->mov(x86::word_ptr + (w, scratchReg1, 1, + ((vec_idx + v) * vlen + r) * sizeof(dataType)), + h.r16()); + } + a->lea(x86::rsp, x86::ptr + (x86::rsp, static_cast(vlen * sizeof(float16)))); + a->mov(h, x86::ptr(x86::rsp)); + a->lea(x86::rsp, x86::ptr(x86::rsp, 8)); + } else { + a->k(x86::k(1)).vcvtph2ps(out_vreg, w_ptr); + a->k(x86::k(1)).vfmadd231ps(out_vreg, float_step_vreg, g_ptr); + if (use_stochastic_rounding) { + a->vpaddd(out_vreg, r0_vreg, out_vreg); + } + a->k(x86::k(1)).vcvtps2ph(w_ptr, out_vreg, 0); + } } else { - a->k(x86::k(1)).vmulps(out_vreg, float_step_vreg, g_ptr); - a->k(x86::k(1)).vaddps(out_vreg, out_vreg, w_ptr); - a->k(x86::k(1)).vmovups(w_ptr, out_vreg); + a->vcvtph2ps(out_vreg, w_ptr); + a->vfmadd231ps(out_vreg, float_step_vreg, g_ptr); + if (use_stochastic_rounding) { + a->vpaddd(out_vreg, r0_vreg, out_vreg); + } + a->vcvtps2ph(w_ptr, out_vreg, 0); } - } else { - a->vmulps(out_vreg, float_step_vreg, g_ptr); - a->vaddps(out_vreg, out_vreg, w_ptr); - a->vmovups(w_ptr, out_vreg); } constexpr int CACHE_LINE_LEN = 64; - constexpr int BYTES_PER_VLOAD = vlen * sizeof(float); + constexpr int BYTES_PER_VLOAD = vlen * sizeof(dataType); constexpr int VLOAD_PER_CACHE_LINE = CACHE_LINE_LEN / BYTES_PER_VLOAD; if (prefetch && (vec_idx + v) % VLOAD_PER_CACHE_LINE == 0) { a->prefetchw(x86::dword_ptr( - w, scratchReg2, 0, (vec_idx + v) * BYTES_PER_VLOAD)); + w, scratchReg2, areWeightsFp16 ? 1 : 2, + (vec_idx + v) * BYTES_PER_VLOAD)); } } } @@ -487,10 +683,30 @@ typename ReturnFunctionSignature:: a->bind(error); a->mov(x86::eax, false); a->bind(exit); + + if (areWeightsFp16 && use_stochastic_rounding) { + if (instSet == inst_set_t::avx2) { + a->vmovdqa(x86::dword_ptr(rand_buffer), S0_vreg.ymm()); + a->vmovdqa(x86::dword_ptr(rand_buffer, + 1 * vlen * sizeof(uint32_t)), S1_vreg.ymm()); + a->vmovdqa(x86::dword_ptr(rand_buffer, + 2 * vlen * sizeof(uint32_t)), S2_vreg.ymm()); + a->vmovdqa(x86::dword_ptr(rand_buffer, + 3 * vlen * sizeof(uint32_t)), S3_vreg.ymm()); + } else { + a->vmovdqa32(x86::dword_ptr(rand_buffer), S0_vreg); + a->vmovdqa32(x86::dword_ptr(rand_buffer, + 1 * vlen * sizeof(uint32_t)), S1_vreg); + a->vmovdqa32(x86::dword_ptr(rand_buffer, + 2 * vlen * sizeof(uint32_t)), S2_vreg); + a->vmovdqa32(x86::dword_ptr(rand_buffer, + 3 * vlen * sizeof(uint32_t)), S3_vreg); + } + } a->emitEpilog(frame); // jit_fused8bitembedding_kernel fn; - typename ReturnFunctionSignature:: + typename ReturnFunctionSignature:: jit_sparse_adagrad_kernel fn; asmjit::Error err; { @@ -512,33 +728,70 @@ typename ReturnFunctionSignature:: } // namespace -template +// Per-thread global buffer for random number generating, with max vector size +constexpr size_t VLEN_MAX = simd_info::WIDTH_32BIT_ELEMS; +alignas(64) static thread_local uint32_t g_rnd128v_buffer[4 * VLEN_MAX]; +static thread_local bool g_rnd128v_initialized = false; + +template FBGEMM_API - typename RowWiseSparseAdaGradFusedSignature::Type + typename + RowWiseSparseAdaGradFusedSignature::Type GenerateRowWiseSparseAdaGradFused( int block_size, // number of parameters per row int prefetch, - bool use_offsets) { + bool use_offsets, + bool use_stochastic_rounding) { if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } // Always use avx2 because avx512 doesn't provide speedups if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { - static GenRowWiseSparseAdagradFused - kernel_generator; + + static GenRowWiseSparseAdagradFused kernel_generator; const auto original_func = - kernel_generator.getOrCreate(block_size, prefetch, use_offsets); + kernel_generator.getOrCreate(internal::avx2_ps_or_epi32_combined_mask, + block_size, + prefetch, + use_offsets, + use_stochastic_rounding); const auto lambda_func = [=](int64_t output_size, int64_t index_size, int64_t data_size, - float* w, + DataType* w, const float* g, float* h, const IndexType* indices, const OffsetType* offsets_or_lengths, float epsilon, float lr) { + // Initialize random buffer in the first execution + // TODO: JIT + if (std::is_same::value) { + // Splitmix64: http://prng.di.unimi.it/splitmix64.c + auto rnd128_init_next = [](uint64_t &x) { + uint64_t z = (x += 0x9e3779b97f4a7c15); + z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9; + z = (z ^ (z >> 27)) * 0x94d049bb133111eb; + return z ^ (z >> 31); + }; + + if (!g_rnd128v_initialized) { + uint64_t h0 = + std::hash{}(std::this_thread::get_id()); + for (auto i = 0; i < 4; ++i) { + g_rnd128v_buffer[i * VLEN_MAX] = rnd128_init_next(h0); + uint64_t h1 = g_rnd128v_buffer[i * VLEN_MAX]; + for (auto v = 1; v < VLEN_MAX; ++v) { + g_rnd128v_buffer[i * VLEN_MAX + v] = rnd128_init_next(h1); + } + } + g_rnd128v_initialized = true; + } + } + return original_func( output_size, index_size, @@ -550,14 +803,14 @@ FBGEMM_API offsets_or_lengths, epsilon, lr, - internal::avx2_ps_or_epi32_combined_mask); + g_rnd128v_buffer); }; return lambda_func; } else { return [=](int64_t output_size, int64_t index_size, int64_t data_size, - float* w, + DataType* w, const float* g, float* h, const IndexType* indices, @@ -575,37 +828,75 @@ FBGEMM_API indices, offsets_or_lengths, epsilon, - lr); + lr, + use_offsets, + use_stochastic_rounding); }; } } template FBGEMM_API - typename RowWiseSparseAdaGradFusedSignature::Type - GenerateRowWiseSparseAdaGradFused( + typename RowWiseSparseAdaGradFusedSignature::Type + GenerateRowWiseSparseAdaGradFused( + int block_size, // number of parameters per row + int prefetch, + bool use_offsets, + bool use_stochastic_rounding); + +template FBGEMM_API + typename RowWiseSparseAdaGradFusedSignature::Type + GenerateRowWiseSparseAdaGradFused( + int block_size, // number of parameters per row + int prefetch, + bool use_offsets, + bool use_stochastic_rounding); + +template FBGEMM_API + typename RowWiseSparseAdaGradFusedSignature::Type + GenerateRowWiseSparseAdaGradFused( + int block_size, // number of parameters per row + int prefetch, + bool use_offsets, + bool use_stochastic_rounding); + +template FBGEMM_API + typename RowWiseSparseAdaGradFusedSignature::Type + GenerateRowWiseSparseAdaGradFused( + int block_size, // number of parameters per row + int prefetch, + bool use_offsets, + bool use_stochastic_rounding); + +template FBGEMM_API + typename RowWiseSparseAdaGradFusedSignature::Type + GenerateRowWiseSparseAdaGradFused( int block_size, // number of parameters per row int prefetch, - bool use_offsets); + bool use_offsets, + bool use_stochastic_rounding); template FBGEMM_API - typename RowWiseSparseAdaGradFusedSignature::Type - GenerateRowWiseSparseAdaGradFused( + typename RowWiseSparseAdaGradFusedSignature::Type + GenerateRowWiseSparseAdaGradFused( int block_size, // number of parameters per row int prefetch, - bool use_offsets); + bool use_offsets, + bool use_stochastic_rounding); template FBGEMM_API - typename RowWiseSparseAdaGradFusedSignature::Type - GenerateRowWiseSparseAdaGradFused( + typename RowWiseSparseAdaGradFusedSignature::Type + GenerateRowWiseSparseAdaGradFused( int block_size, // number of parameters per row int prefetch, - bool use_offsets); + bool use_offsets, + bool use_stochastic_rounding); template FBGEMM_API - typename RowWiseSparseAdaGradFusedSignature::Type - GenerateRowWiseSparseAdaGradFused( + typename RowWiseSparseAdaGradFusedSignature::Type + GenerateRowWiseSparseAdaGradFused( int block_size, // number of parameters per row int prefetch, - bool use_offsets); + bool use_offsets, + bool use_stochastic_rounding); } // namespace fbgemm From 8f0240d9704eddbe1cac5f2ca077a1d89aab49ed Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Tue, 12 May 2020 05:36:44 +0800 Subject: [PATCH 3/7] Extend RowWiseSparseAdagradFusedTest with float16 weights update using stochastic rounding Signed-off-by: Yong Wu --- test/RowWiseSparseAdagradFusedTest.cc | 213 +++++++++++--------------- 1 file changed, 93 insertions(+), 120 deletions(-) diff --git a/test/RowWiseSparseAdagradFusedTest.cc b/test/RowWiseSparseAdagradFusedTest.cc index bc2cba9541..b899aae70a 100644 --- a/test/RowWiseSparseAdagradFusedTest.cc +++ b/test/RowWiseSparseAdagradFusedTest.cc @@ -52,13 +52,14 @@ namespace { class RowWiseSparseAdagradFusedTest : public testing::TestWithParam< - tuple> {}; + tuple> {}; }; // namespace INSTANTIATE_TEST_CASE_P( InstantiationName, RowWiseSparseAdagradFusedTest, ::testing::Combine( + ::testing::Bool(), // isWeightFp16 ::testing::Bool(), // isIndex64b ::testing::Bool(), // isOffset64b ::testing::ValuesIn(prefetch_distances), @@ -71,10 +72,11 @@ INSTANTIATE_TEST_CASE_P( TEST_P(RowWiseSparseAdagradFusedTest, rowwiseTest) { vector> inputs(GetInputs_()); - bool isIndex64b, isOffset64b, use_offsets; + bool isWeightFp16, isIndex64b, isOffset64b, use_offsets; int prefetch; EmbeddingSpMDMCornerCase corner_case; - tie(isIndex64b, isOffset64b, prefetch, use_offsets, corner_case) = GetParam(); + tie(isWeightFp16, isIndex64b, isOffset64b, prefetch, use_offsets, corner_case) + = GetParam(); for (auto input : inputs) { int batch_size = input[0]; @@ -85,10 +87,12 @@ TEST_P(RowWiseSparseAdagradFusedTest, rowwiseTest) { // Create embedding table vector w(num_rows * embedding_dim), w_ref(num_rows * embedding_dim), h(num_rows), h_ref(num_rows), g(batch_size * embedding_dim); + vector w_fp16(w.size()), w_fp16_ref(w.size()); default_random_engine generator; uniform_real_distribution values_gen(0, 2); for (int i = 0; i < w.size(); ++i) { w_ref[i] = w[i] = values_gen(generator); + w_fp16_ref[i] = w_fp16[i] = cpu_float2half_rn(w[i]); } for (int i = 0; i < h.size(); ++i) { h_ref[i] = h[i] = values_gen(generator); @@ -121,122 +125,82 @@ TEST_P(RowWiseSparseAdagradFusedTest, rowwiseTest) { float epsilon = 1e-5; float lr = 0.5; - bool success, success_ref; - if (isOffset64b) { - if (isIndex64b) { - success_ref = rowwise_sparse_adagrad_fused_ref( - embedding_dim, - batch_size, - lengths_sum, - num_rows, - w_ref.data(), - g.data(), - h_ref.data(), - corner_case == EMPTY_INDICES ? nullptr : indices.data(), - offsets_or_lengths, - epsilon, - lr, - use_offsets); - - auto kernel = GenerateRowWiseSparseAdaGradFused( - embedding_dim, prefetch, use_offsets); - success = kernel( - batch_size, - lengths_sum, - num_rows, - w.data(), - g.data(), - h.data(), - corner_case == EMPTY_INDICES ? nullptr : indices.data(), - offsets_or_lengths, - epsilon, - lr); - } else { // 32 bit indices - success_ref = rowwise_sparse_adagrad_fused_ref( - embedding_dim, - batch_size, - lengths_sum, - num_rows, - w_ref.data(), - g.data(), - h_ref.data(), - corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), - offsets_or_lengths, - epsilon, - lr, - use_offsets); +#define REF(Weights, Indices, Offsets) do { \ + success_ref = rowwise_sparse_adagrad_fused_ref( \ + embedding_dim, \ + batch_size, \ + lengths_sum, \ + num_rows, \ + Weights, \ + g.data(), \ + h_ref.data(), \ + corner_case == EMPTY_INDICES ? nullptr : Indices, \ + Offsets, \ + epsilon, \ + lr, \ + use_offsets); \ +} while(0) + +#define JIT(WeightType, IndexType, OffsetType, Weights, Indices, Offsets) do { \ + auto kernel = GenerateRowWiseSparseAdaGradFused< \ + IndexType, OffsetType, WeightType>(embedding_dim, prefetch, use_offsets); \ + success = kernel( \ + batch_size, \ + lengths_sum, \ + num_rows, \ + Weights, \ + g.data(), \ + h.data(), \ + corner_case == EMPTY_INDICES ? nullptr : Indices, \ + Offsets, \ + epsilon, \ + lr); \ +} while(0) - auto kernel = GenerateRowWiseSparseAdaGradFused( - embedding_dim, prefetch, use_offsets); - success = kernel( - batch_size, - lengths_sum, - num_rows, - w.data(), - g.data(), - h.data(), - corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), - offsets_or_lengths, - epsilon, - lr); + bool success, success_ref; + if (isWeightFp16) { + if (isOffset64b) { + if (isIndex64b) { + REF(w_fp16_ref.data(), indices.data(), offsets_or_lengths); + JIT(float16, int64_t, int64_t, + w_fp16.data(), indices.data(), offsets_or_lengths); + } else { // 32 bit indices + REF(w_fp16_ref.data(), indices_32.data(), offsets_or_lengths); + JIT(float16, int32_t, int64_t, + w_fp16.data(), indices_32.data(), offsets_or_lengths); + } + } else { // 32 bit offset + if (isIndex64b) { + REF(w_fp16_ref.data(), indices.data(), offsets_or_lengths_32); + JIT(float16, int64_t, int32_t, + w_fp16.data(), indices.data(), offsets_or_lengths_32); + } else { // 32 bit indices + REF(w_fp16_ref.data(), indices_32.data(), offsets_or_lengths_32); + JIT(float16, int32_t, int32_t, + w_fp16.data(), indices_32.data(), offsets_or_lengths_32); + } } - } else { - if (isIndex64b) { - success_ref = rowwise_sparse_adagrad_fused_ref( - embedding_dim, - batch_size, - lengths_sum, - num_rows, - w_ref.data(), - g.data(), - h_ref.data(), - corner_case == EMPTY_INDICES ? nullptr : indices.data(), - offsets_or_lengths, - epsilon, - lr, - use_offsets); - - auto kernel = GenerateRowWiseSparseAdaGradFused( - embedding_dim, prefetch, use_offsets); - success = kernel( - batch_size, - lengths_sum, - num_rows, - w.data(), - g.data(), - h.data(), - corner_case == EMPTY_INDICES ? nullptr : indices.data(), - offsets_or_lengths_32, - epsilon, - lr); - } else { // 32 bit indices - success_ref = rowwise_sparse_adagrad_fused_ref( - embedding_dim, - batch_size, - lengths_sum, - num_rows, - w_ref.data(), - g.data(), - h_ref.data(), - corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), - offsets_or_lengths, - epsilon, - lr, - use_offsets); - - auto kernel = GenerateRowWiseSparseAdaGradFused( - embedding_dim, prefetch, use_offsets); - success = kernel( - batch_size, - lengths_sum, - num_rows, - w.data(), - g.data(), - h.data(), - corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), - offsets_or_lengths_32, - epsilon, - lr); + } else { // 32 bit of weights + if (isOffset64b) { + if (isIndex64b) { + REF(w_ref.data(), indices.data(), offsets_or_lengths); + JIT(float, int64_t, int64_t, + w.data(), indices.data(), offsets_or_lengths); + } else { // 32 bit indices + REF(w_ref.data(), indices_32.data(), offsets_or_lengths); + JIT(float, int32_t, int64_t, + w.data(), indices_32.data(), offsets_or_lengths); + } + } else { // 32 bit offset + if (isIndex64b) { + REF(w_ref.data(), indices.data(), offsets_or_lengths_32); + JIT(float, int64_t, int32_t, + w.data(), indices.data(), offsets_or_lengths_32); + } else { // 32 bit indices + REF(w_ref.data(), indices_32.data(), offsets_or_lengths_32); + JIT(float, int32_t, int32_t, + w.data(), indices_32.data(), offsets_or_lengths_32); + } } } @@ -249,10 +213,19 @@ TEST_P(RowWiseSparseAdagradFusedTest, rowwiseTest) { << "results for h differ at (" << i << ") reference: " << h_ref[i] << ", FBGEMM: " << h[i] << " emb dim :" << embedding_dim; } + for (int i = 0; i < w.size(); ++i) { - EXPECT_EQ(w[i], w_ref[i]) - << "results for w differ at (" << i << ") reference: " << w_ref[i] - << ", FBGEMM: " << w[i] << " emb dim :" << embedding_dim; + float w_, w_ref_; + if (isWeightFp16) { + w_ = cpu_half2float(w_fp16[i]); + w_ref_ = cpu_half2float(w_fp16_ref[i]); + } else { + w_ = w[i]; + w_ref_ = w_ref[i]; + } + EXPECT_EQ(w_, w_ref_) + << "results for w differ at (" << i << ") reference: " << w_ref_ + << ", FBGEMM: " << w_ << " emb dim :" << embedding_dim; } } } From 28d46bc07d4b27f5387cea4c4f02f2dc5df00c2f Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Tue, 12 May 2020 05:44:56 +0800 Subject: [PATCH 4/7] Extend RowwiseAdagradFusedBenchmark with float16 weights update using stochastic rounding Signed-off-by: Yong Wu --- bench/RowwiseAdagradFusedBenchmark.cc | 152 ++++++++++++++++---------- 1 file changed, 95 insertions(+), 57 deletions(-) diff --git a/bench/RowwiseAdagradFusedBenchmark.cc b/bench/RowwiseAdagradFusedBenchmark.cc index c926408f55..151be9d4ac 100644 --- a/bench/RowwiseAdagradFusedBenchmark.cc +++ b/bench/RowwiseAdagradFusedBenchmark.cc @@ -45,14 +45,14 @@ void run_benchmark( int num_rows, int embedding_dim, int average_len, + bool use_fp16_weights, bool use_32_bit_indices = false, bool prefetch = false) { vector llc(64L * 1024L * 1024L, 1.0); vector g(batch_size * embedding_dim); // gradients vector h(num_rows); // input momentums vector w(num_rows * embedding_dim); // input params - vector h_ref(h.size()); - vector w_ref(w.size()); + vector w_fp16(w.size()); // input params default_random_engine generator; // normal_distribution h_w_distribution; @@ -62,10 +62,10 @@ void run_benchmark( g[i] = 4 + i; // h_w_distribution(generator); } for (int i = 0; i < h.size(); ++i) { - h_ref[i] = h[i] = 2 + i; // h_w_distribution(generator); + h[i] = 2 + i; // h_w_distribution(generator); } for (int i = 0; i < w.size(); ++i) { - w_ref[i] = w[i] = 3 + i; // h_w_distribution(generator); + w[i] = 3 + i; // h_w_distribution(generator); } // Generate lengths @@ -104,47 +104,81 @@ void run_benchmark( constexpr int NUM_ITER = 10; // Only counts the number of bytes for reading embedding table and ignore // others. Should be good enough as long as embdding_dim is big enough. - double bytes = lengths_sum * - ((embedding_dim + 1) * sizeof(float) * 2 + - (use_32_bit_indices ? 4 : 8)) + - batch_size * (embedding_dim * sizeof(float) + sizeof(int)); - double bytes_padded = lengths_sum * - (((embedding_dim * sizeof(float) + 63) / 64 + 1) * 64 * 2 + - (use_32_bit_indices ? 4 : 8)) + - batch_size * (embedding_dim * sizeof(float) + sizeof(int)); - - auto kernel_i32 = GenerateRowWiseSparseAdaGradFused( - embedding_dim, prefetch ? 16 : 0); - auto kernel_i64 = GenerateRowWiseSparseAdaGradFused( - embedding_dim, prefetch ? 16 : 0); + double bytes = + lengths_sum * ((embedding_dim + 1) * + (use_fp16_weights ? sizeof(float16) : sizeof(float)) * 2 + + (use_32_bit_indices ? 4 : 8)) + + batch_size * (embedding_dim * sizeof(float) + sizeof(int)); + // FIXME: float16 is counted as float for effective byte loading + double bytes_padded = + lengths_sum * (((embedding_dim * sizeof(float) + 63) / 64 + 1) * 64 * 2 + + (use_32_bit_indices ? 4 : 8)) + + batch_size * (embedding_dim * sizeof(float) + sizeof(int)); + + auto kernel_i32 = GenerateRowWiseSparseAdaGradFused + (embedding_dim, prefetch ? 16 : 0); + auto kernel_i64 = GenerateRowWiseSparseAdaGradFused + (embedding_dim, prefetch ? 16 : 0); + auto kernel_fp16_i32 = GenerateRowWiseSparseAdaGradFused + (embedding_dim, prefetch ? 16 : 0); + auto kernel_fp16_i64 = GenerateRowWiseSparseAdaGradFused + (embedding_dim, prefetch ? 16 : 0); for (bool flush_cache : {false, true}) { double t = measureWithWarmup( [&]() { - if (use_32_bit_indices) { - kernel_i32( - batch_size, - lengths_sum, - num_rows, - w.data(), - g.data(), - h.data(), - indices_32.data(), - lengths.data(), - epsilon, - lr); + if (use_fp16_weights) { + if (use_32_bit_indices) { + kernel_fp16_i32( + batch_size, + lengths_sum, + num_rows, + w_fp16.data(), + g.data(), + h.data(), + indices_32.data(), + lengths.data(), + epsilon, + lr); + } else { + kernel_fp16_i64( + batch_size, + lengths_sum, + num_rows, + w_fp16.data(), + g.data(), + h.data(), + indices.data(), + lengths.data(), + epsilon, + lr); + } } else { - kernel_i64( - batch_size, - lengths_sum, - num_rows, - w.data(), - g.data(), - h.data(), - indices.data(), - lengths.data(), - epsilon, - lr); + if (use_32_bit_indices) { + kernel_i32( + batch_size, + lengths_sum, + num_rows, + w.data(), + g.data(), + h.data(), + indices_32.data(), + lengths.data(), + epsilon, + lr); + } else { + kernel_i64( + batch_size, + lengths_sum, + num_rows, + w.data(), + g.data(), + h.data(), + indices.data(), + lengths.data(), + epsilon, + lr); + } } }, NUM_WARMUP, @@ -183,23 +217,27 @@ int main() { << embedding_dim << setw(16) << "avg length" << setw(6) << average_len << endl; - for (bool use_32_bit_indices : {false, true}) { - for (bool prefetch : {false, true}) { - // args: batch sz, num rows, emb dim, avg len, use 32b, prefetch - cout << (use_32_bit_indices ? " 32" : " 64") << " bit indices"; - if (prefetch) { - cout << " with prefetching"; - } - cout << ", "; - run_benchmark( - batch_size, - num_rows, - embedding_dim, - average_len, - use_32_bit_indices, - prefetch); - } // prefetch - } // use_32_bit_indices + for (bool use_fp16_weights : {false, true}) { + for (bool use_32_bit_indices : {false, true}) { + for (bool prefetch : {false, true}) { + // args: batch sz, num rows, emb dim, avg len, use 32b, prefetch + cout << (use_fp16_weights ? " float16" : " float32") << " weights"; + cout << (use_32_bit_indices ? " 32" : " 64") << " bit indices"; + if (prefetch) { + cout << " with prefetching"; + } + cout << ", "; + run_benchmark( + batch_size, + num_rows, + embedding_dim, + average_len, + use_fp16_weights, + use_32_bit_indices, + prefetch); + } // prefetch + } // use_32_bit_indices + } // use_fp16_weights } // for each input return 0; From 2b1e1cab45bc01e8674fc358ceacd14b67f0e6e4 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 13 May 2020 08:56:04 +0800 Subject: [PATCH 5/7] Fix misc. issues during code review REF path: - fix local variable shadow - fix memory leak - fix FP32->FP16 rounding - unify shift logic JIT path: - fix comments - fix FP32->FP16 rounding - GPR allocation: RAX for rand_buffer and R12 for length_R Bench: - fix bytes_padded caculation Signed-off-by: Yong Wu --- BUILD.bazel | 3 ++ CMakeLists.txt | 1 + bench/RowwiseAdagradFusedBenchmark.cc | 4 +- src/RefImplementations.cc | 57 +++++++++++++-------------- src/RowWiseSparseAdagradFused.cc | 39 +++++++++--------- 5 files changed, 52 insertions(+), 52 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 4d27a768e7..66a7d14db4 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -8,6 +8,9 @@ cc_library( includes = [ "src", ], + copts = [ + "-mf16c", + ], deps = [ ":fbgemm_headers", "@cpuinfo", diff --git a/CMakeLists.txt b/CMakeLists.txt index 946855c60b..c0e32c47b7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -111,6 +111,7 @@ if(MSVC) target_compile_options(fbgemm_avx2 PRIVATE "/arch:AVX2") target_compile_options(fbgemm_avx512 PRIVATE "/arch:AVX512") else(MSVC) + target_compile_options(fbgemm_generic PRIVATE "-mf16c") target_compile_options(fbgemm_avx2 PRIVATE "-m64" "-mavx2" "-mf16c" "-mfma" "-masm=intel") target_compile_options(fbgemm_avx512 PRIVATE diff --git a/bench/RowwiseAdagradFusedBenchmark.cc b/bench/RowwiseAdagradFusedBenchmark.cc index 151be9d4ac..352bb2d9e1 100644 --- a/bench/RowwiseAdagradFusedBenchmark.cc +++ b/bench/RowwiseAdagradFusedBenchmark.cc @@ -109,9 +109,9 @@ void run_benchmark( (use_fp16_weights ? sizeof(float16) : sizeof(float)) * 2 + (use_32_bit_indices ? 4 : 8)) + batch_size * (embedding_dim * sizeof(float) + sizeof(int)); - // FIXME: float16 is counted as float for effective byte loading double bytes_padded = - lengths_sum * (((embedding_dim * sizeof(float) + 63) / 64 + 1) * 64 * 2 + + lengths_sum * (((embedding_dim * (use_fp16_weights ? sizeof(float16) : + sizeof(float)) + 63) / 64 + 1) * 64 * 2 + (use_32_bit_indices ? 4 : 8)) + batch_size * (embedding_dim * sizeof(float) + sizeof(int)); diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 4130d16bd4..7891c5175a 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -17,6 +17,8 @@ #include #include #include +#include +#include using namespace std; @@ -26,7 +28,7 @@ namespace fbgemm { // // Return a random 32bit integer using xoshiro128++ // http://prng.di.unimi.it/xoshiro128plusplus.c -inline uint32_t rnd128_next(int v, int vlen) { +inline uint32_t rnd128_next(int idx, int vlen) { constexpr int VLEN_MAX = 16; // max vector size alignas(64) static thread_local uint32_t g_rnd128_buffer[4 * VLEN_MAX]; static thread_local bool g_rnd128_initialized = false; @@ -58,19 +60,19 @@ inline uint32_t rnd128_next(int v, int vlen) { } const uint32_t result = - rotl(g_rnd128_buffer[v] + g_rnd128_buffer[3 * vlen + v], 7) - + g_rnd128_buffer[v]; + rotl(g_rnd128_buffer[idx] + g_rnd128_buffer[3 * vlen + idx], 7) + + g_rnd128_buffer[idx]; - const uint32_t t = g_rnd128_buffer[1 * vlen + v] << 9; + const uint32_t t = g_rnd128_buffer[1 * vlen + idx] << 9; - g_rnd128_buffer[2 * vlen + v] ^= g_rnd128_buffer[0 * vlen + v]; - g_rnd128_buffer[3 * vlen + v] ^= g_rnd128_buffer[1 * vlen + v]; - g_rnd128_buffer[1 * vlen + v] ^= g_rnd128_buffer[2 * vlen + v]; - g_rnd128_buffer[0 * vlen + v] ^= g_rnd128_buffer[3 * vlen + v]; + g_rnd128_buffer[2 * vlen + idx] ^= g_rnd128_buffer[0 * vlen + idx]; + g_rnd128_buffer[3 * vlen + idx] ^= g_rnd128_buffer[1 * vlen + idx]; + g_rnd128_buffer[1 * vlen + idx] ^= g_rnd128_buffer[2 * vlen + idx]; + g_rnd128_buffer[0 * vlen + idx] ^= g_rnd128_buffer[3 * vlen + idx]; - g_rnd128_buffer[2 * vlen + v] ^= t; + g_rnd128_buffer[2 * vlen + idx] ^= t; - g_rnd128_buffer[3 * vlen + v] = rotl(g_rnd128_buffer[3 * vlen + v], 11); + g_rnd128_buffer[3 * vlen + idx] = rotl(g_rnd128_buffer[3 * vlen + idx], 11); return result; } @@ -1200,16 +1202,16 @@ int rowwise_sparse_adagrad_fused_ref( bool use_stochastic_rounding, int emu_vector_size) { constexpr bool isFloat16w = std::is_same::value; - // TODO: warning on vector-size not 8/16 - int vlen = emu_vector_size; // Local random buffer to emulate SIMD vector // R: generated 32bit base random numbers // r: extracted 8-bit for rounding - uint32_t *R = nullptr, *r = nullptr; - if (isFloat16w && use_stochastic_rounding) { - // Random vector buffer for stochastic rounding - R = new uint32_t[vlen]; - r = new uint32_t[vlen]; + constexpr int VLEN_MAX = 16; + uint32_t R[VLEN_MAX], r[VLEN_MAX]; + int vlen = emu_vector_size; + if (vlen != 8 && vlen != 16) { + // Raise error as it may cause buffer overflow + cout << "Not supported emu_vector_size: " << emu_vector_size << endl; + return false; } int64_t current = 0; @@ -1272,39 +1274,40 @@ int rowwise_sparse_adagrad_fused_ref( // This will be effectively adding a random variable of [0,1] for (int n = 0; n < nvec; ++n) { - int len = (n == nvec - 1) ? rem : vlen; + int cur_vlen = (n == nvec - 1) ? rem : vlen; int sr_idx = n % 4; if (isFloat16w && use_stochastic_rounding) { if (sr_idx == 0) { for (int v = 0; v < vlen; ++v) { R[v] = rnd128_next(v, vlen); - r[v] = (R[v] << 24) >> 19; + r[v] = (R[v] & 0xFFU) << 5; } } else if (sr_idx == 1) { for (int v = 0; v < vlen; ++v) - r[v] = ((R[v] >> 8) << 24) >> 19; + r[v] = ((R[v] & 0xFF00U) >> 8) << 5; } else if (sr_idx == 2) { for (int v = 0; v < vlen; ++v) - r[v] = ((R[v] << 8) >> 24) << 5; + r[v] = ((R[v] & 0xFF0000U) >> 16) << 5; } else { // 3 for (int v = 0; v < vlen; ++v) - r[v] = (R[v] >> 24) << 5; + r[v] = ((R[v] & 0xFF000000U) >> 24) << 5; } } - for (int v = 0; v < len; ++v) { + for (int v = 0; v < cur_vlen; ++v) { int j = n * vlen + v; if (isFloat16w) { union { float w_f32; uint32_t w_i32; }; - w_f32 = cpu_half2float(w_[j]); + w_f32 = _cvtsh_ss(w_[j]); w_f32 = std::fma(float_step, g_[j], w_f32); if (use_stochastic_rounding) w_i32 += r[v]; - w_[j] = cpu_float2half_rn(w_f32); + // Use truncate rounding to 'counterwork' the random added part + w_[j] = _cvtss_sh(w_f32, _MM_FROUND_TO_ZERO |_MM_FROUND_NO_EXC); } else { // float w_[j] += g_[j] * float_step; } @@ -1312,10 +1315,6 @@ int rowwise_sparse_adagrad_fused_ref( } } } - if (R != nullptr) - delete[] R; - if (r != nullptr) - delete[] r; return current == index_size; } diff --git a/src/RowWiseSparseAdagradFused.cc b/src/RowWiseSparseAdagradFused.cc index 9f158bafad..03debe8048 100644 --- a/src/RowWiseSparseAdagradFused.cc +++ b/src/RowWiseSparseAdagradFused.cc @@ -62,8 +62,9 @@ class GenRowWiseSparseAdagradFused { static mutex rtMutex_; /// Controll access to runtime; - // The hash depends on embedding dimension (block size), prefetch distance, - // and use_offsets + // The hash depends on: + // avx2 mask array, embedding dimension (block size), prefetch distance, + // use_offsets and use_stochastic_rouding switch static CodeCache< tuple, typename ReturnFunctionSignature:: @@ -128,6 +129,7 @@ typename ReturnFunctionSignature:: code.setLogger(codeLogger); #endif + x86::Gp rand_buffer = a->zax(); x86::Gp output_size = a->zdi(); x86::Gp index_size = a->zsi(); x86::Gp data_size = a->zdx(); @@ -138,10 +140,7 @@ typename ReturnFunctionSignature:: x86::Gp lengths = a->gpz(11); x86::Xmm epsilon(0); x86::Xmm lr(1); - x86::Gp rand_buffer = a->gpz(12); - - // FP32 weights does not need rand_buffer - x86::Gpd lengths_R = areWeightsFp16 ? a->zbx().r32() : a->gpz(12).r32(); + x86::Gpd lengths_R = a->gpz(12).r32(); x86::Gp scratchReg1 = a->gpz(13); x86::Gp scratchReg2 = a->gpz(14); // for prefetching @@ -177,16 +176,9 @@ typename ReturnFunctionSignature:: asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); } - if (areWeightsFp16) { - // Random buffer will use R12 exclusively - frame.setDirtyRegs( - x86::Reg::kGroupGp, - asmjit::Support::bitMask(3, 8, 9, 10, 11, 12, 13, 14)); - } else { - frame.setDirtyRegs( - x86::Reg::kGroupGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); - } + frame.setDirtyRegs( + x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); asmjit::FuncArgsAssignment args(&func); args.assignAll( @@ -624,7 +616,8 @@ typename ReturnFunctionSignature:: if (use_stochastic_rounding) { a->vpaddd(out_vreg, r0_vreg, out_vreg); } - a->vcvtps2ph(x86::word_ptr(x86::rsp), out_vreg, 0); + // Truncate rounding to 'counterwork' the random added part + a->vcvtps2ph(x86::word_ptr(x86::rsp), out_vreg, 0b11); // Copy results back for (size_t r = 0; r < remainder; ++r) { a->mov(h.r16(), x86::ptr(x86::rsp, sizeof(dataType) * r)); @@ -643,7 +636,8 @@ typename ReturnFunctionSignature:: if (use_stochastic_rounding) { a->vpaddd(out_vreg, r0_vreg, out_vreg); } - a->k(x86::k(1)).vcvtps2ph(w_ptr, out_vreg, 0); + // Truncate rounding + a->k(x86::k(1)).vcvtps2ph(w_ptr, out_vreg, 0b11); } } else { a->vcvtph2ps(out_vreg, w_ptr); @@ -651,7 +645,8 @@ typename ReturnFunctionSignature:: if (use_stochastic_rounding) { a->vpaddd(out_vreg, r0_vreg, out_vreg); } - a->vcvtps2ph(w_ptr, out_vreg, 0); + // Truncate rounding + a->vcvtps2ph(w_ptr, out_vreg, 0b11); } } @@ -678,10 +673,10 @@ typename ReturnFunctionSignature:: a->cmp(indices, index_size); a->jne(error); - a->mov(x86::eax, true); + a->mov(scratchReg1.r32(), true); a->jmp(exit); a->bind(error); - a->mov(x86::eax, false); + a->mov(scratchReg1.r32(), false); a->bind(exit); if (areWeightsFp16 && use_stochastic_rounding) { @@ -703,6 +698,8 @@ typename ReturnFunctionSignature:: 3 * vlen * sizeof(uint32_t)), S3_vreg); } } + + a->mov(x86::eax, scratchReg1.r32()); a->emitEpilog(frame); // jit_fused8bitembedding_kernel fn; From ec268b361c84f46775dd5e8a773c19366fa3aac4 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 13 May 2020 09:38:43 +0800 Subject: [PATCH 6/7] Fix missing header x86intrin.h in MSVC build Use portable immintrin.h instead. Signed-off-by: Yong Wu --- src/RefImplementations.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 7891c5175a..6872cd35c5 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -18,7 +18,7 @@ #include #include #include -#include +#include using namespace std; From 02cfbee6c46b21e4f3384e5ffd520454cdaf8c25 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 13 May 2020 09:57:05 +0800 Subject: [PATCH 7/7] Fix MSVC build on missing symbol of _cvtss_sh/_cvtsh_ss and suppress exceptions in JIT. Signed-off-by: Yong Wu --- src/RefImplementations.cc | 15 ++++++++++++++- src/RowWiseSparseAdagradFused.cc | 6 +++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 6872cd35c5..8b5ad9ee83 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -18,7 +18,20 @@ #include #include #include -#include +#include // for _cvtss_sh/_cvtsh_ss + +#ifdef _MSC_VER +// MSVC does not provide _cvtsh_ss/_cvtss_sh +#define _cvtsh_ss(a) \ + _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(a))) + +// FIXME - +// MSVC assumes rounding is 0...7 so the _MM_FROUND_NO_EXC (which is 0x8 +// if set) will lose. +#define _cvtss_sh(a, rounding) static_cast \ + (_mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(a), ((rounding) & 0x7U)))) + +#endif using namespace std; diff --git a/src/RowWiseSparseAdagradFused.cc b/src/RowWiseSparseAdagradFused.cc index 03debe8048..e5d20300fb 100644 --- a/src/RowWiseSparseAdagradFused.cc +++ b/src/RowWiseSparseAdagradFused.cc @@ -617,7 +617,7 @@ typename ReturnFunctionSignature:: a->vpaddd(out_vreg, r0_vreg, out_vreg); } // Truncate rounding to 'counterwork' the random added part - a->vcvtps2ph(x86::word_ptr(x86::rsp), out_vreg, 0b11); + a->vcvtps2ph(x86::word_ptr(x86::rsp), out_vreg, 11); // Copy results back for (size_t r = 0; r < remainder; ++r) { a->mov(h.r16(), x86::ptr(x86::rsp, sizeof(dataType) * r)); @@ -637,7 +637,7 @@ typename ReturnFunctionSignature:: a->vpaddd(out_vreg, r0_vreg, out_vreg); } // Truncate rounding - a->k(x86::k(1)).vcvtps2ph(w_ptr, out_vreg, 0b11); + a->k(x86::k(1)).vcvtps2ph(w_ptr, out_vreg, 11); } } else { a->vcvtph2ps(out_vreg, w_ptr); @@ -646,7 +646,7 @@ typename ReturnFunctionSignature:: a->vpaddd(out_vreg, r0_vreg, out_vreg); } // Truncate rounding - a->vcvtps2ph(w_ptr, out_vreg, 0b11); + a->vcvtps2ph(w_ptr, out_vreg, 11); } }