Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cuda] vectorized gamma and beta loading in vectorized_layer_norm #107287

Closed
wants to merge 11 commits into from
33 changes: 23 additions & 10 deletions aten/src/ATen/native/cuda/layer_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,33 +225,40 @@ __device__ __inline__ void vectorized_layer_norm_kernel_impl(
auto i1 = blockIdx.x;
const T * block_row = X + i1 * N;
WelfordDataLN wd = compute_stats(block_row, N, s_data);

using vec_t = aligned_vector<T, vec_size>;
const vec_t * X_vec = reinterpret_cast<const vec_t*>(block_row);
const vec_t * gamma_vec = (gamma != nullptr) ? reinterpret_cast<const vec_t*>(gamma) : nullptr;
const vec_t * beta_vec = (beta != nullptr) ? reinterpret_cast<const vec_t*>(beta) : nullptr;
vec_t * Y_vec = reinterpret_cast<vec_t*>(Y + i1 * N);

const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const int n_vec_to_read = N/vec_size;

T_ACC rstd_val = c10::cuda::compat::rsqrt(wd.sigma2 + eps);
//no tail, N is guaranteed to be multiple of vec size

// No tail, N is guaranteed to be multiple of vec size
for (int i = thrx; i < n_vec_to_read; i += numx) {
vec_t data = X_vec[i];
vec_t out;
//computation is performed in T_ACC, X is cast to T_ACC and result is implicitly cast to T
if (gamma != nullptr && beta != nullptr) {

// Computation is performed in T_ACC, X is cast to T_ACC and result is implicitly cast to T
if (gamma_vec != nullptr && beta_vec != nullptr) {
#pragma unroll
for (int ii=0; ii < vec_size; ii++){
out.val[ii] = static_cast<T_ACC>(gamma[i*vec_size + ii]) * (rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean))
+ static_cast<T_ACC>(beta[i*vec_size + ii]);
out.val[ii] = static_cast<T_ACC>(gamma_vec[i].val[ii]) * (rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean))
+ static_cast<T_ACC>(beta_vec[i].val[ii]);
}
} else if (gamma != nullptr) {
} else if (gamma_vec != nullptr) {
#pragma unroll
for (int ii=0; ii < vec_size; ii++){
out.val[ii] = static_cast<T_ACC>(gamma[i*vec_size + ii]) * (rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean));
out.val[ii] = static_cast<T_ACC>(gamma_vec[i].val[ii]) * (rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean));
}
} else if (beta != nullptr) {
} else if (beta_vec != nullptr) {
#pragma unroll
for (int ii=0; ii < vec_size; ii++){
out.val[ii] = (rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean)) + static_cast<T_ACC>(beta[i*vec_size + ii]);
out.val[ii] = (rstd_val * (static_cast<T_ACC>(data.val[ii]) - wd.mean)) + static_cast<T_ACC>(beta_vec[i].val[ii]);
}
} else {
#pragma unroll
Expand Down Expand Up @@ -715,14 +722,20 @@ void LayerNormKernelImplInternal(
T* Y_data = Y->data_ptr<T>();
T_ACC* mean_data = mean->data_ptr<T_ACC>();
T_ACC* rstd_data = rstd->data_ptr<T_ACC>();

// check if can take fast path - all tensors are properly aligned, N is less than 2^24 (to use float count),
// N is multiple of vec_size (so that all rows are aligned if tensor is aligned)
auto can_vectorize = [&](const T * ptr, int alignment){uint64_t addr = reinterpret_cast<uint64_t>(ptr); return addr % alignment == 0;};
constexpr int num_vec_elems = vec_size;
constexpr int alignment = num_vec_elems * sizeof(T);
bool can_vec_X = can_vectorize(X_data, alignment);
bool can_vec_Y = can_vectorize(Y_data, alignment);
bool can_vec_gamma = gamma.defined() ? can_vectorize(gamma_data, alignment) : true;
bool can_vec_beta = beta.defined() ? can_vectorize(beta_data, alignment) : true;

if ((std::is_same<T, float>::value || std::is_same<T, at::Half>::value || std::is_same<T, at::BFloat16>::value) &&
N <= static_cast<int64_t>(1ULL << std::numeric_limits<float>::digits) && N % num_vec_elems == 0 &&
can_vectorize(X_data, alignment) && can_vectorize(Y_data, alignment)) {
can_vec_X && can_vec_Y && can_vec_gamma && can_vec_beta) {
launch_vectorized_layer_norm_kernel(static_cast<int>(N), M, eps, X_data, gamma_data, beta_data, Y_data, mean_data, rstd_data);
} else {
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
Expand Down