Skip to content

Commit

Permalink
Add BF16 in padded FP8 quantize ops (#2010)
Browse files Browse the repository at this point in the history
Summary:

- Add BF16 support in `FloatToPaddedFP8RowwiseQuantized` and
  `PaddedFP8RowwiseQuantizedToFloat`.
- Refactor `src/quantize_ops/quantize_fp8_rowwise.cu`
- Move unit test from `hpc` to `fbgemm_gpu`

Reviewed By: jianyuh, summerdengfb, qchip

Differential Revision: D49166595
  • Loading branch information
sryap authored and facebook-github-bot committed Sep 19, 2023
1 parent e41a2c5 commit 95a116c
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 110 deletions.
32 changes: 32 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,19 @@ static DEVICE_INLINE float __bfloat162float(const at::BFloat16 input) {
#endif
}

// Helper functions for converting data to float
static DEVICE_INLINE float to_float(const float input) {
return input;
}

static DEVICE_INLINE float to_float(const at::Half input) {
return __half2float(input);
}

static DEVICE_INLINE float to_float(const at::BFloat16 input) {
return __bfloat162float(input);
}

#ifdef __HIP_PLATFORM_HCC__
// the descriptions of __float2bfloat16 and __float2bfloat16_rn are identical
// https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__MISC.html#group__CUDA__MATH____BFLOAT16__MISC
Expand All @@ -1710,6 +1723,25 @@ static __host__ __device__ __nv_bfloat16 __float2bfloat16_rn(float f) {
}
#endif

// Helper functions for storing float in quantized storage
static DEVICE_INLINE void quantize_float_store(
at::BFloat16* output,
const float input) {
*reinterpret_cast<__nv_bfloat16*>(output) = __float2bfloat16(input);
}

static DEVICE_INLINE void quantize_float_store(
at::Half* output,
const float input) {
*output = __float2half(input);
}

static DEVICE_INLINE void quantize_float_store(
float* output,
const float input) {
*output = input;
}

#if !( \
((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ at::Tensor _paddedFP8rowwise_to_float_gpu(
const at::Tensor& input,
const bool forward = true,
const int64_t row_dim = 256,
const int64_t output_last_dim = -1);
const int64_t output_last_dim = -1,
const int64_t output_dtype = 0);
at::Tensor _fused8bitrowwise_to_half_gpu(const at::Tensor& input);
at::Tensor _fused8bitrowwise_to_float_or_half_gpu(
const at::Tensor& input,
Expand Down
93 changes: 19 additions & 74 deletions fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,8 @@ __global__ inline void _float_to_FP8rowwise_cuda_kernel(
max_pos / (kEpsilon + fmaxf(maximum_element, -minimum_element));
output_row_scale_bias[0] = scale;
for (int64_t col = 0; col < ncols; ++col) {
if constexpr (std::is_same<input_t, at::BFloat16>::value) {
output_row[col] = float_to_hfp8(
__bfloat162float(input_row[col]) * scale, ebit, bias, max_pos);
} else if constexpr (std::is_same<input_t, at::Half>::value) {
output_row[col] = float_to_hfp8(
__half2float(input_row[col]) * scale, ebit, bias, max_pos);
} else {
output_row[col] =
float_to_hfp8(input_row[col] * scale, ebit, bias, max_pos);
}
output_row[col] =
float_to_hfp8(to_float(input_row[col]) * scale, ebit, bias, max_pos);
}
}
}
Expand Down Expand Up @@ -95,15 +87,7 @@ __global__ inline void _get_FP8_qparam_cuda_kernel(
for (int64_t col = threadIdx.x; col < ncols; col += lane_width) {
// Get thread-local minmax. These are the smallest min and max ever seen
// by this thread.
if constexpr (std::is_same<input_t, at::BFloat16>::value) {
maximum_element =
fmaxf(maximum_element, fabs(__bfloat162float(input_row[col])));
} else if constexpr (std::is_same<input_t, at::Half>::value) {
maximum_element =
fmaxf(maximum_element, fabs(__half2float(input_row[col])));
} else {
maximum_element = fmaxf(maximum_element, fabs(input_row[col]));
}
maximum_element = fmaxf(maximum_element, fabs(to_float(input_row[col])));
}
}

Expand Down Expand Up @@ -164,16 +148,8 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel(
// TODO: lift range_list into shared memory. However, when nrows is large,
// it might exceed the size of shared memory.
// output_addr[0] = lrintf((input[input_idx] - bias) * inverse_scale);
if constexpr (std::is_same<input_t, at::BFloat16>::value) {
output_addr[0] = float_to_hfp8(
__bfloat162float(input[input_idx]) * scale, ebit, bias, max_pos);
} else if constexpr (std::is_same<input_t, at::Half>::value) {
output_addr[0] = float_to_hfp8(
__half2float(input[input_idx]) * scale, ebit, bias, max_pos);
} else {
output_addr[0] =
float_to_hfp8(input[input_idx] * scale, ebit, bias, max_pos);
}
output_addr[0] = float_to_hfp8(
to_float(input[input_idx]) * scale, ebit, bias, max_pos);
}
}
}
Expand Down Expand Up @@ -201,15 +177,7 @@ __global__ inline void _FP8rowwise_to_float_cuda_kernel(

const float output_ =
hfp8_to_float(input_row[col], ebit, bias) / input_row_scale_bias[0];

if constexpr (std::is_same<output_t, at::BFloat16>::value) {
*reinterpret_cast<__nv_bfloat16*>(&output_row[col]) =
__float2bfloat16(output_);
} else if constexpr (std::is_same<output_t, at::Half>::value) {
output_row[col] = __half2float(output_);
} else {
output_row[col] = output_;
}
quantize_float_store(&output_row[col], output_);
}
}
}
Expand Down Expand Up @@ -348,8 +316,10 @@ _float_to_FP8rowwise_gpu(const Tensor& input, const bool forward) {
}
}
template <typename output_t>
Tensor _FP8rowwise_to_float_gpu_t(const Tensor& input, bool forward) {
Tensor _FP8rowwise_to_float_gpu_t(
const Tensor& input,
bool forward,
const int64_t output_dtype) {
TENSOR_ON_CUDA_GPU(input);
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
Expand All @@ -371,24 +341,14 @@ Tensor _FP8rowwise_to_float_gpu_t(const Tensor& input, bool forward) {
// that size).
auto output_dims = input_sizes.vec();
output_dims[last_dim] = output_columns;
Tensor output;
if constexpr (std::is_same_v<output_t, float>) {
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kFloat));
} else if constexpr (std::is_same_v<output_t, half>) { // T = at::Half
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kHalf));
} else if constexpr (std::is_same_v<
output_t,
__nv_bfloat16>) { // T = at::BFloat16
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kBFloat16));
} else {
TORCH_CHECK(false);
}
const auto output_sdtype = static_cast<SparseType>(output_dtype);
TORCH_CHECK(
output_sdtype == SparseType::FP32 || output_sdtype == SparseType::FP16 ||
output_sdtype == SparseType::BF16);
Tensor output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(getScalarType(output_sdtype)));
if (nrows == 0 || output_columns == 0) {
return output;
Expand Down Expand Up @@ -422,22 +382,7 @@ DLL_PUBLIC at::Tensor _FP8rowwise_to_float_gpu(
const at::Tensor& input,
bool forward,
const int64_t output_dtype) {
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
Tensor output;
switch (output_sparse_dtype) {
case SparseType::FP32:
output = _FP8rowwise_to_float_gpu_t<float>(input, forward);
break;
case SparseType::FP16:
output = _FP8rowwise_to_float_gpu_t<half>(input, forward);
break;
case SparseType::BF16:
output = _FP8rowwise_to_float_gpu_t<__nv_bfloat16>(input, forward);
break;
default:
TORCH_CHECK(false);
}
return output;
return _FP8rowwise_to_float_gpu_t(input, forward, output_dtype);
}
} // namespace fbgemm_gpu
2 changes: 1 addition & 1 deletion fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"MSFPQuantizedToFloat(Tensor input, int ebits, int mbits, int bias) -> Tensor");
m.def(
"PaddedFP8RowwiseQuantizedToFloat(Tensor input, bool forward, int row_dim, int output_last_dim=-1) -> Tensor");
"PaddedFP8RowwiseQuantizedToFloat(Tensor input, bool forward, int row_dim, int output_last_dim=-1, int output_dtype=0) -> Tensor");
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
Expand Down
51 changes: 26 additions & 25 deletions fbgemm_gpu/src/quantize_ops/quantize_padded_fp8_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ __global__ inline void _float_to_paddedFP8rowwise_cuda_kernel(
*reinterpret_cast<float*>((row == threads - 1) ? &pad : &last_buc_idx);
for (int col = 0; col < range; col += 1) {
output_row[col] =
float_to_hfp8(input_row[col] * scale, ebit, bias, max_pos);
float_to_hfp8(to_float(input_row[col]) * scale, ebit, bias, max_pos);
}
return;
}
Expand All @@ -88,8 +88,8 @@ __global__ inline void _float_to_paddedFP8rowwise_cuda_kernel(
output_row_scale[1] = *reinterpret_cast<float*>(
(ncols - col > row_dim) ? &last_buc_idx : &pad);
for (int bi = 0; bi < std::min(row_dim, (int)(ncols - col)); ++bi) {
output_row[col + bi + col_offset] =
float_to_hfp8(input_row[col + bi] * scale, ebit, bias, max_pos);
output_row[col + bi + col_offset] = float_to_hfp8(
to_float(input_row[col + bi]) * scale, ebit, bias, max_pos);
}
}
}
Expand Down Expand Up @@ -160,7 +160,8 @@ __global__ inline void _PaddedFP8rowwise_to_float_1d_cuda_kernel(
const auto pad_offset = offsets[row];
output_t* output_row = output + row * row_dim - pad_offset;
for (int col = threadIdx.x; col < row_dim - pad; col += blockDim.x) {
output_row[col] = hfp8_to_float(input_row[col], ebit, bias) / scale;
const auto output_ = hfp8_to_float(input_row[col], ebit, bias) / scale;
quantize_float_store(&output_row[col], output_);
}
}

Expand Down Expand Up @@ -193,8 +194,9 @@ __global__ inline void _PaddedFP8rowwise_to_float_2d_cuda_kernel(
// bucket
pad = (pad > 0) ? pad : 0;
for (int bi = 0; bi < row_dim - pad; ++bi) {
output_row[col + bi - col_offset] =
const auto output_ =
hfp8_to_float(input_row[col + bi], ebit, bias) / input_row_scale[0];
quantize_float_store(&output_row[col + bi - col_offset], output_);
}
col_offset = col_offset + 8 + pad;
}
Expand All @@ -203,7 +205,6 @@ __global__ inline void _PaddedFP8rowwise_to_float_2d_cuda_kernel(
} // namespace

// revising INT8 rowwise template for FP8 rowwise quantization
template <typename input_t>
Tensor _float_to_paddedFP8rowwise_gpu_t(
const Tensor& input,
const bool forward,
Expand Down Expand Up @@ -241,7 +242,7 @@ Tensor _float_to_paddedFP8rowwise_gpu_t(
const auto num_blocks = cuda_calc_xblock_count(
nrows == 1 ? (ncols + row_dim - 1) / row_dim : nrows, threads_per_block);

FBGEMM_DISPATCH_FLOAT_AND_HALF(
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
input.scalar_type(), "_float_to_FP8rowwise_cuda_kernel", [&] {
_float_to_paddedFP8rowwise_cuda_kernel<scalar_t>
<<<num_blocks,
Expand All @@ -259,12 +260,12 @@ Tensor _float_to_paddedFP8rowwise_gpu_t(
return output;
}
template <typename output_t>
Tensor _paddedFP8rowwise_to_float_gpu_t(
const Tensor& input,
const bool forward,
const int64_t row_dim,
const int64_t output_last_dim) {
const int64_t output_last_dim,
const int64_t output_dtype) {
TENSOR_ON_CUDA_GPU(input);
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
Expand Down Expand Up @@ -328,16 +329,15 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
}
output_dims[last_dim] = output_columns;
Tensor output;
if constexpr (std::is_same_v<output_t, float>) {
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kFloat));
} else { // T = at::Half
output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(at::kHalf));
}
const auto output_sdtype = static_cast<SparseType>(output_dtype);
TORCH_CHECK(
output_sdtype == SparseType::FP32 || output_sdtype == SparseType::FP16 ||
output_sdtype == SparseType::BF16);
Tensor output = at::empty(
output_dims, // 4 = sizeof(float)
input.options().dtype(getScalarType(output_sdtype)));
if (nrows == 0 || output_columns == 0) {
return output;
Expand All @@ -357,7 +357,7 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
constexpr int kMaxThreads = 1024;
const auto threads_per_block =
kMaxThreads < row_dim ? kMaxThreads : row_dim;
FBGEMM_DISPATCH_FLOAT_AND_HALF(
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
output.scalar_type(), "PaddedFP8rowwise_to_float_1d_cuda_kernel", [&] {
_PaddedFP8rowwise_to_float_1d_cuda_kernel<scalar_t>
<<<num_rows,
Expand All @@ -375,7 +375,7 @@ Tensor _paddedFP8rowwise_to_float_gpu_t(
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
FBGEMM_DISPATCH_FLOAT_AND_HALF(
FBGEMM_DISPATCH_FLOAT_HALF_AND_BFLOAT16(
output.scalar_type(), "PaddedFP8rowwise_to_float_2d_cuda_kernel", [&] {
_PaddedFP8rowwise_to_float_2d_cuda_kernel<scalar_t>
<<<num_blocks,
Expand All @@ -402,16 +402,17 @@ DLL_PUBLIC Tensor _float_to_paddedFP8rowwise_gpu(
const Tensor& input,
const bool forward,
const int64_t row_dim) {
return _float_to_paddedFP8rowwise_gpu_t<float>(input, forward, row_dim);
return _float_to_paddedFP8rowwise_gpu_t(input, forward, row_dim);
}
DLL_PUBLIC at::Tensor _paddedFP8rowwise_to_float_gpu(
const at::Tensor& input,
const bool forward,
const int64_t row_dim,
const int64_t output_last_dim) {
return _paddedFP8rowwise_to_float_gpu_t<float>(
input, forward, row_dim, output_last_dim);
const int64_t output_last_dim,
const int64_t output_dtype) {
return _paddedFP8rowwise_to_float_gpu_t(
input, forward, row_dim, output_last_dim, output_dtype);
}
} // namespace fbgemm_gpu
Loading

0 comments on commit 95a116c

Please sign in to comment.