Skip to content

Commit

Permalink
Use cub::DeviceSelect::UniqueByKey for EmbeddingBackward (#68376)
Browse files Browse the repository at this point in the history
Summary:
NVIDIA/cub#405 is still under review, API might change before it finally lands into cub 1.16, please wait for NVIDIA/cub#405 before merging this. Tested locally and tests pass.

Pull Request resolved: #68376

Reviewed By: bdhirsh

Differential Revision: D34706782

Pulled By: ngimel

fbshipit-source-id: a465d39bc24354d1047af1ee85be05a1de361c86
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Mar 8, 2022
1 parent 568872a commit 68a69bb
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 4 deletions.
30 changes: 30 additions & 0 deletions aten/src/ATen/cuda/cub.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <iterator>
#include <limits>

#include <c10/util/C++17.h>

#include <ATen/cuda/cub_definitions.cuh>

#if USE_GLOBAL_CUB_WRAPPED_NAMESPACE()
Expand Down Expand Up @@ -161,6 +163,34 @@ inline void segmented_sort_pairs(
}
}

#if CUB_SUPPORTS_UNIQUE_BY_KEY()
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename KeysOutputIteratorT, typename ValuesOutputIteratorT, typename NumSelectedIteratorT>
inline void unique_by_key(
KeysInputIteratorT keys_in, ValuesInputIteratorT values_in,
KeysOutputIteratorT keys_out, ValuesOutputIteratorT values_out,
NumSelectedIteratorT num_selected, int64_t num_input_items)
{
// TODO: use thrust::discard_iterator to handle null keys_out when https://github.com/NVIDIA/cub/issues/406 is fixed.
constexpr bool null_keys_out = std::is_same<KeysOutputIteratorT, std::nullptr_t>::value;
using KeyT = typename std::iterator_traits<KeysInputIteratorT>::value_type;
using RealKeysOutputIteratorT = typename std::conditional<null_keys_out, KeyT *, KeysOutputIteratorT>::type;
RealKeysOutputIteratorT keys_out_;
auto allocator = c10::cuda::CUDACachingAllocator::get();
c10::DataPtr keys_out_owner;
c10::guts::if_constexpr<null_keys_out>(
[&](auto _) {
keys_out_owner = allocator->allocate(num_input_items * sizeof(KeyT));
keys_out_ = static_cast<KeyT *>(keys_out_owner.get());
},
[&](auto _) {
keys_out_ = keys_out;
}
);
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::UniqueByKey,
keys_in, values_in, keys_out_, values_out, num_selected, num_input_items, c10::cuda::getCurrentCUDAStream());
}
#endif

namespace impl {

template<typename InputIteratorT1, typename InputIteratorT2, typename OutputIteratorT, class ScanOpT>
Expand Down
10 changes: 9 additions & 1 deletion aten/src/ATen/cuda/cub_definitions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#define CUB_SUPPORTS_NV_BFLOAT16() false
#endif

// cub sort support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
// cub support for CUB_WRAPPED_NAMESPACE is added to cub 1.13.1 in:
// https://github.com/NVIDIA/cub/pull/326
// CUB_WRAPPED_NAMESPACE is defined globally in cmake/Dependencies.cmake
// starting from CUDA 11.5
Expand All @@ -28,6 +28,14 @@
#define USE_GLOBAL_CUB_WRAPPED_NAMESPACE() false
#endif

// cub support for UniqueByKey is added to cub 1.16 in:
// https://github.com/NVIDIA/cub/pull/405
#if CUB_VERSION >= 101600
#define CUB_SUPPORTS_UNIQUE_BY_KEY() true
#else
#define CUB_SUPPORTS_UNIQUE_BY_KEY() false
#endif

// cub support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
Expand Down
26 changes: 23 additions & 3 deletions aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/cub.h>
#include <ATen/cuda/cub.cuh>
#include <ATen/TensorUtils.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/cuda/SortingCommon.cuh>
Expand All @@ -10,6 +10,10 @@

#include <c10/macros/Macros.h>

#if CUB_SUPPORTS_UNIQUE_BY_KEY()
#include <thrust/iterator/counting_iterator.h>
#endif

namespace at {
namespace native {

Expand Down Expand Up @@ -175,8 +179,10 @@ __global__ void sum_and_scatter(

} // anon namespace

#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
template<typename index_t>
int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets);
#endif

Tensor embedding_backward_cuda_kernel(
const Tensor &grad,
Expand All @@ -200,10 +206,24 @@ Tensor embedding_backward_cuda_kernel(
// spawn a warp per index. In this context, a segment is a number of rows that should
// be summarized.
// Unit: index in `sorted_indices` and `orig_indices`
auto segment_offsets = at::empty({numel}, orig_indices.options());
int64_t num_of_segments;
#if !CUB_SUPPORTS_UNIQUE_BY_KEY()
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
});
#else
AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
auto segment_offsets = at::empty({numel}, orig_indices.options());
int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong));
cuda::cub::unique_by_key(
sorted_indices.data_ptr<index_t>(), thrust::make_counting_iterator(0),
nullptr, segment_offsets.data_ptr<index_t>(),
num_of_segments_tensor.data_ptr<int64_t>(), sorted_indices.numel());
num_of_segments = num_of_segments_tensor.item<int64_t>();
});
#endif

AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
// We split the segments up into sizes of `NROWS_PER_THREAD`
// Compute the number partial-segments per segment (some partial-segments
// may not be the full `NROWS_PER_THREAD` number of rows)
Expand Down

0 comments on commit 68a69bb

Please sign in to comment.