From f509f5a91b41d00be615d510d8c954d7ce182350 Mon Sep 17 00:00:00 2001 From: Deven Desai Date: Thu, 6 Jun 2019 16:05:59 +0000 Subject: [PATCH] Adding ROCm support for the matrix_diag op --- tensorflow/core/kernels/matrix_diag_op.cc | 8 ++++---- tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/kernels/matrix_diag_op.cc b/tensorflow/core/kernels/matrix_diag_op.cc index 7779525d2310e5..1906bc6fc1093a 100644 --- a/tensorflow/core/kernels/matrix_diag_op.cc +++ b/tensorflow/core/kernels/matrix_diag_op.cc @@ -17,9 +17,9 @@ limitations under the License. #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/matrix_diag_op.h" @@ -170,7 +170,7 @@ struct MatrixDiagPart { } // namespace functor -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. namespace functor { @@ -220,6 +220,6 @@ TF_CALL_complex128(REGISTER_MATRIX_DIAG_GPU); TF_CALL_GPU_NUMBER_TYPES(REGISTER_BATCH_MATRIX_DIAG_GPU); #undef REGISTER_BATCH_MATRIX_DIAG_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc index cfb1fa10fc37d5..cb718d282b65d6 100644 --- a/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc +++ b/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -37,4 +37,4 @@ TF_CALL_complex128(DEFINE_GPU_SPEC); } // end namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM