Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Adding ROCm support for the dynamic_partition_op #30128

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
84 changes: 50 additions & 34 deletions tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
Expand Up @@ -16,13 +16,13 @@ limitations under the License.
// The algorithm for dynamic partition has the following steps:
// 1. Let N be the size of partitions. We initialize a new vector indices_in
// with the values 0, 1, 2, ..., N-1.
// 2. We apply cub::DeviceRadixSort::SortPairs to the key - value pairs given
// by partitions and indices_in. This will result in two new vectors
// 2. We apply gpuprim::DeviceRadixSort::SortPairs to the key - value pairs
// given by partitions and indices_in. This will result in two new vectors
// partitions_out and indices_out, with partitions_out sorted.
// 3. The first dimension of outputs[i] is equal to the number of i-values in
// partitions_out. We determine it in two steps:
// - apply cub::DeviceReduce::ReduceByKey to count how many times each value
// appears in partitions_out,
// - apply gpuprim::DeviceReduce::ReduceByKey to count how many times each
// value appears in partitions_out,
// - move the results to partition_count. This handles missing values
// (corresponding to empty parts).
// 4. Because partition_count is on the GPU, we bring it asynchronously to
Expand All @@ -31,14 +31,18 @@ limitations under the License.
// This works, because for each interval of i-values, indices_out points
// to the slices which should form output[i].

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define EIGEN_USE_GPU

#if GOOGLE_CUDA
#include "third_party/cub/device/device_radix_sort.cuh"
#include "third_party/cub/device/device_reduce.cuh"
#include "third_party/cub/iterator/constant_input_iterator.cuh"
#include "third_party/cub/thread/thread_operators.cuh"
#elif TENSORFLOW_USE_ROCM
#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp"
#endif
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
Expand All @@ -50,6 +54,12 @@ limitations under the License.
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/transform_output_iterator.h"

#if GOOGLE_CUDA
namespace gpuprim = ::cub;
#elif TENSORFLOW_USE_ROCM
namespace gpuprim = ::hipcub;
#endif

namespace tensorflow {

typedef Eigen::GpuDevice GPUDevice;
Expand All @@ -59,14 +69,14 @@ namespace {
template <typename T>
__global__ void RangeInitKernel(const T start, const T delta, const int32 size,
T* out) {
CUDA_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; }
GPU_1D_KERNEL_LOOP(i, size) { out[i] = start + i * delta; }
}

__global__ void MoveValuesKernel(const int32* keys, const int32* values,
const int32* size, int32 out_size,
int32* out) {
int32 N = min(ldg(size), out_size);
CUDA_1D_KERNEL_LOOP(i, N) {
GPU_1D_KERNEL_LOOP(i, N) {
int32 key = ldg(keys + i);
int32 value = ldg(values + i);
if (FastBoundsCheck(key, out_size)) out[key] = value;
Expand All @@ -78,10 +88,10 @@ __global__ void MoveValuesKernel(const int32* keys, const int32* values,
template <typename T>
void RangeInit(const GPUDevice& d, const T start, const T delta,
const int32 size, typename TTypes<T>::Flat out) {
GpuLaunchConfig config = GetCudaLaunchConfig(size, d);
TF_CHECK_OK(CudaLaunchKernel(RangeInitKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), start,
delta, size, out.data()));
GpuLaunchConfig config = GetGpuLaunchConfig(size, d);
TF_CHECK_OK(GpuLaunchKernel(RangeInitKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), start,
delta, size, out.data()));
}

// Given *num_runs pairs (key, value), this function moves the value
Expand All @@ -93,21 +103,21 @@ void MoveValues(const GPUDevice& d, int32* keys, int32* values, int32* num_runs,
// This is valid for correct inputs, because then out_size >= *num_runs.
// For wrong inputs, we may have out_size < *num_runs. In this case we will
// only handle the first out_size values.
GpuLaunchConfig config = GetCudaLaunchConfig(out_size, d);
TF_CHECK_OK(CudaLaunchKernel(MoveValuesKernel, config.block_count,
config.thread_per_block, 0, d.stream(), keys,
values, num_runs, out_size, out));
GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d);
TF_CHECK_OK(GpuLaunchKernel(MoveValuesKernel, config.block_count,
config.thread_per_block, 0, d.stream(), keys,
values, num_runs, out_size, out));
}

template <typename T>
void CallGatherKernel(const GPUDevice& d, const T* params, const int32* indices,
T* out, int64 gather_dim_size, int64 indices_size,
int64 slice_size, int64 out_size) {
GpuLaunchConfig config = GetCudaLaunchConfig(out_size, d);
TF_CHECK_OK(CudaLaunchKernel(
GatherOpKernel<T, int32, true>, config.block_count,
config.thread_per_block, 0, d.stream(), params, indices, out,
gather_dim_size, indices_size, slice_size, out_size));
GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d);
TF_CHECK_OK(GpuLaunchKernel(GatherOpKernel<T, int32, true>,
config.block_count, config.thread_per_block, 0,
d.stream(), params, indices, out, gather_dim_size,
indices_size, slice_size, out_size));
}

struct IdentityOp {
Expand Down Expand Up @@ -181,7 +191,7 @@ class BoundedOutputIterator
// I + P + max(3N + R + P, O + N), where:
// I - the size of the input
// N - the size of the partitions tensor
// R - the temporary storage used by cub::RadixSort, about 2N
// R - the temporary storage used by gpuprim::RadixSort, about 2N
// P - the number of partitions
// O - the size of the output
// So roughly the cost is I + P + max(5N, O + N).
Expand Down Expand Up @@ -326,7 +336,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
Tensor* indices_out, DoneCallback done) {
int32 N = partitions->NumElements();
const GPUDevice& device = c->eigen_device<GPUDevice>();
const cudaStream_t& cu_stream = GetCudaStream(c);
const gpuStream_t& cu_stream = GetGpuStream(c);

// Initialize the indices_in tensor using the Range GPU kernel.
RangeInit(device, 0, 1, N, indices_in->flat<int32>());
Expand All @@ -338,7 +348,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
// Determine temporary device storage requirements.
Tensor cub_temp_storage;
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairs(
gpuprim::DeviceRadixSort::SortPairs(
NULL, temp_storage_bytes, partitions_ptr, partitions_out_ptr,
indices_in_ptr, indices_out_ptr, N, 0, sizeof(int32) * 8, cu_stream);
// Allocate temporary storage.
Expand All @@ -349,7 +359,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
&cub_temp_storage),
done);
// Radix-sort the partition information.
cub::DeviceRadixSort::SortPairs(
gpuprim::DeviceRadixSort::SortPairs(
cub_temp_storage.flat<int8>().data(), temp_storage_bytes,
partitions_ptr, partitions_out_ptr, indices_in_ptr, indices_out_ptr, N,
0, sizeof(int32) * 8, cu_stream);
Expand All @@ -359,7 +369,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
Tensor* partition_count, Tensor* indices_out,
DoneCallback done) {
const GPUDevice& device = c->eigen_device<GPUDevice>();
const cudaStream_t& cu_stream = GetCudaStream(c);
const gpuStream_t& cu_stream = GetGpuStream(c);
int32 N = partitions->NumElements();
Tensor indices_in;
Tensor partitions_out;
Expand Down Expand Up @@ -396,8 +406,14 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
BoundedOutputIterator aggregates_out_it(aggregates_out_ptr, id_op,
num_partitions_);

#if GOOGLE_CUDA
cub::ConstantInputIterator<int32> values_in(1);
cub::Sum reduction_op;
#elif TENSORFLOW_USE_ROCM
using ConstantInputIterator =
::rocprim::constant_iterator<int32, ptrdiff_t>;
ConstantInputIterator values_in(1);
#endif
gpuprim::Sum reduction_op;

// Allocate space on GPU for the number of runs. This is required by CUB.
Tensor num_runs;
Expand All @@ -408,9 +424,9 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
// Determine temporary device storage requirements
Tensor cub_temp_storage;
size_t temp_storage_bytes = 0;
cub::DeviceReduce::ReduceByKey(NULL, temp_storage_bytes, keys_in_ptr,
unique_out_it, values_in, aggregates_out_it,
num_runs_ptr, reduction_op, N, cu_stream);
gpuprim::DeviceReduce::ReduceByKey(
NULL, temp_storage_bytes, keys_in_ptr, unique_out_it, values_in,
aggregates_out_it, num_runs_ptr, reduction_op, N, cu_stream);
// Allocate temporary storage.
OP_REQUIRES_OK_ASYNC(
c,
Expand All @@ -422,10 +438,10 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
// each index appears in partitions. The distinct indices are stored
// in unique_out, while the count is stored in aggregates_out.
// The total number of distinct indices is stored in num_runs.
cub::DeviceReduce::ReduceByKey(cub_temp_storage.flat<int8>().data(),
temp_storage_bytes, keys_in_ptr,
unique_out_it, values_in, aggregates_out_it,
num_runs_ptr, reduction_op, N, cu_stream);
gpuprim::DeviceReduce::ReduceByKey(
cub_temp_storage.flat<int8>().data(), temp_storage_bytes, keys_in_ptr,
unique_out_it, values_in, aggregates_out_it, num_runs_ptr, reduction_op,
N, cu_stream);
// We are not done yet. unique_out only contains the indices that appeared
// at least once in partitions. We move each value from aggregates_out
// to the corresponding position in partition_count. This will handle
Expand Down Expand Up @@ -468,4 +484,4 @@ TF_CALL_complex128(REGISTER_DYNAMIC_PARTITION_GPU);

} // namespace tensorflow

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM