Skip to content

Commit

Permalink
Add ROCm support for 2D and 3D convolution ops.
Browse files Browse the repository at this point in the history
MIOpen has its own algorithm find logic and would return the size of scratch
memory required to client applications. Modify AlgorithmConfig to track such
information.
  • Loading branch information
whchung committed Jun 26, 2019
1 parent 232fb86 commit b7d4805
Show file tree
Hide file tree
Showing 12 changed files with 166 additions and 48 deletions.
31 changes: 24 additions & 7 deletions tensorflow/core/kernels/conv_grad_filter_ops.cc
Expand Up @@ -47,12 +47,12 @@ limitations under the License.
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
#endif

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/protobuf/autotuning.pb.h"
#include "tensorflow/core/util/proto/proto_utils.h"
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

namespace {

Expand Down Expand Up @@ -430,7 +430,7 @@ template struct LaunchConv2DBackpropFilterOp<CPUDevice, float>;
template struct LaunchConv2DBackpropFilterOp<CPUDevice, double>;

// GPU definitions.
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// The slow version (but compiles for GPU)

// A dummy type to group forward backward filter autotune results together.
Expand Down Expand Up @@ -683,8 +683,8 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
Tensor compatible_input;
if (padding_top != padding_bottom || padding_left != padding_right) {
// Pad the input in the same way we did during the forward pass, so that
// cuDNN receives the same input during the backward pass function as it did
// during the forward pass function.
// cuDNN or MIOpen receives the same input during the backward pass function
// as it did during the forward pass function.
const int64 padding_rows_diff = std::abs(padding_bottom - padding_top);
const int64 padding_cols_diff = std::abs(padding_right - padding_left);
const int64 new_in_rows =
Expand Down Expand Up @@ -871,6 +871,7 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdFilter::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
#if GOOGLE_CUDA
std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
Expand Down Expand Up @@ -907,6 +908,22 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
filter_desc, output_desc, conv_desc,
stream->parent(), results);
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
#elif TENSORFLOW_USE_ROCM
ProfileResult best_result;
DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
ctx);
bool miopen_find_status =
stream
->ThenConvolveBackwardFilterWithAlgorithm(
input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
filter_desc, &filter_backprop_ptr, &scratch_allocator,
AlgorithmConfig(), &best_result)
.ok();
OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(),
errors::NotFound("Failed to find backward filter algorithm!"));
algorithm_config.set_algorithm(best_result.algorithm());
algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters,
algorithm_config);
}
Expand All @@ -921,7 +938,7 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(

if (!cudnn_launch_status) {
ctx->SetStatus(errors::Internal(
"cuDNN Backward Filter function launch failure : input shape(",
"DNN Backward Filter function launch failure : input shape(",
input.shape().DebugString(), ") filter shape(",
filter_shape.DebugString(), ")"));
return;
Expand Down Expand Up @@ -1003,6 +1020,6 @@ template struct LaunchConv2DBackpropFilterOp<GPUDevice, float>;
template struct LaunchConv2DBackpropFilterOp<GPUDevice, Eigen::half>;
template struct LaunchConv2DBackpropFilterOp<GPUDevice, double>;

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

} // namespace tensorflow
35 changes: 28 additions & 7 deletions tensorflow/core/kernels/conv_grad_input_ops.cc
Expand Up @@ -47,12 +47,12 @@ limitations under the License.
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
#endif

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/protobuf/autotuning.pb.h"
#include "tensorflow/core/util/proto/proto_utils.h"
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

namespace {

Expand Down Expand Up @@ -765,7 +765,9 @@ template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;

#if GOOGLE_CUDA
// GPU definitions.
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// The slow version (but compiles for GPU)

// A dummy type to group forward backward data autotune results together.
struct ConvBackwardDataAutoTuneGroup {
Expand Down Expand Up @@ -911,8 +913,8 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
TensorShape compatible_input_shape;
if (padding_top != padding_bottom || padding_left != padding_right) {
// Pad the input in the same way we did during the forward pass, so that
// cuDNN receives the same input during the backward pass function as it did
// during the forward pass function.
// cuDNN or MIOpen receives the same input during the backward pass function
// as it did during the forward pass function.
const int64 padding_rows_diff = std::abs(padding_bottom - padding_top);
const int64 padding_cols_diff = std::abs(padding_right - padding_left);
const int64 new_in_rows =
Expand Down Expand Up @@ -1094,6 +1096,7 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
#if GOOGLE_CUDA
std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
Expand Down Expand Up @@ -1128,6 +1131,24 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
in_backprop_ptr, filter_ptr, out_backprop_ptr, input_desc, filter_desc,
output_desc, conv_desc, stream->parent(), results);
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
#elif TENSORFLOW_USE_ROCM
// MIOpen has its own Find and autotuner so use it here, passing
// default AlgorithmConfig to force a search
DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx);
ProfileResult best_result;
bool miopen_find_status =
stream
->ThenConvolveBackwardDataWithAlgorithm(
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
AlgorithmConfig(), &best_result)
.ok();
OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(),
errors::NotFound("Failed to find backwards-data algorithm!"));

algorithm_config.set_algorithm(best_result.algorithm());
algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
algorithm_config);
}
Expand All @@ -1141,7 +1162,7 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(

if (!cudnn_launch_status) {
ctx->SetStatus(errors::Internal(
"cuDNN Backward Data function launch failure : input shape(",
"DNN Backward Data function launch failure : input shape(",
input_shape.DebugString(), ") filter shape(",
filter_shape.DebugString(), ")"));
return;
Expand Down Expand Up @@ -1255,6 +1276,6 @@ template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

} // namespace tensorflow
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/conv_grad_ops.h
Expand Up @@ -191,7 +191,7 @@ struct LaunchConv2DBackpropFilterOp {
Tensor* filter_backprop, TensorFormat data_format);
};

