From b7d4805838a64e6a33135b19281345032c400f56 Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Thu, 30 May 2019 19:31:17 +0000 Subject: [PATCH] Add ROCm support for 2D and 3D convolution ops. MIOpen has its own algorithm find logic and would return the size of scratch memory required to client applications. Modify AlgorithmConfig to track such information. --- .../core/kernels/conv_grad_filter_ops.cc | 31 +++++++++--- .../core/kernels/conv_grad_input_ops.cc | 35 +++++++++++--- tensorflow/core/kernels/conv_grad_ops.h | 4 +- tensorflow/core/kernels/conv_grad_ops_3d.cc | 47 ++++++++++++++++--- tensorflow/core/kernels/conv_ops.cc | 30 ++++++++++-- tensorflow/core/kernels/conv_ops.h | 8 ++-- tensorflow/core/kernels/conv_ops_3d.cc | 22 +++++++-- tensorflow/core/kernels/conv_ops_gpu.h | 4 +- tensorflow/core/kernels/conv_ops_gpu_2.cu.cc | 4 +- tensorflow/core/kernels/conv_ops_test.cc | 4 +- tensorflow/stream_executor/dnn.h | 18 ++++++- tensorflow/stream_executor/rocm/rocm_dnn.cc | 7 +-- 12 files changed, 166 insertions(+), 48 deletions(-) diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 321bc1908ba7ab..cdbe0aa759f721 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -47,12 +47,12 @@ limitations under the License. #include "tensorflow/core/kernels/eigen_contraction_kernel.h" #endif -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/util/proto/proto_utils.h" -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace { @@ -430,7 +430,7 @@ template struct LaunchConv2DBackpropFilterOp; template struct LaunchConv2DBackpropFilterOp; // GPU definitions. -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // The slow version (but compiles for GPU) // A dummy type to group forward backward filter autotune results together. @@ -683,8 +683,8 @@ void LaunchConv2DBackpropFilterOp::operator()( Tensor compatible_input; if (padding_top != padding_bottom || padding_left != padding_right) { // Pad the input in the same way we did during the forward pass, so that - // cuDNN receives the same input during the backward pass function as it did - // during the forward pass function. + // cuDNN or MIOpen receives the same input during the backward pass function + // as it did during the forward pass function. const int64 padding_rows_diff = std::abs(padding_bottom - padding_top); const int64 padding_cols_diff = std::abs(padding_right - padding_left); const int64 new_in_rows = @@ -871,6 +871,7 @@ void LaunchConv2DBackpropFilterOp::operator()( AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find( conv_parameters, &algorithm_config)) { +#if GOOGLE_CUDA std::vector algorithms; CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms( conv_parameters.ShouldIncludeWinogradNonfusedAlgo(stream->parent()), @@ -907,6 +908,22 @@ void LaunchConv2DBackpropFilterOp::operator()( filter_desc, output_desc, conv_desc, stream->parent(), results); OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config)); +#elif TENSORFLOW_USE_ROCM + ProfileResult best_result; + DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, + ctx); + bool miopen_find_status = + stream + ->ThenConvolveBackwardFilterWithAlgorithm( + input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc, + filter_desc, &filter_backprop_ptr, &scratch_allocator, + AlgorithmConfig(), &best_result) + .ok(); + OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(), + errors::NotFound("Failed to find backward filter algorithm!")); + algorithm_config.set_algorithm(best_result.algorithm()); + algorithm_config.set_scratch_size(best_result.scratch_size()); +#endif AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters, algorithm_config); } @@ -921,7 +938,7 @@ void LaunchConv2DBackpropFilterOp::operator()( if (!cudnn_launch_status) { ctx->SetStatus(errors::Internal( - "cuDNN Backward Filter function launch failure : input shape(", + "DNN Backward Filter function launch failure : input shape(", input.shape().DebugString(), ") filter shape(", filter_shape.DebugString(), ")")); return; @@ -1003,6 +1020,6 @@ template struct LaunchConv2DBackpropFilterOp; template struct LaunchConv2DBackpropFilterOp; template struct LaunchConv2DBackpropFilterOp; -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index d995b44e2a4c42..6844e3a8264725 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -47,12 +47,12 @@ limitations under the License. #include "tensorflow/core/kernels/eigen_contraction_kernel.h" #endif -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/util/proto/proto_utils.h" -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace { @@ -765,7 +765,9 @@ template struct LaunchConv2DBackpropInputOp; template struct LaunchConv2DBackpropInputOp; template struct LaunchConv2DBackpropInputOp; -#if GOOGLE_CUDA +// GPU definitions. +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +// The slow version (but compiles for GPU) // A dummy type to group forward backward data autotune results together. struct ConvBackwardDataAutoTuneGroup { @@ -911,8 +913,8 @@ void LaunchConv2DBackpropInputOp::operator()( TensorShape compatible_input_shape; if (padding_top != padding_bottom || padding_left != padding_right) { // Pad the input in the same way we did during the forward pass, so that - // cuDNN receives the same input during the backward pass function as it did - // during the forward pass function. + // cuDNN or MIOpen receives the same input during the backward pass function + // as it did during the forward pass function. const int64 padding_rows_diff = std::abs(padding_bottom - padding_top); const int64 padding_cols_diff = std::abs(padding_right - padding_left); const int64 new_in_rows = @@ -1094,6 +1096,7 @@ void LaunchConv2DBackpropInputOp::operator()( AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find( conv_parameters, &algorithm_config)) { +#if GOOGLE_CUDA std::vector algorithms; CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms( conv_parameters.ShouldIncludeWinogradNonfusedAlgo(stream->parent()), @@ -1128,6 +1131,24 @@ void LaunchConv2DBackpropInputOp::operator()( in_backprop_ptr, filter_ptr, out_backprop_ptr, input_desc, filter_desc, output_desc, conv_desc, stream->parent(), results); OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config)); +#elif TENSORFLOW_USE_ROCM + // MIOpen has its own Find and autotuner so use it here, passing + // default AlgorithmConfig to force a search + DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx); + ProfileResult best_result; + bool miopen_find_status = + stream + ->ThenConvolveBackwardDataWithAlgorithm( + filter_desc, filter_ptr, output_desc, out_backprop_ptr, + conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, + AlgorithmConfig(), &best_result) + .ok(); + OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(), + errors::NotFound("Failed to find backwards-data algorithm!")); + + algorithm_config.set_algorithm(best_result.algorithm()); + algorithm_config.set_scratch_size(best_result.scratch_size()); +#endif AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters, algorithm_config); } @@ -1141,7 +1162,7 @@ void LaunchConv2DBackpropInputOp::operator()( if (!cudnn_launch_status) { ctx->SetStatus(errors::Internal( - "cuDNN Backward Data function launch failure : input shape(", + "DNN Backward Data function launch failure : input shape(", input_shape.DebugString(), ") filter shape(", filter_shape.DebugString(), ")")); return; @@ -1255,6 +1276,6 @@ template struct LaunchConv2DBackpropInputOp; template struct LaunchConv2DBackpropInputOp; template struct LaunchConv2DBackpropInputOp; -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h index 173f92806f911e..1a61db76b6e32d 100644 --- a/tensorflow/core/kernels/conv_grad_ops.h +++ b/tensorflow/core/kernels/conv_grad_ops.h @@ -191,7 +191,7 @@ struct LaunchConv2DBackpropFilterOp { Tensor* filter_backprop, TensorFormat data_format); }; -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM template struct LaunchConv2DBackpropInputOp { void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, @@ -211,7 +211,7 @@ struct LaunchConv2DBackpropFilterOp { const std::vector& explicit_paddings, Tensor* filter_backprop, TensorFormat data_format); }; -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Information about a single spatial dimension for a convolution // backpropagation. diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc index 655f4f0e593430..e371c59c6c1fd5 100644 --- a/tensorflow/core/kernels/conv_grad_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/core/kernels/eigen_contraction_kernel.h" #endif -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/stream_executor.h" using stream_executor::dnn::DimIndex; #endif @@ -1049,7 +1049,7 @@ TF_CALL_half(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL // GPU definitions of both ops. -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Forward declarations of the functor specializations for GPU. // This ensures that the custom implementation is used instead of the default // Eigen one (which is used for CPU). @@ -1355,15 +1355,16 @@ class Conv3DBackpropInputOp : public OpKernel { using se::dnn::AlgorithmDesc; using se::dnn::ProfileResult; AlgorithmConfig algorithm_config; + ProfileResult best_result; + ProfileResult best_result_no_scratch; if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find( conv_parameters, &algorithm_config)) { +#if GOOGLE_CUDA std::vector algorithms; CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms( conv_parameters.ShouldIncludeWinogradNonfusedAlgo( stream->parent()), &algorithms)); - ProfileResult best_result; - ProfileResult best_result_no_scratch; for (auto profile_algorithm : algorithms) { // TODO(zhengxq): profile each algorithm multiple times to better // accuracy. @@ -1401,6 +1402,21 @@ class Conv3DBackpropInputOp : public OpKernel { algorithm_config.set_algorithm_no_scratch( best_result_no_scratch.algorithm()); } +#elif TENSORFLOW_USE_ROCM + DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, + context); + bool miopen_find_status = + stream + ->ThenConvolveBackwardDataWithAlgorithm( + filter_desc, filter_ptr, output_desc, out_backprop_ptr, + conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator, + AlgorithmConfig(), &best_result) + .ok(); + OP_REQUIRES(context, miopen_find_status && best_result.is_valid(), + errors::NotFound("Failed to find backward data algorithm!")); + algorithm_config.set_algorithm(best_result.algorithm()); + algorithm_config.set_scratch_size(best_result.scratch_size()); +#endif AutoTuneConv3dBwdData::GetInstance()->Insert(conv_parameters, algorithm_config); } @@ -1761,15 +1777,16 @@ class Conv3DBackpropFilterOp : public OpKernel { using se::dnn::AlgorithmDesc; using se::dnn::ProfileResult; AlgorithmConfig algorithm_config; + ProfileResult best_result; + ProfileResult best_result_no_scratch; if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find( conv_parameters, &algorithm_config)) { +#if GOOGLE_CUDA std::vector algorithms; CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms( conv_parameters.ShouldIncludeWinogradNonfusedAlgo( stream->parent()), &algorithms)); - ProfileResult best_result; - ProfileResult best_result_no_scratch; for (auto profile_algorithm : algorithms) { // TODO(zhengxq): profile each algorithm multiple times to better // accuracy. @@ -1808,6 +1825,22 @@ class Conv3DBackpropFilterOp : public OpKernel { algorithm_config.set_algorithm_no_scratch( best_result_no_scratch.algorithm()); } +#elif TENSORFLOW_USE_ROCM + DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, + context); + bool miopen_find_status = + stream + ->ThenConvolveBackwardFilterWithAlgorithm( + input_desc, input_ptr, output_desc, out_backprop_ptr, + conv_desc, filter_desc, &filter_backprop_ptr, + &scratch_allocator, AlgorithmConfig(), &best_result) + .ok(); + OP_REQUIRES( + context, miopen_find_status && best_result.is_valid(), + errors::NotFound("Failed to find backward filter algorithm!")); + algorithm_config.set_algorithm(best_result.algorithm()); + algorithm_config.set_scratch_size(best_result.scratch_size()); +#endif AutoTuneConv3dBwdFilter::GetInstance()->Insert(conv_parameters, algorithm_config); } @@ -1866,6 +1899,6 @@ TF_CALL_float(REGISTER_GPU_KERNEL); TF_CALL_double(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 0ea367d425091f..c19376e7f2afc9 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -18,9 +18,9 @@ limitations under the License. #define USE_EIGEN_TENSOR #define EIGEN_USE_THREADS -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/conv_ops.h" @@ -57,11 +57,13 @@ limitations under the License. #include "tensorflow/core/kernels/xsmm_conv2d.h" #endif -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/util/proto/proto_utils.h" +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if GOOGLE_CUDA #include "tensorflow/stream_executor/cuda/ptxas_utils.h" #include "tensorflow/stream_executor/cuda/redzone_allocator.h" #include "tensorflow/stream_executor/tf_allocator_adapter.h" @@ -574,7 +576,7 @@ template struct LaunchConv2DOp; template struct LaunchConv2DOp; template struct LaunchConv2DOp; -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM int64 GetDnnWorkspaceLimit(const string& envvar_in_mb, int64 default_value_in_bytes) { const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str()); @@ -978,6 +980,7 @@ void LaunchConv2DOp::operator()( AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) { +#if GOOGLE_CUDA std::vector algorithms; OP_REQUIRES( ctx, @@ -1049,6 +1052,23 @@ void LaunchConv2DOp::operator()( output_tensor, input_desc, filter_desc, output_desc, conv_desc, stream->parent(), results); OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config)); +#elif TENSORFLOW_USE_ROCM + DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + ProfileResult best_result; + bool miopen_find_status = + stream + ->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc, + filter_ptr, conv_desc, output_desc, + &output_ptr, &scratch_allocator, + AlgorithmConfig(), &best_result) + .ok(); + + OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(), + errors::NotFound("Failed to find conv algorithm!")); + + algorithm_config.set_algorithm(best_result.algorithm()); + algorithm_config.set_scratch_size(best_result.scratch_size()); +#endif AutoTuneConv::GetInstance()->Insert(conv_parameters, algorithm_config); } @@ -1137,6 +1157,6 @@ template struct LaunchConv2DOp; template struct LaunchConv2DOp; template struct LaunchConv2DOp; -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h index ccd24fcdd4c5e4..2f09383714ce75 100644 --- a/tensorflow/core/kernels/conv_ops.h +++ b/tensorflow/core/kernels/conv_ops.h @@ -21,10 +21,10 @@ limitations under the License. #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/util/tensor_format.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/platform/stream_executor.h" -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM namespace tensorflow { @@ -41,7 +41,7 @@ struct LaunchConv2DOp { TensorFormat data_format); }; -#ifdef GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM template struct LaunchConv2DOp { void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, @@ -51,7 +51,7 @@ struct LaunchConv2DOp { const std::vector& explicit_paddings, Tensor* output, TensorFormat data_format); }; -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Used to keep track of persistent memory buffers used within the op. // It uses malloc and free to avoid the time cost of initializing the memory. diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc index 32a5a5e0a36149..076db5c5442a7c 100644 --- a/tensorflow/core/kernels/conv_ops_3d.cc +++ b/tensorflow/core/kernels/conv_ops_3d.cc @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/protobuf/autotuning.pb.h" #include "tensorflow/core/util/proto/proto_utils.h" @@ -186,7 +186,7 @@ TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // A dummy type to group forward convolution autotune results together. struct Conv3dAutoTuneGroup { @@ -435,6 +435,7 @@ struct LaunchConvOp { if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find( conv_parameters, &algorithm_config)) { +#if GOOGLE_CUDA std::vector algorithms; OP_REQUIRES(ctx, stream->parent()->GetConvolveAlgorithms( @@ -477,6 +478,21 @@ struct LaunchConvOp { filter_ptr, output_ptr, input_desc, filter_desc, output_desc, conv_desc, stream->parent(), results); OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config)); +#elif TENSORFLOW_USE_ROCM + ProfileResult best_result; + DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + bool miopen_find_status = + stream + ->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc, + filter_ptr, conv_desc, output_desc, + &output_ptr, &scratch_allocator, + AlgorithmConfig(), &best_result) + .ok(); + OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(), + errors::NotFound("Failed to find conv algorithm!")); + algorithm_config.set_algorithm(best_result.algorithm()); + algorithm_config.set_scratch_size(best_result.scratch_size()); +#endif AutoTuneConv3d::GetInstance()->Insert(conv_parameters, algorithm_config); } @@ -555,6 +571,6 @@ REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER( Name("Conv3D").Device(DEVICE_GPU).TypeConstraint("T"), Conv3DOp); -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index 70ffcd6291dd3d..7906f74c616e69 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_ #define TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include #include @@ -214,6 +214,6 @@ typedef Eigen::GpuDevice GPUDevice; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_ diff --git a/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc index 52859af950e3c3..f23630783bd840 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_2.cu.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU @@ -36,4 +36,4 @@ template struct functor::InflatePadAndShuffle; } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc index bb0cd9e26e2c93..ab338a2550c195 100644 --- a/tensorflow/core/kernels/conv_ops_test.cc +++ b/tensorflow/core/kernels/conv_ops_test.cc @@ -36,7 +36,7 @@ limitations under the License. namespace tensorflow { -#if GOOGLE_CUDA +#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM struct ConvParametersPeer { template @@ -91,7 +91,7 @@ TEST(ConvParameters, WinogradNonfusedAlgoSize) { conv_params_large.ShouldIncludeWinogradNonfusedAlgoPreCudnn7()); } -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM class FusedResizePadConvOpTest : public OpsTestBase { protected: diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index f410875e77eede..c63f355fdfeaaa 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -822,10 +822,22 @@ class ProfileResult { // algorithm: the primary algorithm that should be used. // algorithm_no_scratch: a secondary algorithm that should be used, if the // the allocation for the scratch memory fails. +// scrach_size: specify the size of scratch memory in bytes needed for the +// algorithm used +// +// On CUDA platform with CUDNN library, algorithm and algorithm_no_scratch +// would be used. On ROCm platform with MIOpen library, algorithm and +// scratch_size would be used. The major difference between the two platforms +// are whether it's possible to get an algorithm without scratch memory. On +// CUDA + CUDNN it's possible, and algorithm_no_scratch can be used to track +// such information, whereas on ROCm + MIOpen there is no guarantee to getting +// one without scratch memory, and scratch_size field is used to track it. class AlgorithmConfig { public: AlgorithmConfig() {} explicit AlgorithmConfig(AlgorithmDesc algorithm) : algorithm_(algorithm) {} + AlgorithmConfig(AlgorithmDesc algorithm, size_t scratch_size) + : algorithm_(algorithm), scratch_size_(scratch_size) {} AlgorithmConfig(AlgorithmDesc algorithm, AlgorithmDesc algorithm_no_scratch) : algorithm_(algorithm), algorithm_no_scratch_(algorithm_no_scratch) {} absl::optional algorithm() const { return algorithm_; } @@ -836,9 +848,12 @@ class AlgorithmConfig { void set_algorithm_no_scratch(AlgorithmDesc val) { algorithm_no_scratch_ = val; } + absl::optional scratch_size() const { return scratch_size_; } + void set_scratch_size(size_t val) { scratch_size_ = val; } bool operator==(const AlgorithmConfig& other) const { return this->algorithm_ == other.algorithm_ && - this->algorithm_no_scratch_ == other.algorithm_no_scratch_; + this->algorithm_no_scratch_ == other.algorithm_no_scratch_ && + this->scratch_size_ == other.scratch_size_; } bool operator!=(const AlgorithmConfig& other) const { return !(*this == other); @@ -848,6 +863,7 @@ class AlgorithmConfig { private: absl::optional algorithm_; absl::optional algorithm_no_scratch_; + absl::optional scratch_size_; }; // Describes a local response normalization (LRN). LRN is used e.g. in diff --git a/tensorflow/stream_executor/rocm/rocm_dnn.cc b/tensorflow/stream_executor/rocm/rocm_dnn.cc index 622ecf16a7db0f..4a0df0af17125b 100644 --- a/tensorflow/stream_executor/rocm/rocm_dnn.cc +++ b/tensorflow/stream_executor/rocm/rocm_dnn.cc @@ -2732,12 +2732,7 @@ port::Status MIOpenSupport::DoPrepareForConvolution( } else { // An algorithm has been specified. *algorithm_desc = *algo_desc; - // commenting this line out for the upstream repo, since - // AlgorithmConfig::scratch_size_ has been removed in the upstream repo but - // is still used in the ROCM develop-upstream repo - // - // scratch_memory_size = *(algorithm_config.scratch_size()); - // + scratch_memory_size = *(algorithm_config.scratch_size()); } // allocate scratch memory