Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix https://github.com/tensorflow/tensorflow/issues/29432 #29454

Merged
merged 1 commit into from Jun 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 52 additions & 11 deletions tensorflow/core/kernels/fused_batch_norm_op.cc
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/fused_batch_norm_op.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/tensor_format.h"

namespace tensorflow {
Expand Down Expand Up @@ -247,8 +248,8 @@ struct FusedBatchNorm<CPUDevice, T, U> {
#else
Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
one_by_depth.set(1, depth);
Eigen::IndexList<Eigen::type2index<0> > reduce_dims;
Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > bcast_spec;
Eigen::IndexList<Eigen::type2index<0>> reduce_dims;
Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
bcast_spec.set(0, rest_size);
#endif

Expand Down Expand Up @@ -341,8 +342,8 @@ struct FusedBatchNormGrad<CPUDevice, T, U> {
#else
Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
one_by_depth.set(1, depth);
Eigen::IndexList<Eigen::type2index<0> > reduce_dims;
Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > bcast_spec;
Eigen::IndexList<Eigen::type2index<0>> reduce_dims;
Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
bcast_spec.set(0, rest_size);
#endif

Expand Down Expand Up @@ -385,6 +386,26 @@ struct FusedBatchNormGrad<CPUDevice, T, U> {
};

#if GOOGLE_CUDA

namespace {
// NOTE(ezhulenev): See `BatchnormSpatialPersistentEnabled` documentation in the
// `cuda_dnn.cc` for details.
bool BatchnormSpatialPersistentEnabled() {
#if CUDNN_VERSION >= 7402
static bool is_enabled = [] {
bool is_enabled = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
"TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT",
/*default_val=*/false, &is_enabled));
return is_enabled;
}();
return is_enabled;
#else
return false;
#endif
}
} // namespace

template <typename T, typename U>
struct FusedBatchNorm<GPUDevice, T, U> {
void operator()(OpKernelContext* context, const Tensor& x,
Expand Down Expand Up @@ -575,13 +596,28 @@ struct FusedBatchNormGrad<GPUDevice, T, U> {
const int64 height = GetTensorDim(x, tensor_format, 'H');
const int64 width = GetTensorDim(x, tensor_format, 'W');

// Check if cuDNN batch normalization has a fast NHWC implementation:
// (1) Tensorflow enabled batchnorm spatial persistence, and
// FusedBatchNormGradV3 passed non-null reserve space and allocator.
const bool fast_nhwc_batch_norm = BatchnormSpatialPersistentEnabled() &&
DataTypeToEnum<T>::value == DT_HALF &&
reserve_space != nullptr &&
workspace_allocator != nullptr;

// If input tensor is in NHWC format, and we have a fast cuDNN
// implementation, there is no need to do data format conversion.
TensorFormat compute_format =
fast_nhwc_batch_norm && tensor_format == FORMAT_NHWC ? FORMAT_NHWC
: FORMAT_NCHW;

VLOG(2) << "FusedBatchNormGrad:"
<< " batch_size: " << batch_size << " channels: " << channels
<< " height: " << height << " width: " << width
<< " y_backprop shape: " << y_backprop.shape().DebugString()
<< " x shape: " << x.shape().DebugString()
<< " scale shape: " << scale.shape().DebugString()
<< " tensor format: " << tensor_format;
<< " tensor format: " << tensor_format
<< " compute format: " << compute_format;

// Inputs
Tensor y_backprop_maybe_transformed = y_backprop;
Expand All @@ -593,9 +629,9 @@ struct FusedBatchNormGrad<GPUDevice, T, U> {
Tensor x_backprop_transformed;
se::DeviceMemory<T> x_backprop_ptr;

if (tensor_format == FORMAT_NCHW) {
if (tensor_format == compute_format) {
x_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*x_backprop);
} else if (tensor_format == FORMAT_NHWC) {
} else if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
// Transform inputs from 'NHWC' to 'NCHW'
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value,
Expand Down Expand Up @@ -629,17 +665,22 @@ struct FusedBatchNormGrad<GPUDevice, T, U> {
x_backprop_ptr =
StreamExecutorUtil::AsDeviceMemory<T>(x_backprop_transformed);
} else {
context->SetStatus(
errors::Internal("Unsupported tensor format: ", tensor_format));
context->SetStatus(errors::Internal(
"Unsupported tensor format: ", ToString(tensor_format),
" and compute format: ", ToString(compute_format)));
return;
}

const se::dnn::DataLayout data_layout =
compute_format == FORMAT_NHWC ? se::dnn::DataLayout::kBatchYXDepth
: se::dnn::DataLayout::kBatchDepthYX;

se::dnn::BatchDescriptor x_desc;
x_desc.set_count(batch_size)
.set_feature_map_count(channels)
.set_height(height)
.set_width(width)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
.set_layout(data_layout);

se::dnn::BatchDescriptor scale_offset_desc;
scale_offset_desc.set_count(1)
Expand Down Expand Up @@ -681,7 +722,7 @@ struct FusedBatchNormGrad<GPUDevice, T, U> {
errors::Internal("cuDNN launch failure : input shape (",
x.shape().DebugString(), ")"));
}
if (tensor_format == FORMAT_NHWC) {
if (tensor_format == FORMAT_NHWC && compute_format == FORMAT_NCHW) {
functor::NCHWToNHWC<GPUDevice, T, 4>()(
context->eigen_device<GPUDevice>(),
const_cast<const Tensor&>(x_backprop_transformed).tensor<T, 4>(),
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/fused_batch_norm_op_test.cc
Expand Up @@ -157,7 +157,7 @@ static Graph* FusedBatchNormInference(int n, int h, int w, int c,
Node* empty = test::graph::Constant(g, empty_t, "empty");

Node* fused_batch_norm;
TF_CHECK_OK(NodeBuilder(g->NewName("fused_batch_norm"), "FusedBatchNormV2")
TF_CHECK_OK(NodeBuilder(g->NewName("fused_batch_norm"), "FusedBatchNormV3")
.Input(x)
.Input(other) // scale
.Input(other) // offset
Expand Down