Skip to content

Commit

Permalink
Improve numerical stability of GroupNorm (#54921)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #54921

Improve numerical stability of GroupNorm

Test Plan: buck test mode/dev-nosan //caffe2/test:nn -- "GroupNorm"

Reviewed By: ngimel

Differential Revision: D27414438

fbshipit-source-id: 815517240ca5ea3e2beb77ced3bd862e9c83d445
  • Loading branch information
xiaomengy authored and facebook-github-bot committed Jun 13, 2021
1 parent 095cd6a commit ff15d93
Show file tree
Hide file tree
Showing 10 changed files with 319 additions and 92 deletions.
13 changes: 10 additions & 3 deletions aten/src/ATen/native/SharedReduceOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,15 @@ struct WelfordData {
scalar_t m2;
index_t n;
combine_t nf;
C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}
C10_DEVICE WelfordData(scalar_t mean, scalar_t m2, index_t n, combine_t nf) : mean(mean), m2(m2), n(n), nf(nf) {}

C10_HOST_DEVICE WelfordData() : mean(0), m2(0), n(0), nf(0) {}

C10_HOST_DEVICE WelfordData(
scalar_t mean,
scalar_t m2,
index_t n,
combine_t nf)
: mean(mean), m2(m2), n(n), nf(nf) {}
};


Expand Down Expand Up @@ -145,7 +152,7 @@ struct WelfordOps {
};
}
#endif
WelfordOps(index_t correction, bool take_sqrt)
C10_HOST_DEVICE WelfordOps(index_t correction, bool take_sqrt)
: correction(correction), take_sqrt(take_sqrt) {}
};

Expand Down
21 changes: 5 additions & 16 deletions aten/src/ATen/native/cpu/SumKernel.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/ReduceOps.h>
#include <ATen/native/cpu/Reduce.h>
#include <c10/util/llvmMathExtras.h>

#include <algorithm>

#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cpu/Reduce.h>
#include <ATen/native/cpu/utils.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -48,17 +48,6 @@ void accumulate_result(char * C10_RESTRICT data, int64_t stride, int64_t index,
}
}

int64_t ceil_log2(int64_t x) {
if (x <= 2) {
return 1;
}

auto ux = static_cast<uint64_t>(x);
// Last set bit is floor(log2(x)), floor + 1 is ceil
// except when x is an exact powers of 2, so subtract 1 first
return static_cast<int64_t>(llvm::findLastSet(ux - 1)) + 1;
}

/** Simultaneously sum over n rows at once
This algorithm calculates the sum without loss of precision over large axes. It
Expand Down Expand Up @@ -101,7 +90,7 @@ std::array<scalar_t, nrows> multi_row_sum(
constexpr int64_t num_levels = 4;

const int64_t level_power =
std::max(int64_t(4), ceil_log2(size) / num_levels);
std::max(int64_t(4), utils::CeilLog2(size) / num_levels);
const int64_t level_step = (1 << level_power);
const int64_t level_mask = level_step - 1;

Expand Down
59 changes: 23 additions & 36 deletions aten/src/ATen/native/cpu/group_norm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <ATen/CPUApplyUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/moments_utils.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -38,47 +39,33 @@ void GroupNormKernelImplInternal(
T* Y_data = Y.data_ptr<T>();
T* mean_data = mean.data_ptr<T>();
T* rstd_data = rstd.data_ptr<T>();
const T s = T(1) / static_cast<T>(D * HxW);
const bool gamma_null = (gamma_data == nullptr);
const bool beta_null = beta_data == nullptr;
const int64_t inner_size = D * HxW;

at::parallel_for(0, N * G, 1, [&](int64_t start, int64_t end) {
constexpr int64_t K = vec::Vectorized<T>::size();
const int64_t inner_size = D * HxW / K * K;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
std::array<T, K> mean_arr;
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
std::array<T, K> rstd_arr;
for (int64_t i = start; i < end; ++i) {
const T* X_ptr = X_data + i * D * HxW;
vec::Vectorized<T> mean_vec(0);
vec::Vectorized<T> rstd_vec(0);
for (int64_t j = 0; j < inner_size; j += K) {
const vec::Vectorized<T> x_vec = vec::Vectorized<T>::loadu(X_ptr + j);
mean_vec = mean_vec + x_vec;
rstd_vec = rstd_vec + x_vec * x_vec;
}
mean_vec.store(mean_arr.data());
rstd_vec.store(rstd_arr.data());
T mean_val = std::accumulate(mean_arr.cbegin(), mean_arr.cend(), T(0));
T rstd_val = std::accumulate(rstd_arr.cbegin(), rstd_arr.cend(), T(0));
for (int64_t j = inner_size; j < D * HxW; ++j) {
mean_val += X_ptr[j];
rstd_val += X_ptr[j] * X_ptr[j];
}
mean_val *= s;
rstd_val = std::max(rstd_val * s - mean_val * mean_val, T(0));
rstd_val = T(1) / std::sqrt(rstd_val + eps);

const int64_t g = i % G;
for (int64_t j = 0; j < D; ++j) {
const int64_t c = g * D + j;
const T scale = rstd_val * (gamma_null ? T(1) : gamma_data[c]);
const T bias = -scale * mean_val + (beta_null ? T(0) : beta_data[c]);
X_ptr = X_data + (i * D + j) * HxW;
T* Y_ptr = Y_data + (i * D + j) * HxW;
for (int64_t k = 0; k < HxW; ++k) {
Y_ptr[k] = scale * X_ptr[k] + bias;
const T* X_ptr = X_data + i * inner_size;
T mean_val;
T rstd_val;
std::tie(mean_val, rstd_val) = utils::RowwiseMoments(X_ptr, inner_size);
rstd_val = T(1) / std::sqrt(std::max(rstd_val, T(0)) + eps);
if (gamma_null && beta_null) {
T* Y_ptr = Y_data + i * inner_size;
for (int j = 0; j < inner_size; ++j) {
Y_ptr[j] = (X_ptr[j] - mean_val) * rstd_val;
}
} else {
const int64_t g = i % G;
for (int64_t j = 0; j < D; ++j) {
const int64_t c = g * D + j;
const T scale = rstd_val * (gamma_null ? T(1) : gamma_data[c]);
const T bias = -scale * mean_val + (beta_null ? T(0) : beta_data[c]);
X_ptr = X_data + (i * D + j) * HxW;
T* Y_ptr = Y_data + (i * D + j) * HxW;
for (int64_t k = 0; k < HxW; ++k) {
Y_ptr[k] = scale * X_ptr[k] + bias;
}
}
}
mean_data[i] = mean_val;
Expand Down
147 changes: 147 additions & 0 deletions aten/src/ATen/native/cpu/moments_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
#pragma once

#include <array>
#include <cstring>
#include <numeric>
#include <utility>
#include <vector>

#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/utils.h>
#include <c10/util/SmallVector.h>

namespace at {
namespace native {
namespace utils {

constexpr int64_t kChunkSize = 16;

template <typename T>
void AddMoments(
int64_t m0_add,
const T& m1_add,
const T& m2_add,
int64_t& m0,
T& m1,
T& m2) {
const int64_t n = m0 + m0_add;
const T c = n == 0 ? 0 : static_cast<T>(m0_add) / static_cast<T>(n);
const T delta = m1_add - m1;
m1 += c * delta;
m2 += m2_add + delta * delta * c * static_cast<T>(m0);
m0 = n;
}

template <typename T>
void AddMomentsVec(
int64_t m0_add,
const vec::Vectorized<T>& m1_add,
const vec::Vectorized<T>& m2_add,
int64_t& m0,
vec::Vectorized<T>& m1,
vec::Vectorized<T>& m2) {
using Vec = vec::Vectorized<T>;
const int64_t n = m0 + m0_add;
const T c = n == 0 ? 0 : static_cast<T>(m0_add) / static_cast<T>(n);
const Vec c_vec(c);
const Vec delta = m1_add - m1;
m1 += c_vec * delta;
m2 += m2_add + delta * delta * c_vec * Vec(static_cast<T>(m0));
m0 = n;
}

// Compute rowwise moments by Welford algorithm and cascade sum to improve
// numerical stability.
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
// https://en.wikipedia.org/wiki/Pairwise_summation
template <typename T, int64_t kMaxDepth>
std::pair<T, T> RowwiseMomentsImpl(const T* X, int64_t N) {
using Vec = vec::Vectorized<T>;

constexpr int64_t kVecSize = Vec::size();
const int64_t n = N / kVecSize;
const int64_t m = divup(n, kChunkSize);
const int64_t depth = CeilLog2(m);

const Vec kZeroVec(T(0));
c10::SmallVector<int64_t, kMaxDepth> m0_stk(depth, 0);
c10::SmallVector<Vec, kMaxDepth> m1_stk(depth, kZeroVec);
c10::SmallVector<Vec, kMaxDepth> m2_stk(depth, kZeroVec);

for (int64_t i = 0; i < m; ++i) {
const T* X_ptr = X + i * kChunkSize * kVecSize;
const int64_t m0 = std::min(kChunkSize, n - i * kChunkSize);
Vec m1_vec(0);
Vec m2_vec(0);
for (int64_t j = 0; j < m0; ++j) {
const Vec x_vec = Vec::loadu(X_ptr + j * kVecSize);
const Vec delta_vec = x_vec - m1_vec;
const Vec c_vec = Vec(T(1) / static_cast<T>(j + 1));
m1_vec += delta_vec * c_vec;
m2_vec += delta_vec * (x_vec - m1_vec);
}
AddMomentsVec(m0, m1_vec, m2_vec, m0_stk[0], m1_stk[0], m2_stk[0]);
int64_t mask = i + 1;
for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) {
AddMomentsVec(
m0_stk[j - 1],
m1_stk[j - 1],
m2_stk[j - 1],
m0_stk[j],
m1_stk[j],
m2_stk[j]);
m0_stk[j - 1] = 0;
m1_stk[j - 1] = kZeroVec;
m2_stk[j - 1] = kZeroVec;
mask >>= 1;
}
}
for (int64_t i = 1; i < depth; ++i) {
AddMomentsVec(
m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]);
}

std::array<T, kVecSize> m1_arr{};
std::array<T, kVecSize> m2_arr{};
m1_stk[0].store(m1_arr.data());
m2_stk[0].store(m2_arr.data());

int64_t m0 = 0;
T m1 = 0;
T m2 = 0;
for (int64_t i = n * kVecSize; i < N; ++i) {
const T delta = X[i] - m1;
++m0;
m1 += delta / static_cast<T>(m0);
m2 += delta * (X[i] - m1);
}
for (int64_t i = 0; i < kVecSize; ++i) {
AddMoments(n, m1_arr[i], m2_arr[i], m0, m1, m2);
}

return std::make_pair(m1, m2 / static_cast<T>(N));
}

template <typename T>
std::pair<T, T> RowwiseMoments(const T* X, int64_t N) {
using Vec = vec::Vectorized<T>;
constexpr int64_t kVecSize = Vec::size();
const int64_t n = N / kVecSize;
const int64_t m = divup(n, kChunkSize);
const int64_t depth = CeilLog2(m);
if (depth <= 4) {
return RowwiseMomentsImpl<T, 4>(X, N);
} else if (depth <= 8) {
return RowwiseMomentsImpl<T, 8>(X, N);
} else if (depth <= 16) {
return RowwiseMomentsImpl<T, 16>(X, N);
} else if (depth <= 32) {
return RowwiseMomentsImpl<T, 32>(X, N);
} else {
return RowwiseMomentsImpl<T, 64>(X, N);
}
}

} // namespace utils
} // namespace native
} // namespace at
11 changes: 7 additions & 4 deletions aten/src/ATen/native/cpu/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
#include <ATen/cpu/vec/vec.h>
#include <c10/util/llvmMathExtras.h>

namespace at { namespace native { namespace {
namespace at {
namespace native {

namespace {

template <typename T>
inline T data_index_init(T offset) {
return offset;
}

template <typename T, typename... Args>
inline T data_index_init(T offset, T &x, const T &X, Args &&... args) {
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
offset = data_index_init(offset, std::forward<Args>(args)...);
x = offset % X;
return offset / X;
Expand All @@ -22,7 +25,7 @@ inline bool data_index_step() {
}

template <typename T, typename... Args>
inline bool data_index_step(T &x, const T &X, Args &&... args) {
inline bool data_index_step(T& x, const T& X, Args&&... args) {
if (data_index_step(std::forward<Args>(args)...)) {
x = ((x + 1) == X) ? 0 : (x + 1);
return x == 0;
Expand All @@ -47,4 +50,4 @@ T CeilLog2(const T& x) {
} // namespace utils

} // namespace native
} // namespace at// namespace at::native::<anonymous>
} // namespace at
31 changes: 31 additions & 0 deletions aten/src/ATen/native/cuda/block_reduce.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#pragma once

#include <thrust/tuple.h>

#include <ATen/native/SharedReduceOps.h>
#include <ATen/cuda/DeviceUtils.cuh>

namespace at {
Expand Down Expand Up @@ -45,6 +48,34 @@ __inline__ __device__ T BlockReduceSum(T val, T* shared) {
return val;
}

template <typename T, class ReduceOp>
__inline__ __device__ T WarpReduce(T val, const ReduceOp& op) {
#pragma unroll
for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
val = op.combine(val, op.warp_shfl_down(val, offset));
}
return val;
}

template <typename T, class ReduceOp>
__inline__ __device__ T
BlockReduce(T val, const ReduceOp& op, const T& identity_element, T* shared) {
const int lid = threadIdx.x % C10_WARP_SIZE;
const int wid = threadIdx.x / C10_WARP_SIZE;
val = WarpReduce(val, op);
__syncthreads();
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
val = (threadIdx.x < blockDim.x / C10_WARP_SIZE) ? shared[lid]
: identity_element;
if (wid == 0) {
val = WarpReduce(val, op);
}
return val;
}

} // namespace cuda_utils
} // namespace native
} // namespace at

0 comments on commit ff15d93

Please sign in to comment.