Skip to content

Commit

Permalink
Add shape checks to FusedBatchNorm kernels.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 399755576
Change-Id: If8049fde109cc33badb5509d174b9b95aee1ea5e
  • Loading branch information
reedwm authored and tensorflower-gardener committed Sep 29, 2021
1 parent 56f86d6 commit aab9998
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 7 deletions.
38 changes: 31 additions & 7 deletions tensorflow/core/kernels/fused_batch_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1340,18 +1340,20 @@ class FusedBatchNormOpBase : public OpKernel {
errors::InvalidArgument("offset must have the same number of elements "
"as the channels of x, got ",
offset.NumElements(), " and ", num_channels));
if (estimated_mean.NumElements() != 0) {
if (!is_training_ || exponential_avg_factor_ != 1.) {
std::string prefix_msg = is_training_ ? "When exponential_avg_factor != 1"
: "When is_training=false";
OP_REQUIRES(context, estimated_mean.NumElements() == num_channels,
errors::InvalidArgument(
"mean must be empty or have the same number of "
"elements as the channels of x, got ",
prefix_msg,
", mean must have the same number "
"of elements as the channels of x, got ",
estimated_mean.NumElements(), " and ", num_channels));
}
if (estimated_variance.NumElements() != 0) {
OP_REQUIRES(context, estimated_variance.NumElements() == num_channels,
errors::InvalidArgument(
"variance must be empty or have the same number of "
"elements as the channels of x, got ",
prefix_msg,
", variance must have the same "
"number of elements as the channels of x, got ",
estimated_variance.NumElements(), " and ", num_channels));
}

Expand Down Expand Up @@ -1543,6 +1545,11 @@ class FusedBatchNormGradOpBase : public OpKernel {
errors::InvalidArgument(
"saved variance must be 1-dimensional",
saved_maybe_inv_var_or_pop_var.shape().DebugString()));
OP_REQUIRES(
context, x.shape() == y_backprop.shape(),
errors::InvalidArgument(
"x and y_backprop must have same shape, but x has shape ",
x.shape(), " and y_backprop has shape ", y_backprop.shape()));
if (use_activation) {
OP_REQUIRES(
context, x.dim_size(3) % 4 == 0,
Expand All @@ -1569,6 +1576,23 @@ class FusedBatchNormGradOpBase : public OpKernel {
errors::InvalidArgument("Error during tensor copy."));
}

const auto num_channels = GetTensorDim(x, tensor_format_, 'C');
OP_REQUIRES(
context, scale.NumElements() == num_channels,
errors::InvalidArgument("scale must have the same number of elements "
"as the channels of x, got ",
scale.NumElements(), " and ", num_channels));
OP_REQUIRES(
context, saved_mean_or_pop_mean.NumElements() == num_channels,
errors::InvalidArgument("reserve_space_1 must have the same number of "
"elements as the channels of x, got ",
scale.NumElements(), " and ", num_channels));
OP_REQUIRES(
context, saved_maybe_inv_var_or_pop_var.NumElements() == num_channels,
errors::InvalidArgument("reserve_space_2 must have the same number of "
"elements as the channels of x, got ",
scale.NumElements(), " and ", num_channels));

Tensor* x_backprop = nullptr;
auto alloc_shape = use_reshape ? dest_shape : x_shape;
OP_REQUIRES_OK(context,
Expand Down
123 changes: 123 additions & 0 deletions tensorflow/python/ops/nn_fused_batchnorm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

import numpy as np

from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
Expand Down Expand Up @@ -694,6 +697,126 @@ def test5dBatchNormFollowedByRelu(self):
y_ref = np.maximum(y_ref, 0.)
self.assertAllClose(y_ref, y_val, atol=1e-3)

def testEagerShapeErrors(self):
with context.eager_mode():
x = array_ops.ones((2, 2, 2, 2))
scale = array_ops.ones((3,))
offset = array_ops.ones((2,))
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
'scale must have the same number of elements'):
nn_impl.fused_batch_norm(x, scale, offset)

x = array_ops.ones((2, 2, 2, 2))
scale = array_ops.ones((2,))
offset = array_ops.ones((3,))
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
'offset must have the same number of elements'):
nn_impl.fused_batch_norm(x, scale, offset)

x = array_ops.ones((2, 2, 2, 2))
scale = array_ops.ones((2,))
offset = array_ops.ones((2,))
mean = array_ops.ones((0,))
variance = array_ops.ones((2,))
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
'When is_training=false, mean must have the same number of elements'):
nn_impl.fused_batch_norm(
x, scale, offset, mean=mean, variance=variance, is_training=False)

x = array_ops.ones((2, 2, 2, 2))
scale = array_ops.ones((2,))
offset = array_ops.ones((2,))
mean = array_ops.ones((2,))
variance = array_ops.ones((0,))
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
'When is_training=false, variance must have the same number of '
'elements'):
nn_impl.fused_batch_norm(
x, scale, offset, mean=mean, variance=variance, is_training=False)

x = array_ops.ones((2, 2, 2, 2))
scale = array_ops.ones((2,))
offset = array_ops.ones((2,))
mean = array_ops.ones((0,))
variance = array_ops.ones((2,))
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
'When exponential_avg_factor != 1, mean must have the same number of '
'elements'):
nn_impl.fused_batch_norm(
x,
scale,
offset,
mean=mean,
variance=variance,
exponential_avg_factor=0.5)

x = array_ops.ones((2, 2, 2, 2))
scale = array_ops.ones((2,))
offset = array_ops.ones((2,))
mean = array_ops.ones((2,))
variance = array_ops.ones((0,))
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
'When exponential_avg_factor != 1, variance must have the same '
'number of elements'):
nn_impl.fused_batch_norm(
x,
scale,
offset,
mean=mean,
variance=variance,
exponential_avg_factor=0.5)

def testEagerShapeGradErrors(self):
with context.eager_mode():
y_backprop = array_ops.ones((2, 2, 2, 3))
x = array_ops.ones((2, 2, 2, 2))
scale = array_ops.ones((2,))
reserve_space_1 = array_ops.ones((2,))
reserve_space_2 = array_ops.ones((2,))
with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
'x and y_backprop must have same shape,'):
gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale,
reserve_space_1, reserve_space_2)

y_backprop = array_ops.ones((2, 2, 2, 2))
x = array_ops.ones((2, 2, 2, 2))
scale = array_ops.ones((3,))
reserve_space_1 = array_ops.ones((2,))
reserve_space_2 = array_ops.ones((2,))
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
'scale must have the same number of elements'):
gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale,
reserve_space_1, reserve_space_2)

y_backprop = array_ops.ones((2, 2, 2, 2))
x = array_ops.ones((2, 2, 2, 2))
scale = array_ops.ones((2,))
reserve_space_1 = array_ops.ones((3,))
reserve_space_2 = array_ops.ones((2,))
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
'reserve_space_1 must have the same number of elements'):
gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale,
reserve_space_1, reserve_space_2)

y_backprop = array_ops.ones((2, 2, 2, 2))
x = array_ops.ones((2, 2, 2, 2))
scale = array_ops.ones((2,))
reserve_space_1 = array_ops.ones((2,))
reserve_space_2 = array_ops.ones((3,))
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
'reserve_space_2 must have the same number of elements'):
gen_nn_ops.fused_batch_norm_grad_v2(y_backprop, x, scale,
reserve_space_1, reserve_space_2)


if __name__ == '__main__':
test.main()

0 comments on commit aab9998

Please sign in to comment.