diff --git a/tensorflow/core/kernels/scan_ops.cc b/tensorflow/core/kernels/scan_ops.cc index ea42fdefb4124b..70dad7d9d3dee9 100644 --- a/tensorflow/core/kernels/scan_ops.cc +++ b/tensorflow/core/kernels/scan_ops.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" @@ -89,7 +89,7 @@ class ScanOp : public OpKernel { bool exclusive_; }; -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace functor { // Forward declarations of GPU functors @@ -111,7 +111,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS); #undef DECLARE } // namespace functor -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Register Cumsum kernels #define REGISTER_CPU_KERNELS(type) \ @@ -130,7 +130,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_ALL_REDUCERS); TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ Name("Cumsum") \ @@ -148,7 +148,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS); ScanOp, int64>) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS) #undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Register Cumprod kernels #define REGISTER_CPU_KERNELS(type) \ @@ -167,7 +167,7 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS) TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS); #undef REGISTER_CPU_KERNELS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ Name("Cumprod") \ @@ -185,6 +185,6 @@ TF_CALL_NUMBER_TYPES(REGISTER_CPU_KERNELS); ScanOp, int64>) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS) #undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/scan_ops_gpu.h b/tensorflow/core/kernels/scan_ops_gpu.h index 685fe3bf950aa0..5bcbdd9dded1c0 100644 --- a/tensorflow/core/kernels/scan_ops_gpu.h +++ b/tensorflow/core/kernels/scan_ops_gpu.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_ #define TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -24,12 +24,16 @@ limitations under the License. #define CUB_USE_COOPERATIVE_GROUPS #endif // CUDA_VERSION >= 9000 +#if GOOGLE_CUDA #include "third_party/cub/block/block_load.cuh" #include "third_party/cub/block/block_scan.cuh" #include "third_party/cub/block/block_store.cuh" #include "third_party/cub/iterator/counting_input_iterator.cuh" #include "third_party/cub/iterator/transform_input_iterator.cuh" #include "third_party/gpus/cuda/include/cuComplex.h" +#elif TENSORFLOW_USE_ROCM +#include "external/rocprim_archive/hipcub/include/hipcub/hipcub.hpp" +#endif #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/scan_ops.h" @@ -38,6 +42,12 @@ limitations under the License. #include "tensorflow/core/util/permutation_input_iterator.h" #include "tensorflow/core/util/permutation_output_iterator.h" +#if GOOGLE_CUDA +namespace gpuprim = ::cub; +#elif TENSORFLOW_USE_ROCM +namespace gpuprim = ::hipcub; +#endif + namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; @@ -160,12 +170,13 @@ struct IdentityValue { template __global__ void scan_kernel(const T* in, T* out, int dimx, int dimy, int dimz, bool exclusive, bool reverse, Op op) { - typedef cub::BlockLoad + typedef gpuprim::BlockLoad BlockLoad; - typedef cub::BlockStore + typedef gpuprim::BlockStore BlockStore; - typedef cub::BlockScan BlockScan; + typedef gpuprim::BlockScan BlockScan; // Allocate aliased shared memory for BlockLoad, BlockStore, and BlockScan __shared__ union { @@ -189,11 +200,11 @@ __global__ void scan_kernel(const T* in, T* out, int dimx, int dimy, int dimz, problem_length - (block_offset % problem_length)); // first construct a counting iterator that has the desired start point - typedef cub::TransformInputIterator> + typedef gpuprim::TransformInputIterator> MapIterType; - cub::CountingInputIterator counting_iter(block_offset); + gpuprim::CountingInputIterator counting_iter(block_offset); // Next map the iterator to the actual locations in memory MapIterType map_iter(counting_iter, map_op); @@ -243,39 +254,39 @@ void LaunchScan(const GPUDevice& d, typename TTypes::ConstTensor in, if (ideal_block_size >= 1024 && std::is_same::value) { const int block_size = 1024; TF_CHECK_OK( - CudaLaunchKernel(scan_kernel, - num_blocks, block_size, 0, d.stream(), in.data(), - out.data(), dimx, dimy, dimz, exclusive, reverse, op)); + GpuLaunchKernel(scan_kernel, + num_blocks, block_size, 0, d.stream(), in.data(), + out.data(), dimx, dimy, dimz, exclusive, reverse, op)); } else if (ideal_block_size >= 512) { const int block_size = 512; TF_CHECK_OK( - CudaLaunchKernel(scan_kernel, - num_blocks, block_size, 0, d.stream(), in.data(), - out.data(), dimx, dimy, dimz, exclusive, reverse, op)); + GpuLaunchKernel(scan_kernel, + num_blocks, block_size, 0, d.stream(), in.data(), + out.data(), dimx, dimy, dimz, exclusive, reverse, op)); } else if (ideal_block_size >= 256) { const int block_size = 256; TF_CHECK_OK( - CudaLaunchKernel(scan_kernel, - num_blocks, block_size, 0, d.stream(), in.data(), - out.data(), dimx, dimy, dimz, exclusive, reverse, op)); + GpuLaunchKernel(scan_kernel, + num_blocks, block_size, 0, d.stream(), in.data(), + out.data(), dimx, dimy, dimz, exclusive, reverse, op)); } else if (ideal_block_size >= 128) { const int block_size = 128; TF_CHECK_OK( - CudaLaunchKernel(scan_kernel, - num_blocks, block_size, 0, d.stream(), in.data(), - out.data(), dimx, dimy, dimz, exclusive, reverse, op)); + GpuLaunchKernel(scan_kernel, + num_blocks, block_size, 0, d.stream(), in.data(), + out.data(), dimx, dimy, dimz, exclusive, reverse, op)); } else if (ideal_block_size >= 64) { const int block_size = 64; TF_CHECK_OK( - CudaLaunchKernel(scan_kernel, - num_blocks, block_size, 0, d.stream(), in.data(), - out.data(), dimx, dimy, dimz, exclusive, reverse, op)); + GpuLaunchKernel(scan_kernel, + num_blocks, block_size, 0, d.stream(), in.data(), + out.data(), dimx, dimy, dimz, exclusive, reverse, op)); } else { const int block_size = 32; TF_CHECK_OK( - CudaLaunchKernel(scan_kernel, - num_blocks, block_size, 0, d.stream(), in.data(), - out.data(), dimx, dimy, dimz, exclusive, reverse, op)); + GpuLaunchKernel(scan_kernel, + num_blocks, block_size, 0, d.stream(), in.data(), + out.data(), dimx, dimy, dimz, exclusive, reverse, op)); } } @@ -302,6 +313,6 @@ struct Scan, T> { } // namespace functor } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // TENSORFLOW_CORE_KERNELS_SCAN_OPS_GPU_H_ diff --git a/tensorflow/core/kernels/scan_ops_gpu_double.cu.cc b/tensorflow/core/kernels/scan_ops_gpu_double.cu.cc index adce37e473c4f3..f304c5cc53cddb 100644 --- a/tensorflow/core/kernels/scan_ops_gpu_double.cu.cc +++ b/tensorflow/core/kernels/scan_ops_gpu_double.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -28,4 +28,4 @@ template struct functor::Scan, double>; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/scan_ops_gpu_float.cu.cc b/tensorflow/core/kernels/scan_ops_gpu_float.cu.cc index b72415822d0eeb..1d0780541cc04b 100644 --- a/tensorflow/core/kernels/scan_ops_gpu_float.cu.cc +++ b/tensorflow/core/kernels/scan_ops_gpu_float.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -28,4 +28,4 @@ template struct functor::Scan, float>; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/scan_ops_gpu_half.cu.cc b/tensorflow/core/kernels/scan_ops_gpu_half.cu.cc index f9fb528be98efc..3ea7c5a47c7fd3 100644 --- a/tensorflow/core/kernels/scan_ops_gpu_half.cu.cc +++ b/tensorflow/core/kernels/scan_ops_gpu_half.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -28,4 +28,4 @@ template struct functor::Scan< GpuDevice, Eigen::internal::ProdReducer, Eigen::half>; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM