Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] CUDA/ROCm shared interface #36641

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
76 changes: 75 additions & 1 deletion tensorflow/core/util/gpu_device_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ limitations under the License.
* Wraps the warp-cooperative intrinsics introduced in CUDA 9 to provide
* backwards compatibility, see go/volta-porting for details.
* Provides atomic operations on types that aren't natively supported.
* Defines a number of macros and types providing a shared interface
* to either CUDA or ROCm APIs, depending on the build.
*/

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
Expand All @@ -33,12 +35,62 @@ limitations under the License.
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuComplex.h"
#include "third_party/gpus/cuda/include/cuda.h"
#else
#include "rocm/include/hip/hip_complex.h"
#endif

#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/gpu_cuda_alias.h"

namespace tensorflow {
#if GOOGLE_CUDA
using gpuFloatComplex = cuFloatComplex;
using gpuDoubleComplex = cuDoubleComplex;
using gpuStream_t = cudaStream_t;
using gpuEvent_t = cudaEvent_t;
#define gpuEventRecord cudaEventRecord
#define gpuEventSynchronize cudaEventSynchronize
#define gpuEventDestroy cudaEventDestroy
#define gpuEventCreate cudaEventCreate
#define gpuEventCreateWithFlags cudaEventCreateWithFlags
#define gpuEventDisableTiming cudaEventDisableTiming
#elif TENSORFLOW_USE_ROCM
using gpuFloatComplex = hipFloatComplex;
using gpuDoubleComplex = hipDoubleComplex;
using gpuStream_t = hipStream_t;
using gpuEvent_t = hipEvent_t;
using cudaError = int;
using cudaError_t = int;
#define cudaSuccess 0
#define cudaGetLastError hipGetLastError
#define gpuEventRecord hipEventRecord
#define gpuEventDestroy hipEventDestroy
#define gpuEventSynchronize hipEventSynchronize
#define gpuEventCreate hipEventCreate
#define gpuEventCreateWithFlags hipEventCreateWithFlags
#define gpuEventDisableTiming hipEventDisableTiming
static std::string cudaGetErrorString(int err) { return std::to_string(err); }
#endif

#define TF_RETURN_IF_CUDA_ERROR(result) \
do { \
cudaError_t error(result); \
if (!SE_PREDICT_TRUE(error == cudaSuccess)) { \
return errors::Internal("Cuda call failed with ", \
cudaGetErrorString(error)); \
} \
} while (0)

#define TF_OP_REQUIRES_CUDA_SUCCESS(context, result) \
do { \
cudaError_t error(result); \
if (!SE_PREDICT_TRUE(error == cudaSuccess)) { \
context->SetStatus(errors::Internal("Cuda call failed with", \
cudaGetErrorString(error))); \
return; \
} \
} while (0)

namespace tensorflow {
// According to HIP developer guide at
// https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#assert
// assert is not supported by HIP. While we are waiting for assert support in
Expand Down Expand Up @@ -878,6 +930,28 @@ __device__ inline std::complex<double> operator/(
}
#endif // GOOGLE_CUDA

namespace functor {
// ROCm hcc(clang) has severe difficulties dealing with std::complex directly due to a header issue.
// This template assists in casting std::complex into the corresponding internal ROCm types.
template <class T>
struct MapComplexToHipComplex {
typedef T TM;
};

#if TENSORFLOW_USE_ROCM
template <>
struct MapComplexToHipComplex<std::complex<float> > {
typedef hipFloatComplex TM;
};


template <>
struct MapComplexToHipComplex<std::complex<double> > {
typedef hipDoubleComplex TM;
};
#endif
};

} // namespace tensorflow

#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
Expand Down