Skip to content

Commit

Permalink
Automated rollback of commit 65bad48
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 295346786
Change-Id: I089897b1b02373c02049b03bc3b986a2442a5127
  • Loading branch information
tensorflower-gardener committed Feb 15, 2020
1 parent 65bad48 commit 80acd88
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 264 deletions.
290 changes: 99 additions & 191 deletions tensorflow/core/kernels/fused_batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,163 +81,25 @@ Status ParseActivationMode(OpKernelConstruction* context,
}

// Functor used by FusedBatchNormOp to do the computations.
template <typename Device, typename T, typename U, bool is_training>
template <typename Device, typename T, typename U>
struct FusedBatchNorm;
// Functor used by FusedBatchNormGradOp to do the computations when
// is_training=True.
template <typename Device, typename T, typename U>
struct FusedBatchNormGrad;

template <typename T, typename U>
struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ true> {
void operator()(OpKernelContext* context, const Tensor& x_input,
const Tensor& scale_input, const Tensor& offset_input,
const Tensor& running_mean_input,
const Tensor& running_variance_input,
const Tensor* side_input, U epsilon, U exponential_avg_factor,
FusedBatchNormActivationMode activation_mode,
Tensor* y_output, Tensor* running_mean_output,
Tensor* running_var_output, Tensor* saved_batch_mean_output,
Tensor* saved_batch_var_output, TensorFormat tensor_format,
bool use_reserved_space) {
OP_REQUIRES(context, side_input == nullptr,
errors::Internal(
"The CPU implementation of FusedBatchNorm does not support "
"side input."));
OP_REQUIRES(context,
activation_mode == FusedBatchNormActivationMode::kIdentity,
errors::Internal("The CPU implementation of FusedBatchNorm "
"does not support activations."));

if (use_reserved_space) {
Tensor* dummy_reserve_space = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(5, {}, &dummy_reserve_space));
// Initialize the memory, to avoid sanitizer alerts.
dummy_reserve_space->flat<U>()(0) = U();
}
Tensor transformed_x;
Tensor transformed_y;
if (tensor_format == FORMAT_NCHW) {
const int64 in_batch = GetTensorDim(x_input, tensor_format, 'N');
const int64 in_rows = GetTensorDim(x_input, tensor_format, 'H');
const int64 in_cols = GetTensorDim(x_input, tensor_format, 'W');
const int64 in_depths = GetTensorDim(x_input, tensor_format, 'C');
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value,
ShapeFromFormat(FORMAT_NHWC, in_batch,
in_rows, in_cols, in_depths),
&transformed_x));
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value,
ShapeFromFormat(FORMAT_NHWC, in_batch,
in_rows, in_cols, in_depths),
&transformed_y));
// Perform NCHW to NHWC
std::vector<int32> perm = {0, 2, 3, 1};
OP_REQUIRES_OK(
context, ::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
x_input, perm, &transformed_x));
} else {
transformed_x = x_input;
transformed_y = *y_output;
}
typename TTypes<T, 4>::Tensor x(transformed_x.tensor<T, 4>());
typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
typename TTypes<U>::ConstVec offset(offset_input.vec<U>());
typename TTypes<U>::ConstVec old_mean(running_mean_input.vec<U>());
typename TTypes<U>::ConstVec old_variance(running_variance_input.vec<U>());
typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>());
typename TTypes<U>::Vec new_mean(running_mean_output->vec<U>());
typename TTypes<U>::Vec new_variance(running_var_output->vec<U>());
typename TTypes<U>::Vec saved_batch_mean(saved_batch_mean_output->vec<U>());
typename TTypes<U>::Vec saved_batch_var(saved_batch_var_output->vec<U>());

const CPUDevice& d = context->eigen_device<CPUDevice>();

const int depth = x.dimension(3);
const int size = x.size();
const int rest_size = size / depth;
Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);

#if !defined(EIGEN_HAS_INDEX_LIST)
Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
Eigen::array<int, 1> reduce_dims({0});
Eigen::array<int, 2> bcast_spec({rest_size, 1});
#else
Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
one_by_depth.set(1, depth);
Eigen::IndexList<Eigen::type2index<0>> reduce_dims;
Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
bcast_spec.set(0, rest_size);
#endif

auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1;
U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
// This adjustment is for Bessel's correction
U rest_size_adjust =
static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one);

Eigen::Tensor<U, 1, Eigen::RowMajor> batch_mean(depth);
Eigen::Tensor<U, 1, Eigen::RowMajor> batch_variance(depth);

batch_mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
auto x_centered = x_rest_by_depth -
batch_mean.reshape(one_by_depth).broadcast(bcast_spec);

batch_variance.device(d) =
x_centered.square().sum(reduce_dims) * rest_size_inv;
auto scaling_factor = ((batch_variance + epsilon).rsqrt() * scale)
.eval()
.reshape(one_by_depth)
.broadcast(bcast_spec);
auto x_scaled = x_centered * scaling_factor;
auto x_shifted =
(x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
.template cast<T>();

y.reshape(rest_by_depth).device(d) = x_shifted;
if (exponential_avg_factor == U(1.0)) {
saved_batch_var.device(d) = batch_variance;
saved_batch_mean.device(d) = batch_mean;
new_variance.device(d) = batch_variance * rest_size_adjust;
new_mean.device(d) = batch_mean;
} else {
U one_minus_factor = U(1) - exponential_avg_factor;
saved_batch_var.device(d) = batch_variance;
saved_batch_mean.device(d) = batch_mean;
new_variance.device(d) =
one_minus_factor * old_variance +
(exponential_avg_factor * rest_size_adjust) * batch_variance;
new_mean.device(d) =
one_minus_factor * old_mean + exponential_avg_factor * batch_mean;
}

if (tensor_format == FORMAT_NCHW) {
// Perform NHWC to NCHW
const std::vector<int32> perm = {0, 3, 1, 2};
const Status s = ::tensorflow::DoTranspose(
context->eigen_device<CPUDevice>(), transformed_y, perm, y_output);
if (!s.ok()) {
context->SetStatus(errors::InvalidArgument("Transpose failed: ", s));
}
}
}
};

template <typename T, typename U>
struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ false> {
struct FusedBatchNorm<CPUDevice, T, U> {
void operator()(OpKernelContext* context, const Tensor& x_input,
const Tensor& scale_input, const Tensor& offset_input,
const Tensor& estimated_mean_input,
const Tensor& estimated_variance_input,
const Tensor* side_input, U epsilon, U exponential_avg_factor,
const Tensor* side_input, U epsilon,
FusedBatchNormActivationMode activation_mode,
Tensor* y_output, Tensor* batch_mean_output,
Tensor* batch_var_output, Tensor* saved_mean_output,
Tensor* saved_var_output, TensorFormat tensor_format,
bool use_reserved_space) {
bool use_reserved_space, bool is_training) {
OP_REQUIRES(context, side_input == nullptr,
errors::Internal(
"The CPU implementation of FusedBatchNorm does not support "
Expand Down Expand Up @@ -288,7 +150,9 @@ struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ false> {
estimated_variance_input.vec<U>());
typename TTypes<T, 4>::Tensor y(transformed_y.tensor<T, 4>());
typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>());
typename TTypes<U>::Vec batch_variance(batch_var_output->vec<U>());
typename TTypes<U>::Vec batch_var(batch_var_output->vec<U>());
typename TTypes<U>::Vec saved_mean(saved_mean_output->vec<U>());
typename TTypes<U>::Vec saved_var(saved_var_output->vec<U>());

const CPUDevice& d = context->eigen_device<CPUDevice>();

Expand All @@ -304,36 +168,80 @@ struct FusedBatchNorm<CPUDevice, T, U, /* is_training= */ false> {
#else
Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
one_by_depth.set(1, depth);
Eigen::IndexList<Eigen::type2index<0>> reduce_dims;
Eigen::IndexList<Eigen::Index, Eigen::type2index<1>> bcast_spec;
bcast_spec.set(0, rest_size);
#endif

auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
auto x_centered =
x_rest_by_depth -
estimated_mean.reshape(one_by_depth).broadcast(bcast_spec);
auto scaling_factor = ((estimated_variance + epsilon).rsqrt() * scale)
.eval()
.reshape(one_by_depth)
.broadcast(bcast_spec);
auto x_scaled = x_centered * scaling_factor;
auto x_shifted =
(x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
.template cast<T>();

y.reshape(rest_by_depth).device(d) = x_shifted;
batch_mean.device(d) = estimated_mean;
batch_variance.device(d) = estimated_variance;
const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1;
U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
// This adjustment is for Bessel's correction
U rest_size_adjust =
static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one);

if (tensor_format == FORMAT_NCHW) {
// Perform NHWC to NCHW
const std::vector<int32> perm = {0, 3, 1, 2};
const Status s = ::tensorflow::DoTranspose(
context->eigen_device<CPUDevice>(), transformed_y, perm, y_output);
if (!s.ok()) {
context->SetStatus(errors::InvalidArgument("Transpose failed: ", s));
Eigen::Tensor<U, 1, Eigen::RowMajor> mean(depth);
Eigen::Tensor<U, 1, Eigen::RowMajor> variance(depth);
BlockingCounter barrier(1);
std::atomic<uint8> task_counter;
auto on_done = [&]() {
uint8 count = --task_counter;
if (count == 0) {
if (tensor_format == FORMAT_NCHW) {
// Perform NHWC to NCHW
const std::vector<int32> perm = {0, 3, 1, 2};
const Status s =
::tensorflow::DoTranspose(context->eigen_device<CPUDevice>(),
transformed_y, perm, y_output);
if (!s.ok()) {
context->SetStatus(
errors::InvalidArgument("Transpose failed: ", s));
}
}
barrier.DecrementCount();
}
};
if (is_training) {
// TODO(b/137108598): Extend kernel to allow use of exponential averaging.
mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
auto x_centered =
x_rest_by_depth - mean.reshape(one_by_depth).broadcast(bcast_spec);

variance.device(d) = x_centered.square().sum(reduce_dims) * rest_size_inv;
auto scaling_factor = ((variance + epsilon).rsqrt() * scale)
.eval()
.reshape(one_by_depth)
.broadcast(bcast_spec);
auto x_scaled = x_centered * scaling_factor;
auto x_shifted =
(x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
.template cast<T>();

task_counter = 5;
y.reshape(rest_by_depth).device(d, on_done) = x_shifted;
batch_var.device(d, on_done) = variance * rest_size_adjust;
saved_var.device(d, on_done) = variance;
batch_mean.device(d, on_done) = mean;
saved_mean.device(d, on_done) = mean;
} else { // is_training == false
auto x_centered =
x_rest_by_depth -
estimated_mean.reshape(one_by_depth).broadcast(bcast_spec);
auto scaling_factor = ((estimated_variance + epsilon).rsqrt() * scale)
.eval()
.reshape(one_by_depth)
.broadcast(bcast_spec);
auto x_scaled = x_centered * scaling_factor;
auto x_shifted =
(x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec))
.template cast<T>();

task_counter = 3;
y.reshape(rest_by_depth).device(d, on_done) = x_shifted;
mean.device(d, on_done) = estimated_mean;
variance.device(d, on_done) = estimated_variance;
}
barrier.Wait();
}
};

Expand Down Expand Up @@ -754,17 +662,17 @@ class CudnnBatchNormAllocatorInOutput : public ScratchAllocator {
bool output_allocated = false;
};

template <typename T, typename U, bool is_training>
struct FusedBatchNorm<GPUDevice, T, U, is_training> {
template <typename T, typename U>
struct FusedBatchNorm<GPUDevice, T, U> {
void operator()(OpKernelContext* context, const Tensor& x,
const Tensor& scale, const Tensor& offset,
const Tensor& estimated_mean,
const Tensor& estimated_variance, const Tensor* side_input,
U epsilon, U exponential_avg_factor,
FusedBatchNormActivationMode activation_mode, Tensor* y,
Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean,
Tensor* saved_inv_var, TensorFormat tensor_format,
bool use_reserved_space) {
U epsilon, FusedBatchNormActivationMode activation_mode,
Tensor* y, Tensor* batch_mean, Tensor* batch_var,
Tensor* saved_mean, Tensor* saved_inv_var,
TensorFormat tensor_format, bool use_reserved_space,
bool is_training) {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));

Expand Down Expand Up @@ -929,13 +837,15 @@ struct FusedBatchNorm<GPUDevice, T, U, is_training> {
workspace_allocator.reset(
new functor::CudnnBatchNormAllocatorInTemp<uint8>(context));
}
// TODO(b/137108598): Extend kernel to allow use of exponential averaging.
const double exponential_average_factor = 1.0;
bool cudnn_launch_status =
stream
->ThenBatchNormalizationForward(
x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr,
estimated_variance_ptr, side_input_ptr, x_desc,
scale_offset_desc, static_cast<double>(epsilon),
static_cast<double>(exponential_avg_factor),
exponential_average_factor,
AsDnnActivationMode(activation_mode), &y_ptr, &batch_mean_ptr,
&batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr,
is_training, reserve_space_allocator.get(),
Expand Down Expand Up @@ -1165,10 +1075,6 @@ class FusedBatchNormOpBase : public OpKernel {
float epsilon;
OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
epsilon_ = U(epsilon);
float exponential_avg_factor;
OP_REQUIRES_OK(context, context->GetAttr("exponential_avg_factor",
&exponential_avg_factor));
exponential_avg_factor_ = U(exponential_avg_factor);
string tensor_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
Expand Down Expand Up @@ -1259,6 +1165,17 @@ class FusedBatchNormOpBase : public OpKernel {
"channel dimension to be a multiple of 4."));
}

if (is_training_) {
OP_REQUIRES(
context, estimated_mean.dim_size(0) == 0,
errors::InvalidArgument("estimated_mean must be empty for training",
estimated_mean.shape().DebugString()));
OP_REQUIRES(context, estimated_variance.dim_size(0) == 0,
errors::InvalidArgument(
"estimated_variance must be empty for training",
estimated_variance.shape().DebugString()));
}

Tensor* y = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, x.shape(), &y));
Expand All @@ -1275,24 +1192,15 @@ class FusedBatchNormOpBase : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(),
&saved_maybe_inv_var));

if (is_training_) {
functor::FusedBatchNorm<Device, T, U, true>()(
context, x, scale, offset, estimated_mean, estimated_variance,
side_input, epsilon_, exponential_avg_factor_, activation_mode_, y,
batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
tensor_format_, use_reserved_space);
} else {
functor::FusedBatchNorm<Device, T, U, false>()(
context, x, scale, offset, estimated_mean, estimated_variance,
side_input, epsilon_, exponential_avg_factor_, activation_mode_, y,
batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
tensor_format_, use_reserved_space);
}
functor::FusedBatchNorm<Device, T, U>()(
context, x, scale, offset, estimated_mean, estimated_variance,
side_input, epsilon_, activation_mode_, y, batch_mean, batch_var,
saved_mean, saved_maybe_inv_var, tensor_format_, use_reserved_space,
is_training_);
}

private:
U epsilon_;
U exponential_avg_factor_;
TensorFormat tensor_format_;
bool is_training_;
bool has_side_input_;
Expand Down
Loading

0 comments on commit 80acd88

Please sign in to comment.