Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Creating GpuLaunchKernel #28565

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow/core/kernels/check_numerics_op_gpu.cu.cc
Expand Up @@ -24,6 +24,7 @@ limitations under the License.

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/gpu_launch_config.h"

namespace tensorflow {
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/scan_ops_gpu.h
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/scan_ops.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/gpu_launch_config.h"
#include "tensorflow/core/util/permutation_input_iterator.h"
#include "tensorflow/core/util/permutation_output_iterator.h"
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/tridiagonal_matmul_op_gpu.cu.cc
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/gpu_device_functions.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/gpu_launch_config.h"

namespace tensorflow {
Expand Down Expand Up @@ -77,7 +78,7 @@ class TridiagonalMatMulOpGpu : public OpKernel {

const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
CudaLaunchConfig cfg = GetCudaLaunchConfig(1, device);
TF_CHECK_OK(CudaLaunchKernel(
TF_CHECK_OK(GpuLaunchKernel(
TridiagonalMatMulKernel<Scalar>, cfg.block_count, cfg.thread_per_block,
0, device.stream(), batch_size, m, n, superdiag.flat<Scalar>().data(),
maindiag.flat<Scalar>().data(), subdiag.flat<Scalar>().data(),
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/tridiagonal_solve_op_gpu.cu.cc
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/gpu_device_functions.h"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/gpu_launch_config.h"

namespace tensorflow {
Expand Down
33 changes: 33 additions & 0 deletions tensorflow/core/util/gpu_kernel_helper.h
Expand Up @@ -71,6 +71,39 @@ inline const gpuStream_t& GetGpuStream(OpKernelContext* context) {
return *ptr;
}

// Launches a GPU kernel through cudaLaunchKernel in CUDA environment, or
// hipLaunchKernel in ROCm environment with the given arguments.
//
// The kernel parameters 'Ts' must be constructible from the arguments 'Args'.
template <typename... Ts, typename... Args>
Status GpuLaunchKernel(void (*function)(Ts...), dim3 grid_dim, dim3 block_dim,
size_t shared_memory_size_bytes, gpuStream_t stream,
Args... arguments) {
static_assert(detail::NoneIsReference<Ts...>(),
"Kernels with reference arguments have undefined behaviour.");
#if GOOGLE_CUDA
auto func_ptr = absl::bit_cast<const void*>(function);
// Cast arguments and forward them as an array of pointers.
auto args_tuple = std::tuple<Ts...>(arguments...);
auto arg_ptrs = detail::GetArrayOfElementPointers(&args_tuple);
auto result = cudaLaunchKernel(func_ptr, grid_dim, block_dim, arg_ptrs.data(),
shared_memory_size_bytes, stream);
if (result != cudaSuccess) {
return errors::Internal(cudaGetErrorString(result));
}
#elif TENSORFLOW_USE_ROCM
hipLaunchKernelGGL(function, grid_dim, block_dim, shared_memory_size_bytes,
stream, std::forward<Args>(arguments)...);
#endif
return Status::OK();
}

// Perfect forwarding to make CudaLaunchKernel available to both ROCm and CUDA builds
template <typename... Args>
auto CudaLaunchKernel(Args&&... args) -> decltype(GpuLaunchKernel(std::forward<Args>(args)...)) {
return GpuLaunchKernel(std::forward<Args>(args)...);
}

__host__ __device__ inline tensorflow::bfloat16 CudaLdg(
const tensorflow::bfloat16* address) {
tensorflow::bfloat16 return_value;
Expand Down
26 changes: 0 additions & 26 deletions tensorflow/core/util/gpu_launch_config.h
Expand Up @@ -416,32 +416,6 @@ constexpr bool NoneIsReference() {
return NoneTrue<(std::is_reference<Ts>::value)...>::value;
}
} // namespace detail

#if GOOGLE_CUDA
// Launches a CUDA kernel through cudaLaunchKernel with the given arguments.
//
// The kernel parameters 'Ts' must be constructible from the arguments 'Args'.
template <typename... Ts, typename... Args>
Status CudaLaunchKernel(void (*function)(Ts...), dim3 grid_dim, dim3 block_dim,
size_t shared_memory_size_bytes, cudaStream_t stream,
Args... arguments) {
static_assert(detail::NoneIsReference<Ts...>(),
"Kernels with reference arguments have undefined behaviour.");
// Cast arguments and forward them as an array of pointers.
auto args_tuple = std::tuple<Ts...>(arguments...);
auto arg_ptrs = detail::GetArrayOfElementPointers(&args_tuple);
auto func_ptr = absl::bit_cast<const void*>(function);
auto result = cudaLaunchKernel(func_ptr, grid_dim, block_dim, arg_ptrs.data(),
shared_memory_size_bytes, stream);
if (result != cudaSuccess) {
return errors::Internal(cudaGetErrorString(result));
}
return Status::OK();
}
#endif // GOOGLE_CUDA

} // namespace tensorflow

#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#endif // TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_
1 change: 1 addition & 0 deletions tensorflow/examples/adding_an_op/cuda_op_kernel.cu.cc
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/util/gpu_kernel_helper.h"
#include "tensorflow/core/util/gpu_launch_config.h"

__global__ void AddOneKernel(const int* in, const int N, int* out) {
Expand Down