Skip to content

Commit

Permalink
added column-wise parallel quantization and dequantization for padded…
Browse files Browse the repository at this point in the history
… all gather kernel for 1D input (nrows==1) (pytorch#1743)

Summary: Pull Request resolved: pytorch#1743

Differential Revision: D45425499

fbshipit-source-id: 06f1fc9dd9b86ef58d92f9277b4f44f5c251d663
  • Loading branch information
Xiao Sun authored and facebook-github-bot committed May 1, 2023
1 parent 54e5ab0 commit 24a037b
Showing 1 changed file with 119 additions and 35 deletions.
154 changes: 119 additions & 35 deletions fbgemm_gpu/src/quantize_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
#include <math_constants.h>
#endif

#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/quantize_ops.cuh"
#include "fbgemm_gpu/quantize_ops_utils.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/core/TensorAccessor.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/fbgemm_cuda_utils.cuh"
#include "fbgemm_gpu/quantize_ops.cuh"
#include "fbgemm_gpu/quantize_ops_utils.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

using Tensor = at::Tensor;

Expand Down Expand Up @@ -172,34 +172,59 @@ __global__ inline void _float_to_paddedFP8rowwise_cuda_kernel(

const int ncols_aligned = (ncols + row_dim - 1) / row_dim * row_dim;
int pad = ncols_aligned - ncols;
const auto row_ext = row_dim + 8;
const int output_columns =
ncols_aligned + (ncols + row_dim - 1) / row_dim * 8;

const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x;

if (row < nrows) {
const input_t* input_row = input + row * ncols;
std::uint8_t* output_row = output + row * output_columns;
for (int col = 0; col < ncols; col += row_dim) {
int col_offset = col / row_dim * 8;
int last_buc_idx = (ncols - col) / row_dim *
-1; // negative suggest it's an indice offset
float* output_row_scale =
reinterpret_cast<float*>(output_row + col + col_offset + row_dim);
int buc_end = (row_dim < ncols - col) ? row_dim : ncols - col;
float minimum_element =
fbgemm_gpu::min(input_row + col, input_row + buc_end + col);
float maximum_element =
fbgemm_gpu::max(input_row + col, input_row + buc_end + col);
auto scale =
max_pos / (kEpsilon + fmaxf(maximum_element, -minimum_element));
output_row_scale[0] = scale;
output_row_scale[1] = *reinterpret_cast<float*>(
(ncols - col > row_dim) ? &last_buc_idx : &pad);
for (int bi = 0; bi < std::min(row_dim, (int)(ncols - col)); ++bi) {
output_row[col + bi + col_offset] =
float_to_hfp8(input_row[col + bi] * scale, ebit, bias, max_pos);
}
if (row >= nrows && nrows != 1)
return;
// for 1D case, unsqueezing needed
if (nrows == 1) {
const auto threads = (ncols + row_dim - 1) / row_dim;
if (row >= threads)
return;
const input_t* const input_row = input + row * row_dim;
std::uint8_t* output_row = output + row * row_ext;
int last_buc_idx = row - (threads - 1);
float* output_row_scale = reinterpret_cast<float*>(output_row + row_dim);
const auto range = (row == threads - 1) ? row_dim - pad : row_dim;
float minimum_element = fbgemm_gpu::min(input_row, input_row + range);
float maximum_element = fbgemm_gpu::max(input_row, input_row + range);
auto scale =
max_pos / (kEpsilon + fmaxf(maximum_element, -minimum_element));
output_row_scale[0] = scale;
// if no padding, the pad value is negative to indicate where the next
// non-zero pad value is for output size counting in host
output_row_scale[1] =
*reinterpret_cast<float*>((row == threads - 1) ? &pad : &last_buc_idx);
for (int col = 0; col < range; col += 1) {
output_row[col] =
float_to_hfp8(input_row[col] * scale, ebit, bias, max_pos);
}
return;
}
// for 2D case
const input_t* input_row = input + row * ncols;
std::uint8_t* output_row = output + row * output_columns;
for (int col = 0; col < ncols; col += row_dim) {
int col_offset = col / row_dim * 8;
int last_buc_idx = (ncols - col) / row_dim * -1;
float* output_row_scale =
reinterpret_cast<float*>(output_row + col + col_offset + row_dim);
int buc_end = (row_dim < ncols - col) ? row_dim : ncols - col;
float minimum_element =
fbgemm_gpu::min(input_row + col, input_row + buc_end + col);
float maximum_element =
fbgemm_gpu::max(input_row + col, input_row + buc_end + col);
auto scale =
max_pos / (kEpsilon + fmaxf(maximum_element, -minimum_element));
output_row_scale[0] = scale;
output_row_scale[1] = *reinterpret_cast<float*>(
(ncols - col > row_dim) ? &last_buc_idx : &pad);
for (int bi = 0; bi < std::min(row_dim, (int)(ncols - col)); ++bi) {
output_row[col + bi + col_offset] =
float_to_hfp8(input_row[col + bi] * scale, ebit, bias, max_pos);
}
}
}
Expand Down Expand Up @@ -429,6 +454,24 @@ __global__ inline void _fused8bitrowwise_to_float_cuda_kernel(
}
}
}

__global__ inline void _get_padding_value_kernel(
const int nrows,
const int ncols,
const int row_dim,
const std::uint8_t* const __restrict__ input,
int* const __restrict__ offsets) {
const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x;
const int row_ext = row_dim + 8;
const auto threads = (ncols + row_ext - 1) / row_ext;
if (row >= threads)
return;
const std::uint8_t* const input_row = input + row * row_ext;
int pad = *reinterpret_cast<const int*>(input_row + row_dim + 4);
pad = (pad > 0) ? pad : 0;
offsets[row] = pad;
}

template <typename output_t>
__global__ inline void _PaddedFP8rowwise_to_float_cuda_kernel(
const std::uint8_t* const __restrict__ input,
Expand All @@ -437,13 +480,31 @@ __global__ inline void _PaddedFP8rowwise_to_float_cuda_kernel(
const int output_columns,
output_t* const __restrict__ output,
const bool forward,
const int row_dim) {
const int row_dim,
int* const __restrict__ offsets) {
const int row_ext = row_dim + 8;
const int ebit = forward ? 4 : 5;
const int bias = forward ? 15 : 31;

const int64_t row = (int)blockIdx.x * blockDim.x + threadIdx.x;
if (row >= nrows) {
if (row >= nrows && nrows != 1) {
return;
}
if (nrows == 1) {
const auto threads = (ncols + row_ext - 1) / row_ext;
if (row >= threads)
return;
const std::uint8_t* const input_row = input + row * row_ext;
output_t* output_row = output + row * row_dim;
const float* input_row_scale =
reinterpret_cast<const float*>(input_row + row_dim);
int pad = *reinterpret_cast<const int*>(&input_row_scale[1]);
pad = (pad > 0) ? pad : 0;
const int pad_offset = offsets[row];
for (int col = 0; col < row_dim - pad; col++) {
output_row[col - pad_offset] =
hfp8_to_float(input_row[col], ebit, bias) / input_row_scale[0];
}
return;
}
const std::uint8_t* const input_row = input + row * ncols;
Expand Down Expand Up @@ -916,7 +977,8 @@ Tensor _float_to_paddedFP8rowwise_gpu_t(
}
constexpr int threads_per_block = 256;
const auto num_blocks = cuda_calc_xblock_count(nrows, threads_per_block);
const auto num_blocks = cuda_calc_xblock_count(
nrows == 1 ? (ncols + row_dim - 1) / row_dim : nrows, threads_per_block);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "_float_to_FP8rowwise_cuda_kernel", [&] {
Expand Down Expand Up @@ -1160,7 +1222,28 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
}
constexpr int threads_per_block = 256;
const auto num_blocks = cuda_calc_xblock_count(nrows, threads_per_block);
const auto num_blocks = cuda_calc_xblock_count(
nrows == 1 ? (ncols + row_ext - 1) / row_ext + 1 : nrows,
threads_per_block);
Tensor offsets = at::empty(
(nrows == 1) ? num_blocks * threads_per_block : 0, // 4 = sizeof(float)
input.options().dtype(at::kInt));
if (nrows == 1) {
_get_padding_value_kernel<<<
num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
nrows,
ncols,
row_dim,
input.data_ptr<std::uint8_t>(),
offsets.data_ptr<int>());
offsets = asynchronous_complete_cumsum_gpu(offsets);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
output.scalar_type(), "PaddedFP8rowwise_to_float_cuda_kernel", [&] {
_PaddedFP8rowwise_to_float_cuda_kernel<scalar_t>
Expand All @@ -1174,7 +1257,8 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
output_columns,
output.data_ptr<scalar_t>(),
forward,
row_dim);
row_dim,
offsets.data_ptr<int>());
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down

0 comments on commit 24a037b

Please sign in to comment.