Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions tensorflow/core/kernels/scan_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/

#define EIGEN_USE_THREADS
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/numeric_op.h"
Expand Down Expand Up @@ -89,7 +89,7 @@ class ScanOp : public OpKernel {
bool exclusive_;
};

#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
namespace functor {

// Forward declarations of GPU functors
Expand All @@ -111,7 +111,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS);
#undef DECLARE

} // namespace functor
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Register Cumsum kernels
#define REGISTER_CPU_KERNELS(type) \
Expand All @@ -130,7 +130,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS);
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Cumsum") \
Expand All @@ -148,7 +148,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
ScanOp<GPUDevice, type, Eigen::internal::SumReducer<type>, int64>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Register Cumprod kernels
#define REGISTER_CPU_KERNELS(type) \
Expand All @@ -167,7 +167,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
#undef REGISTER_CPU_KERNELS

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define REGISTER_GPU_KERNELS(type) \
REGISTER_KERNEL_BUILDER( \
Name("Cumprod") \
Expand All @@ -185,6 +185,6 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
ScanOp<GPUDevice, type, Eigen::internal::ProdReducer<type>, int64>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS)
#undef REGISTER_GPU_KERNELS
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

} // namespace tensorflow
65 changes: 38 additions & 27 deletions tensorflow/core/kernels/scan_ops_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,24 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_
#define TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define EIGEN_USE_GPU

#if CUDA_VERSION >= 9000
#define CUB_USE_COOPERATIVE_GROUPS
#endif // CUDA_VERSION >= 9000

#if GOOGLE_CUDA
#include "third_party/cub/block/block_load.cuh"
#include "third_party/cub/block/block_scan.cuh"
#include "third_party/cub/block/block_store.cuh"
#include "third_party/cub/iterator/counting_input_iterator.cuh"
#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "third_party/gpus/cuda/include/cuComplex.h"
#elif TENSORFLOW_USE_ROCM
#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp"
#endif
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/scan_ops.h"
Expand All @@ -38,6 +42,12 @@ limitations under the License.
#include "tensorflow/core/util/permutation_input_iterator.h"
#include "tensorflow/core/util/permutation_output_iterator.h"

#if GOOGLE_CUDA
namespace gpuprim = ::cub;
#elif TENSORFLOW_USE_ROCM
namespace gpuprim = ::hipcub;
#endif

