Skip to content

Commit

Permalink
Adding ROCm support for the "split" op
Browse files Browse the repository at this point in the history
  • Loading branch information
deven-amd committed May 19, 2019
1 parent 31dfd51 commit 476e554
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 38 deletions.
53 changes: 29 additions & 24 deletions tensorflow/core/kernels/split_lib_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define EIGEN_USE_GPU

Expand Down Expand Up @@ -87,7 +87,7 @@ __global__ void SplitOpKernel(const T* input, int32 prefix_dim_size,
int32 size = prefix_dim_size * split_dim_size * suffix_dim_size;
int32 piece_size = split_dim_size / num_split;

CUDA_1D_KERNEL_LOOP(offset, size) {
GPU_1D_KERNEL_LOOP(offset, size) {
// Calculate the index into input from offset.
int32 i = offset / (split_dim_size * suffix_dim_size);
int32 j = (offset % (split_dim_size * suffix_dim_size)) / suffix_dim_size;
Expand Down Expand Up @@ -124,7 +124,11 @@ __global__ void split_v_kernel(const T* input_ptr,
int num_outputs = output_ptr_data.size;

// verbose declaration needed due to template
#if GOOGLE_CUDA
extern __shared__ __align__(sizeof(T)) unsigned char smem[];
#elif TENSORFLOW_USE_ROCM
HIP_DYNAMIC_SHARED(unsigned char, smem);
#endif
IntType* smem_col_scan = reinterpret_cast<IntType*>(smem);

if (useSmem) {
Expand All @@ -144,7 +148,7 @@ __global__ void split_v_kernel(const T* input_ptr,
// works well when there are many small segments and when the
// segments are much longer
IntType segment =
cuda_helper::upper_bound<IntType>(col_scan, num_outputs, gidx) - 1;
gpu_helper::upper_bound<IntType>(col_scan, num_outputs, gidx) - 1;

IntType curr_offset = col_scan[segment];
IntType curr_segment = segment;
Expand Down Expand Up @@ -181,7 +185,7 @@ __global__ void SplitVOpKernel_fixed(const T* input, int32 prefix_dim_size,
int32 size = prefix_dim_size * suffix_dim_size;
int32 piece_size = suffix_dim_size / num_split;

CUDA_1D_KERNEL_LOOP(offset, size) {
GPU_1D_KERNEL_LOOP(offset, size) {
// Calculate the index into input from offset.
int32 i = offset / suffix_dim_size;
int32 j = offset % suffix_dim_size;
Expand All @@ -198,13 +202,13 @@ void SplitOpGPULaunch<T>::Run(const Eigen::GpuDevice& d, const T* input,
int32 prefix_dim_size, int32 split_dim_size,
int32 suffix_dim_size,
const GpuDeviceArrayStruct<T*>& output_ptr_data) {
GpuLaunchConfig config = GetCudaLaunchConfig(
prefix_dim_size * split_dim_size * suffix_dim_size, d);
GpuLaunchConfig config =
GetGpuLaunchConfig(prefix_dim_size * split_dim_size * suffix_dim_size, d);

TF_CHECK_OK(CudaLaunchKernel(SplitOpKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), input,
prefix_dim_size, split_dim_size, suffix_dim_size,
output_ptr_data));
TF_CHECK_OK(GpuLaunchKernel(SplitOpKernel<T>, config.block_count,
config.thread_per_block, 0, d.stream(), input,
prefix_dim_size, split_dim_size, suffix_dim_size,
output_ptr_data));
}

template <typename T, typename IntType>
Expand All @@ -217,10 +221,10 @@ void SplitVOpGPULaunch<T, IntType>::Run(
GpuLaunchConfig config =
GetGpuLaunchConfig(total_rows * total_cols, gpu_device);

TF_CHECK_OK(CudaLaunchKernel(SplitVOpKernel_fixed<T>, config.block_count,
config.thread_per_block, 0,
gpu_device.stream(), input_ptr, total_rows,
total_cols, output_ptr_data));
TF_CHECK_OK(
GpuLaunchKernel(SplitVOpKernel_fixed<T>, dim3(config.block_count),
dim3(config.thread_per_block), 0, gpu_device.stream(),
input_ptr, total_rows, total_cols, output_ptr_data));
} else {
auto config = GetGpu2DLaunchConfig(total_cols, total_rows, gpu_device);
IntType smem_max = gpu_device.sharedMemPerBlock();
Expand All @@ -229,16 +233,17 @@ void SplitVOpGPULaunch<T, IntType>::Run(
// memory on most processors possibly due to decreasing occupancy
// 4096 inputs is a lot, most code will take the smem path
const int32 kMaxSmemBytesPerformance = 16384;
if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance)
TF_CHECK_OK(CudaLaunchKernel(
split_v_kernel<T, IntType, true>, config.block_count,
config.thread_per_block, smem_usage, gpu_device.stream(), input_ptr,
output_scan, total_rows, total_cols, output_ptr_data));
else
TF_CHECK_OK(CudaLaunchKernel(
split_v_kernel<T, IntType, false>, config.block_count,
config.thread_per_block, 0, gpu_device.stream(), input_ptr,
if (smem_usage < smem_max && smem_usage < kMaxSmemBytesPerformance) {
TF_CHECK_OK(GpuLaunchKernel(
(split_v_kernel<T, IntType, true>), dim3(config.block_count),
dim3(config.thread_per_block), smem_usage, gpu_device.stream(),
input_ptr, output_scan, total_rows, total_cols, output_ptr_data));
} else {
TF_CHECK_OK(GpuLaunchKernel(
(split_v_kernel<T, IntType, false>), dim3(config.block_count),
dim3(config.thread_per_block), 0, gpu_device.stream(), input_ptr,
output_scan, total_rows, total_cols, output_ptr_data));
}
}
}

Expand All @@ -261,4 +266,4 @@ TF_CALL_bfloat16(REGISTER_GPU_KERNEL);

} // namespace tensorflow

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
12 changes: 6 additions & 6 deletions tensorflow/core/kernels/split_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/gpu_device_array.h"
#include "tensorflow/core/kernels/split_lib_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

namespace tensorflow {

Expand Down Expand Up @@ -266,7 +266,7 @@ class SplitOpCPU : public SplitOpBase<CPUDevice, T> {
}
};

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Partial specialization for GPU
template <typename T>
Expand Down Expand Up @@ -323,7 +323,7 @@ class SplitOpGPU : public SplitOpBase<GPUDevice, T> {
errors::Internal("Launch of gpu kernel for SplitOp failed"));
}
};
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#ifdef TENSORFLOW_USE_SYCL
template <typename T>
Expand Down Expand Up @@ -407,7 +407,7 @@ REGISTER_SPLIT(quint8);

#undef REGISTER_SPLIT

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define REGISTER_GPU(type) \
REGISTER_KERNEL_BUILDER(Name("Split") \
Expand All @@ -422,7 +422,7 @@ TF_CALL_complex128(REGISTER_GPU);
REGISTER_GPU(bfloat16);
#undef REGISTER_GPU

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#ifdef TENSORFLOW_USE_SYCL
#define REGISTER_SYCL(type) \
Expand Down
16 changes: 8 additions & 8 deletions tensorflow/core/kernels/split_v_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ limitations under the License.

#define EIGEN_USE_THREADS

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#include <numeric>

Expand All @@ -33,12 +33,12 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/kernels/gpu_device_array.h"
#include "tensorflow/core/kernels/split_lib_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

namespace tensorflow {

Expand Down Expand Up @@ -328,7 +328,7 @@ class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> {
}
};

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Partial specialization for GPU
template <typename T, typename Tlen>
Expand Down Expand Up @@ -436,7 +436,7 @@ class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> {
}
}
};
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define REGISTER_SPLIT(type, len_type) \
REGISTER_KERNEL_BUILDER(Name("SplitV") \
Expand All @@ -456,7 +456,7 @@ TF_CALL_ALL_TYPES(REGISTER_SPLIT_LEN);
#undef REGISTER_SPLIT_LEN
#undef REGISTER_SPLIT

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define REGISTER_GPU(type, len_type) \
REGISTER_KERNEL_BUILDER(Name("SplitV") \
Expand Down Expand Up @@ -496,6 +496,6 @@ REGISTER_GPU_int32(int64);

#undef REGISTER_GPU_int32

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

} // end namespace tensorflow

0 comments on commit 476e554

Please sign in to comment.