#ifdef GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
template <typename T>
struct LaunchConv2DBackpropInputOp<Eigen::GpuDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
Expand All @@ -211,7 +211,7 @@ struct LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T> {
const std::vector<int64>& explicit_paddings,
Tensor* filter_backprop, TensorFormat data_format);
};
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// Information about a single spatial dimension for a convolution
// backpropagation.
Expand Down
47 changes: 40 additions & 7 deletions tensorflow/core/kernels/conv_grad_ops_3d.cc
Expand Up @@ -38,7 +38,7 @@ limitations under the License.
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
#endif

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/platform/stream_executor.h"
using stream_executor::dnn::DimIndex;
#endif
Expand Down Expand Up @@ -1049,7 +1049,7 @@ TF_CALL_half(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL

// GPU definitions of both ops.
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// Forward declarations of the functor specializations for GPU.
// This ensures that the custom implementation is used instead of the default
// Eigen one (which is used for CPU).
Expand Down Expand Up @@ -1355,15 +1355,16 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
using se::dnn::AlgorithmDesc;
using se::dnn::ProfileResult;
AlgorithmConfig algorithm_config;
ProfileResult best_result;
ProfileResult best_result_no_scratch;
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
#if GOOGLE_CUDA
std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
stream->parent()),
&algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
// TODO(zhengxq): profile each algorithm multiple times to better
// accuracy.
Expand Down Expand Up @@ -1401,6 +1402,21 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
algorithm_config.set_algorithm_no_scratch(
best_result_no_scratch.algorithm());
}
#elif TENSORFLOW_USE_ROCM
DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
context);
bool miopen_find_status =
stream
->ThenConvolveBackwardDataWithAlgorithm(
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
AlgorithmConfig(), &best_result)
.ok();
OP_REQUIRES(context, miopen_find_status && best_result.is_valid(),
errors::NotFound("Failed to find backward data algorithm!"));
algorithm_config.set_algorithm(best_result.algorithm());
algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
AutoTuneConv3dBwdData::GetInstance()->Insert(conv_parameters,
algorithm_config);
}
Expand Down Expand Up @@ -1761,15 +1777,16 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
using se::dnn::AlgorithmDesc;
using se::dnn::ProfileResult;
AlgorithmConfig algorithm_config;
ProfileResult best_result;
ProfileResult best_result_no_scratch;
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
#if GOOGLE_CUDA
std::vector<AlgorithmDesc> algorithms;
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
stream->parent()),
&algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
// TODO(zhengxq): profile each algorithm multiple times to better
// accuracy.
Expand Down Expand Up @@ -1808,6 +1825,22 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
algorithm_config.set_algorithm_no_scratch(
best_result_no_scratch.algorithm());
}
#elif TENSORFLOW_USE_ROCM
DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
context);
bool miopen_find_status =
stream
->ThenConvolveBackwardFilterWithAlgorithm(
input_desc, input_ptr, output_desc, out_backprop_ptr,
conv_desc, filter_desc, &filter_backprop_ptr,
&scratch_allocator, AlgorithmConfig(), &best_result)
.ok();
OP_REQUIRES(
context, miopen_find_status && best_result.is_valid(),
errors::NotFound("Failed to find backward filter algorithm!"));
algorithm_config.set_algorithm(best_result.algorithm());
algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
AutoTuneConv3dBwdFilter::GetInstance()->Insert(conv_parameters,
algorithm_config);
}
Expand Down Expand Up @@ -1866,6 +1899,6 @@ TF_CALL_float(REGISTER_GPU_KERNEL);
TF_CALL_double(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

} // namespace tensorflow
30 changes: 25 additions & 5 deletions tensorflow/core/kernels/conv_ops.cc
Expand Up @@ -18,9 +18,9 @@ limitations under the License.
#define USE_EIGEN_TENSOR
#define EIGEN_USE_THREADS

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#include "tensorflow/core/kernels/conv_ops.h"

Expand Down Expand Up @@ -57,11 +57,13 @@ limitations under the License.
#include "tensorflow/core/kernels/xsmm_conv2d.h"
#endif

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/protobuf/autotuning.pb.h"
#include "tensorflow/core/util/proto/proto_utils.h"
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#if GOOGLE_CUDA
#include "tensorflow/stream_executor/cuda/ptxas_utils.h"
#include "tensorflow/stream_executor/cuda/redzone_allocator.h"
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
Expand Down Expand Up @@ -574,7 +576,7 @@ template struct LaunchConv2DOp<CPUDevice, Eigen::half>;
template struct LaunchConv2DOp<CPUDevice, float>;
template struct LaunchConv2DOp<CPUDevice, double>;

#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
int64 GetDnnWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes) {
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
Expand Down Expand Up @@ -978,6 +980,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune &&
!AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
#if GOOGLE_CUDA
std::vector<AlgorithmDesc> algorithms;
OP_REQUIRES(
ctx,
Expand Down Expand Up @@ -1049,6 +1052,23 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
output_tensor, input_desc, filter_desc, output_desc,
conv_desc, stream->parent(), results);
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
#elif TENSORFLOW_USE_ROCM
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
ProfileResult best_result;
bool miopen_find_status =
stream
->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc,
filter_ptr, conv_desc, output_desc,
&output_ptr, &scratch_allocator,
AlgorithmConfig(), &best_result)
.ok();

OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(),
errors::NotFound("Failed to find conv algorithm!"));

algorithm_config.set_algorithm(best_result.algorithm());
algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
AutoTuneConv::GetInstance()->Insert(conv_parameters, algorithm_config);
}

Expand Down Expand Up @@ -1137,6 +1157,6 @@ template struct LaunchConv2DOp<GPUDevice, float>;
template struct LaunchConv2DOp<GPUDevice, Eigen::half>;
template struct LaunchConv2DOp<GPUDevice, double>;

#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

} // namespace tensorflow

0 comments on commit b7d4805

Please sign in to comment.