diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index 338f61ff6642cb..646366dc3aa080 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -82,12 +82,12 @@ Status Concat(OpKernelContext* context, const gtl::ArraySlice& inputs, context->allocate_temp(DataTypeToEnum::value, output_shape, output)); if (output->NumElements() > 0) { auto output_flat = output->shaped({1, output->NumElements()}); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (std::is_same::value) { ConcatGPU(context, inputs_flat, output, &output_flat); return Status::OK(); } -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM ConcatCPU(context->device(), inputs_flat, &output_flat); } @@ -172,7 +172,7 @@ Status SplitCPU(OpKernelContext* context, const Tensor& input, return Status::OK(); } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Handles the general case, on GPU. template @@ -183,7 +183,7 @@ Status SplitGPU(OpKernelContext* context, const Tensor& input, LOG(FATAL) << "Not yet implemented"; // Crash ok } -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // The outer function that dispatches to the various Split*() functions above. template @@ -197,10 +197,10 @@ Status Split(OpKernelContext* context, const Tensor& input, return Status::OK(); } -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // TODO(olston, apassos): Handle non-CPU cases. // return SplitGPU(context, input, sizes, outputs); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM return SplitCPU(context, input, sizes, outputs); } diff --git a/tensorflow/core/kernels/concat_lib.h b/tensorflow/core/kernels/concat_lib.h index 175c45285d6711..7303a3409a21c0 100644 --- a/tensorflow/core/kernels/concat_lib.h +++ b/tensorflow/core/kernels/concat_lib.h @@ -47,7 +47,7 @@ void ConcatCPU( const std::vector::ConstMatrix>>& inputs, typename TTypes::Matrix* output); -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM template void ConcatGPU( OpKernelContext* c, @@ -73,7 +73,7 @@ TF_CALL_bfloat16(REGISTER); TF_CALL_bool(REGISTER); TF_CALL_uint8(REGISTER); #undef REGISTER -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENOSORFLOW_USE_ROCM #ifdef TENSORFLOW_USE_SYCL template