namespace tensorflow {

typedef Eigen::GpuDevice GPUDevice;
Expand Down Expand Up @@ -160,12 +170,13 @@ struct IdentityValue {
template <typename T, typename Op, int BlockDim = 128, int ItemsPerThread = 4>
__global__ void scan_kernel(const T* in, T* out, int dimx, int dimy, int dimz,
bool exclusive, bool reverse, Op op) {
typedef cub::BlockLoad<T, BlockDim, ItemsPerThread, cub::BLOCK_LOAD_TRANSPOSE>
typedef gpuprim::BlockLoad<T, BlockDim, ItemsPerThread,
gpuprim::BLOCK_LOAD_TRANSPOSE>
BlockLoad;
typedef cub::BlockStore<T, BlockDim, ItemsPerThread,
cub::BLOCK_STORE_TRANSPOSE>
typedef gpuprim::BlockStore<T, BlockDim, ItemsPerThread,
gpuprim::BLOCK_STORE_TRANSPOSE>
BlockStore;
typedef cub::BlockScan<T, BlockDim> BlockScan;
typedef gpuprim::BlockScan<T, BlockDim> BlockScan;

// Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan
__shared__ union {
Expand All @@ -189,11 +200,11 @@ __global__ void scan_kernel(const T* in, T* out, int dimx, int dimy, int dimz,
problem_length - (block_offset % problem_length));

// first construct a counting iterator that has the desired start point
typedef cub::TransformInputIterator<int, MapIndexToLocation,
cub::CountingInputIterator<int>>
typedef gpuprim::TransformInputIterator<int, MapIndexToLocation,
gpuprim::CountingInputIterator<int>>
MapIterType;

cub::CountingInputIterator<int> counting_iter(block_offset);
gpuprim::CountingInputIterator<int> counting_iter(block_offset);

// Next map the iterator to the actual locations in memory
MapIterType map_iter(counting_iter, map_op);
Expand Down Expand Up @@ -243,39 +254,39 @@ void LaunchScan(const GPUDevice& d, typename TTypes<T, 3>::ConstTensor in,
if (ideal_block_size >= 1024 && std::is_same<T, float>::value) {
const int block_size = 1024;
TF_CHECK_OK(
CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
num_blocks, block_size, 0, d.stream(), in.data(),
out.data(), dimx, dimy, dimz, exclusive, reverse, op));
GpuLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
num_blocks, block_size, 0, d.stream(), in.data(),
out.data(), dimx, dimy, dimz, exclusive, reverse, op));
} else if (ideal_block_size >= 512) {
const int block_size = 512;
TF_CHECK_OK(
CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
num_blocks, block_size, 0, d.stream(), in.data(),
out.data(), dimx, dimy, dimz, exclusive, reverse, op));
GpuLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
num_blocks, block_size, 0, d.stream(), in.data(),
out.data(), dimx, dimy, dimz, exclusive, reverse, op));
} else if (ideal_block_size >= 256) {
const int block_size = 256;
TF_CHECK_OK(
CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
num_blocks, block_size, 0, d.stream(), in.data(),
out.data(), dimx, dimy, dimz, exclusive, reverse, op));
GpuLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
num_blocks, block_size, 0, d.stream(), in.data(),
out.data(), dimx, dimy, dimz, exclusive, reverse, op));
} else if (ideal_block_size >= 128) {
const int block_size = 128;
TF_CHECK_OK(
CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
num_blocks, block_size, 0, d.stream(), in.data(),
out.data(), dimx, dimy, dimz, exclusive, reverse, op));
GpuLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
num_blocks, block_size, 0, d.stream(), in.data(),
out.data(), dimx, dimy, dimz, exclusive, reverse, op));
} else if (ideal_block_size >= 64) {
const int block_size = 64;
TF_CHECK_OK(
CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
num_blocks, block_size, 0, d.stream(), in.data(),
out.data(), dimx, dimy, dimz, exclusive, reverse, op));
GpuLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
num_blocks, block_size, 0, d.stream(), in.data(),
out.data(), dimx, dimy, dimz, exclusive, reverse, op));
} else {
const int block_size = 32;
TF_CHECK_OK(
CudaLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
num_blocks, block_size, 0, d.stream(), in.data(),
out.data(), dimx, dimy, dimz, exclusive, reverse, op));
GpuLaunchKernel(scan_kernel<T, Op, block_size, items_per_thread>,
num_blocks, block_size, 0, d.stream(), in.data(),
out.data(), dimx, dimy, dimz, exclusive, reverse, op));
}
}

Expand All @@ -302,6 +313,6 @@ struct Scan<GPUDevice, Eigen::internal::ProdReducer<T>, T> {
} // namespace functor
} // end namespace tensorflow

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#endif // TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/scan_ops_gpu_double.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define EIGEN_USE_GPU

Expand All @@ -28,4 +28,4 @@ template struct functor::Scan<GpuDevice, Eigen::internal::ProdReducer<double>,
double>;
} // namespace tensorflow

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/scan_ops_gpu_float.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define EIGEN_USE_GPU

Expand All @@ -28,4 +28,4 @@ template struct functor::Scan<GpuDevice, Eigen::internal::ProdReducer<float>,
float>;
} // namespace tensorflow

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/scan_ops_gpu_half.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#define EIGEN_USE_GPU

Expand All @@ -28,4 +28,4 @@ template struct functor::Scan<
GpuDevice, Eigen::internal::ProdReducer<Eigen::half>, Eigen::half>;
} // namespace tensorflow

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM