Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions fbgemm_gpu/experimental/gen_ai/src/comm/car.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@
#endif
#include <cub/cub.cuh>

#include <fbgemm_gpu/sparse_ops_utils.h>

#include <torch/torch.h>

#include <fbgemm_gpu/utils/vec_quant.cuh>
#include "fbgemm_gpu/utils/cuda_block_count.h"
#include "fbgemm_gpu/utils/vec_quant.cuh"

namespace fbgemm_gpu {

Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include <ATen/cuda/Atomic.cuh>
#include <algorithm>

#include <fbgemm_gpu/sparse_ops_utils.h>
#include "fbgemm_gpu/utils/cuda_block_count.h"
#include "fbgemm_gpu/utils/cuda_prelude.cuh"
#include "fbgemm_gpu/utils/stochastic_rounding.cuh"

Expand Down
154 changes: 1 addition & 153 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <optional>
#include <string>

#include "fbgemm_gpu/utils/cuda_block_count.h"
#include "fbgemm_gpu/utils/ops_utils.h"

inline bool torch_tensor_on_cpu_check(const at::Tensor& ten) {
Expand Down Expand Up @@ -288,159 +289,6 @@ std::string tensor_on_same_gpu_if_not_optional_check(
TORCH_CHECK(tensors_on_same_gpu.empty(), tensors_on_same_gpu); \
} while (false)

/// Determine an appropriate CUDA block count along the x axis
///
/// When launching CUDA kernels the number of blocks B is often calculated
/// w.r.t. the number of threads T and items to be processed N as
/// B=(N+T-1)/T - which is integer division rounding up.
/// This function abstracts that calculation, performs it in an
/// overflow-safe manner, and limits the return value appropriately.
///
/// This is a general function for all integral data types.
/// The goal of this set of functions is to ensure correct calculations
/// across a variety of data types without forcing the programmer to
/// cast to an appropriate type (which is dangerous because we don't
/// have conversion warnings enabled). The values of the variables
/// can then be checked for correctness at run-time.
/// Specialized functions below handle various combinations of signed
/// and unsigned inputs. This system prevents "pointless comparison
/// against zero" warnings from the compiler for unsigned types
/// (simpler ways of suppressing this warning didn't work) while
/// maintaining the various warnings.
///
/// Function is designed to facilitate run-time value checking.
template <
typename Integer1,
typename Integer2,
std::enable_if_t<std::is_integral<Integer1>::value, bool> = true,
std::enable_if_t<std::is_integral<Integer2>::value, bool> = true>
constexpr uint32_t cuda_calc_xblock_count_base(
Integer1 num_items,
Integer2 threads_per_block) {
// The number of threads can be as high as 2048 on some newer architectures,
// but this is not portable.
TORCH_CHECK(threads_per_block <= 1024, "Number of threads must be <=1024!");
// The CUDA specification at
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
// states that for compute capability 3.5-* the grid dimension of a kernel
// launch must must be <=2^31-1.
constexpr uint64_t max_blocks = 2147483647;
const auto u_num_items = static_cast<uint64_t>(num_items);
const auto u_threads = static_cast<uint64_t>(threads_per_block);
// Overflow safe variant of (a + b - 1) / b
const uint64_t blocks =
u_num_items / u_threads + (u_num_items % u_threads != 0);
return static_cast<uint32_t>(std::min(blocks, max_blocks));
}

// See: cuda_calc_xblock_count_base
template <
typename Integer1,
typename Integer2,
std::enable_if_t<
std::is_integral<Integer1>::value && std::is_signed<Integer2>::value,
bool> = true,
std::enable_if_t<
std::is_integral<Integer2>::value && std::is_unsigned<Integer2>::value,
bool> = true>
constexpr uint32_t cuda_calc_xblock_count(
Integer1 num_items,
Integer2 threads_per_block) {
TORCH_CHECK(
num_items >= 0,
"When calculating block counts, the number of items must be positive!");
return cuda_calc_xblock_count_base(num_items, threads_per_block);
}

// See: cuda_calc_xblock_count_base
template <
typename Integer1,
typename Integer2,
std::enable_if_t<
std::is_integral<Integer1>::value && std::is_unsigned<Integer2>::value,
bool> = true,
std::enable_if_t<
std::is_integral<Integer2>::value && std::is_signed<Integer2>::value,
bool> = true>
constexpr uint32_t cuda_calc_xblock_count(
Integer1 num_items,
Integer2 threads_per_block) {
TORCH_CHECK(
threads_per_block >= 0,
"When calculating thread counts, the number of threads must be positive!");
return cuda_calc_xblock_count_base(num_items, threads_per_block);
}

// See: cuda_calc_xblock_count_base
template <
typename Integer1,
typename Integer2,
std::enable_if_t<
std::is_integral<Integer1>::value && std::is_signed<Integer2>::value,
bool> = true,
std::enable_if_t<
std::is_integral<Integer2>::value && std::is_signed<Integer2>::value,
bool> = true>
constexpr uint32_t cuda_calc_xblock_count(
Integer1 num_items,
Integer2 threads_per_block) {
TORCH_CHECK(
num_items >= 0,
"When calculating block counts, the number of items must be positive!");
TORCH_CHECK(
threads_per_block >= 0,
"When calculating thread counts, the number of threads must be positive!");
return cuda_calc_xblock_count_base(num_items, threads_per_block);
}

// See: cuda_calc_xblock_count_base
template <
typename Integer1,
typename Integer2,
std::enable_if_t<
std::is_integral<Integer1>::value && std::is_unsigned<Integer2>::value,
bool> = true,
std::enable_if_t<
std::is_integral<Integer2>::value && std::is_unsigned<Integer2>::value,
bool> = true>
constexpr uint32_t cuda_calc_xblock_count(
Integer1 num_items,
Integer2 threads_per_block) {
return cuda_calc_xblock_count_base(num_items, threads_per_block);
}

/// Determine an appropriate CUDA block count.
///
/// See cuda_calc_xblock_count_base() for details.
template <
typename Integer1,
typename Integer2,
std::enable_if_t<std::is_integral<Integer1>::value, bool> = true,
std::enable_if_t<std::is_integral<Integer2>::value, bool> = true>
constexpr uint32_t cuda_calc_block_count(
Integer1 num_items,
Integer2 threads_per_block) {
// The CUDA specification at
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
// states that the grid dimension of a kernel launch must generally
// be <=65535. (For compute capability 3.5-* the grid's x-dimension must
// be <=2^31-1.) Because this function does not know which dimension
// is being calculated, we use the smaller limit.
constexpr uint32_t max_blocks = 65535;
return std::min(
cuda_calc_xblock_count(num_items, threads_per_block), max_blocks);
}

// A wrapper class for passing dynamically sized dimension information (e.g.
// tensor.dims()) from the host to device.
constexpr size_t kStackArrayMaxDims = 5;

template <typename T>
struct StackArray {
T vals[kStackArrayMaxDims];
size_t ndim;
};

inline at::Tensor aligned_grad_output_tensor_for_cuda_backwards(
const at::Tensor& grad_output) {
auto aligned_grad_output = grad_output;
Expand Down
155 changes: 155 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/utils/cuda_block_count.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <ATen/ATen.h>
#include <cstdint>

/// Determine an appropriate CUDA block count along the x axis
///
/// When launching CUDA kernels the number of blocks B is often calculated
/// w.r.t. the number of threads T and items to be processed N as
/// B=(N+T-1)/T - which is integer division rounding up.
/// This function abstracts that calculation, performs it in an
/// overflow-safe manner, and limits the return value appropriately.
///
/// This is a general function for all integral data types.
/// The goal of this set of functions is to ensure correct calculations
/// across a variety of data types without forcing the programmer to
/// cast to an appropriate type (which is dangerous because we don't
/// have conversion warnings enabled). The values of the variables
/// can then be checked for correctness at run-time.
/// Specialized functions below handle various combinations of signed
/// and unsigned inputs. This system prevents "pointless comparison
/// against zero" warnings from the compiler for unsigned types
/// (simpler ways of suppressing this warning didn't work) while
/// maintaining the various warnings.
///
/// Function is designed to facilitate run-time value checking.
template <
typename Integer1,
typename Integer2,
std::enable_if_t<std::is_integral<Integer1>::value, bool> = true,
std::enable_if_t<std::is_integral<Integer2>::value, bool> = true>
constexpr uint32_t cuda_calc_xblock_count_base(
Integer1 num_items,
Integer2 threads_per_block) {
// The number of threads can be as high as 2048 on some newer architectures,
// but this is not portable.
TORCH_CHECK(threads_per_block <= 1024, "Number of threads must be <=1024!");
// The CUDA specification at
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
// states that for compute capability 3.5-* the grid dimension of a kernel
// launch must must be <=2^31-1.
constexpr uint64_t max_blocks = 2147483647;
const auto u_num_items = static_cast<uint64_t>(num_items);
const auto u_threads = static_cast<uint64_t>(threads_per_block);
// Overflow safe variant of (a + b - 1) / b
const uint64_t blocks =
u_num_items / u_threads + (u_num_items % u_threads != 0);
return static_cast<uint32_t>(std::min(blocks, max_blocks));
}

// See: cuda_calc_xblock_count_base
template <
typename Integer1,
typename Integer2,
std::enable_if_t<
std::is_integral<Integer1>::value && std::is_signed<Integer2>::value,
bool> = true,
std::enable_if_t<
std::is_integral<Integer2>::value && std::is_unsigned<Integer2>::value,
bool> = true>
constexpr uint32_t cuda_calc_xblock_count(
Integer1 num_items,
Integer2 threads_per_block) {
TORCH_CHECK(
num_items >= 0,
"When calculating block counts, the number of items must be positive!");
return cuda_calc_xblock_count_base(num_items, threads_per_block);
}

// See: cuda_calc_xblock_count_base
template <
typename Integer1,
typename Integer2,
std::enable_if_t<
std::is_integral<Integer1>::value && std::is_unsigned<Integer2>::value,
bool> = true,
std::enable_if_t<
std::is_integral<Integer2>::value && std::is_signed<Integer2>::value,
bool> = true>
constexpr uint32_t cuda_calc_xblock_count(
Integer1 num_items,
Integer2 threads_per_block) {
TORCH_CHECK(
threads_per_block >= 0,
"When calculating thread counts, the number of threads must be positive!");
return cuda_calc_xblock_count_base(num_items, threads_per_block);
}

// See: cuda_calc_xblock_count_base
template <
typename Integer1,
typename Integer2,
std::enable_if_t<
std::is_integral<Integer1>::value && std::is_signed<Integer2>::value,
bool> = true,
std::enable_if_t<
std::is_integral<Integer2>::value && std::is_signed<Integer2>::value,
bool> = true>
constexpr uint32_t cuda_calc_xblock_count(
Integer1 num_items,
Integer2 threads_per_block) {
TORCH_CHECK(
num_items >= 0,
"When calculating block counts, the number of items must be positive!");
TORCH_CHECK(
threads_per_block >= 0,
"When calculating thread counts, the number of threads must be positive!");
return cuda_calc_xblock_count_base(num_items, threads_per_block);
}

// See: cuda_calc_xblock_count_base
template <
typename Integer1,
typename Integer2,
std::enable_if_t<
std::is_integral<Integer1>::value && std::is_unsigned<Integer2>::value,
bool> = true,
std::enable_if_t<
std::is_integral<Integer2>::value && std::is_unsigned<Integer2>::value,
bool> = true>
constexpr uint32_t cuda_calc_xblock_count(
Integer1 num_items,
Integer2 threads_per_block) {
return cuda_calc_xblock_count_base(num_items, threads_per_block);
}

/// Determine an appropriate CUDA block count.
///
/// See cuda_calc_xblock_count_base() for details.
template <
typename Integer1,
typename Integer2,
std::enable_if_t<std::is_integral<Integer1>::value, bool> = true,
std::enable_if_t<std::is_integral<Integer2>::value, bool> = true>
constexpr uint32_t cuda_calc_block_count(
Integer1 num_items,
Integer2 threads_per_block) {
// The CUDA specification at
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
// states that the grid dimension of a kernel launch must generally
// be <=65535. (For compute capability 3.5-* the grid's x-dimension must
// be <=2^31-1.) Because this function does not know which dimension
// is being calculated, we use the smaller limit.
constexpr uint32_t max_blocks = 65535;
return std::min(
cuda_calc_xblock_count(num_items, threads_per_block), max_blocks);
}
11 changes: 11 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
#include "fbgemm_gpu/utils/binary_search_range.cuh"
#include "fbgemm_gpu/utils/cuda_block_count.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/utils/fixed_divisor.cuh"
#include "fbgemm_gpu/utils/inclusive_sum_scan.cuh"
Expand All @@ -39,6 +40,16 @@ namespace fbgemm_gpu {

using Tensor = at::Tensor;

// A wrapper class for passing dynamically sized dimension information (e.g.
// tensor.dims()) from the host to device.
constexpr size_t kStackArrayMaxDims = 5;

template <typename T>
struct StackArray {
T vals[kStackArrayMaxDims];
size_t ndim;
};

namespace {

// template <typename T>
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/src/quantize_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "fbgemm_gpu/quantize_ops_utils.h"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
#include "fbgemm_gpu/utils/cuda_block_count.h"
#include "fbgemm_gpu/utils/dispatch_macros.h"
#include "fbgemm_gpu/utils/float.cuh"
#include "fbgemm_gpu/utils/ops_utils.h"
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/src/sparse_ops/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "fbgemm_gpu/sparse_ops.cuh"
#include "fbgemm_gpu/sparse_ops.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
#include "fbgemm_gpu/utils/cuda_block_count.h"
#include "fbgemm_gpu/utils/ops_utils.h"

#include <ATen/ATen.h>
Expand Down