Skip to content

Commit

Permalink
Merge pull request #52834 from tensorflow/mm-cp-4-on-r2.4
Browse files Browse the repository at this point in the history
Error checking for zero size filters for tf.nn.convolution (conv2d, c…
  • Loading branch information
mihaimaruseac committed Oct 28, 2021
2 parents f5bfc75 + a0473f2 commit 6b54dd0
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 136 deletions.
9 changes: 9 additions & 0 deletions tensorflow/core/kernels/conv_ops.cc
Expand Up @@ -201,6 +201,10 @@ struct LaunchConv2DOp<GPUDevice, int32> {
"attempted to be run because the input depth of ",
in_depth, " does not match the filter input depth of ",
filter.dim_size(2)));
OP_REQUIRES(
ctx, filter.NumElements() > 0,
errors::InvalidArgument("filter must not have zero elements "
"(i.e. all dimensions must be non-zero)"));

for (int64 explicit_padding : explicit_paddings) {
if (!FastBoundsCheck(explicit_padding, std::numeric_limits<int>::max())) {
Expand Down Expand Up @@ -674,6 +678,11 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
const int64 patch_cols = filter.dim_size(1);
const int64 patch_depths = filter.dim_size(2);

OP_REQUIRES(
ctx, filter.NumElements() > 0,
errors::InvalidArgument("filter must not have zero elements "
"(i.e. all dimensions must be non-zero)"));

// If the filter in-depth (patch_depths) is 1 and smaller than the input
// depth, it's a depthwise convolution. More generally, if the filter in-depth
// divides but is smaller than the input depth, it is a grouped convolution.
Expand Down
11 changes: 11 additions & 0 deletions tensorflow/python/kernel_tests/conv_ops_3d_test.py
Expand Up @@ -24,6 +24,7 @@

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
Expand Down Expand Up @@ -460,6 +461,16 @@ def testKernelSizeMatchesInputSize(self):
padding="VALID",
expected=[1.5625, 1.875])

def testZeroSizedFilterThrowsIllegalArgument(self):
tensor_in_sizes = [1, 1, 1, 1, 1]
x1 = self._CreateNumpyTensor(tensor_in_sizes)
filter_in = np.ones((1, 1, 0, 1, 1), dtype=np.float32)
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError, "filter must not have zero elements"
"|has a non-positive dimension"):
self.evaluate(
nn_ops.conv3d(x1, filter_in, strides=[1, 1, 1, 1, 1], padding="SAME"))

def _ConstructAndTestGradientForConfig(
self, batch, input_shape, filter_shape, in_depth, out_depth, stride,
padding, test_input, data_format, use_gpu):
Expand Down
216 changes: 80 additions & 136 deletions tensorflow/python/kernel_tests/conv_ops_test.py
Expand Up @@ -2522,139 +2522,83 @@ def testShapeFunctionEdgeCases(self):
strides=[1, 1, 1, 1],
padding=[0, 0, 0, 0])

@test_util.deprecated_graph_mode_only
def testOpEdgeCases(self):
with self.cached_session() as sess:
# Illegal strides.
with self.assertRaisesRegex(errors_impl.UnimplementedError,
"strides in the batch and depth"):
input_placeholder = array_ops.placeholder(dtypes.float32)
input_val = np.ones([10, 10])
filter_placeholder = array_ops.placeholder(dtypes.float32)
filter_val = np.ones([10, 10])
sess.run(
nn_ops.conv2d(
input_placeholder,
filter_placeholder,
strides=[2, 1, 1, 1],
padding="SAME"),
feed_dict={
input_placeholder: input_val,
filter_placeholder: filter_val
})
with self.assertRaisesRegex(errors_impl.UnimplementedError,
"strides in the batch and depth"):
input_placeholder = array_ops.placeholder(dtypes.float32)
filter_placeholder = array_ops.placeholder(dtypes.float32)
input_val = np.ones([10, 10])
filter_val = np.ones([10, 10])
sess.run(
nn_ops.conv2d(
input_placeholder,
filter_placeholder,
strides=[1, 1, 1, 2],
padding="SAME"),
feed_dict={
input_placeholder: input_val,
filter_placeholder: filter_val
})

# Filter larger than input.
with self.assertRaisesRegex(ValueError, "Negative dimension size"):
input_placeholder = array_ops.placeholder(
dtypes.float32, shape=[32, 20, 20, 3])
input_val = np.ones([32, 20, 20, 3])
filter_placeholder = array_ops.placeholder(
dtypes.float32, shape=[20, 21, 3, 2])
filter_val = np.ones([20, 21, 3, 2])

sess.run(
nn_ops.conv2d(
input_placeholder,
filter_placeholder,
strides=[1, 1, 1, 1],
padding="VALID"),
feed_dict={
input_placeholder: input_val,
filter_placeholder: filter_val
})
with self.assertRaisesRegex(ValueError, "Negative dimension size"):
input_placeholder = array_ops.placeholder(
dtypes.float32, shape=[32, 20, 20, 3])
input_val = np.ones([32, 20, 20, 3])
filter_placeholder = array_ops.placeholder(
dtypes.float32, shape=[21, 20, 3, 2])
filter_val = np.ones([21, 20, 3, 2])
sess.run(
nn_ops.conv2d(
input_placeholder,
filter_placeholder,
strides=[1, 1, 1, 1],
padding="VALID"),
feed_dict={
input_placeholder: input_val,
filter_placeholder: filter_val
})

# Filter larger than input + padding.
with self.assertRaisesRegex(ValueError, "Negative dimension size"):
input_placeholder = array_ops.placeholder(
dtypes.float32, shape=[32, 20, 20, 3])
input_val = np.ones([32, 20, 20, 3])
filter_placeholder = array_ops.placeholder(
dtypes.float32, shape=[24, 25, 3, 2])
filter_val = np.ones([24, 25, 3, 2])
sess.run(
nn_ops.conv2d(
input_placeholder,
filter_placeholder,
strides=[1, 1, 1, 1],
padding=[[0, 0], [2, 2], [2, 2], [0, 0]]),
feed_dict={
input_placeholder: input_val,
filter_placeholder: filter_val
})

# Negative padding during backprop.
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
"All elements of explicit_paddings must be nonnegative"):
filter_placeholder = array_ops.placeholder(
dtypes.float32, shape=[18, 18, 3, 2])
filter_val = np.ones([18, 18, 3, 2])
out_backprop = array_ops.placeholder(
dtypes.float32, shape=[32, 3, 2, 2])
out_backprop_val = np.ones([32, 3, 2, 2])
sess.run(
nn_ops.conv2d_backprop_input([32, 20, 20, 3],
filter_placeholder,
out_backprop,
strides=[1, 1, 1, 1],
padding=[[0, 0], [-1, 0], [0, 0],
[0, 0]]),
feed_dict={
filter_placeholder: filter_val,
out_backprop: out_backprop_val
})
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
"All elements of explicit_paddings must be nonnegative"):
input_placeholder = array_ops.placeholder(
dtypes.float32, shape=[32, 20, 20, 3])
input_val = np.ones([32, 20, 20, 3])
out_backprop = array_ops.placeholder(
dtypes.float32, shape=[32, 3, 2, 2])
out_backprop_val = np.ones([32, 3, 2, 2])
sess.run(
nn_ops.conv2d_backprop_filter(
input_placeholder, [18, 18, 3, 2],
out_backprop,
strides=[1, 1, 1, 1],
padding=[[0, 0], [-1, 0], [0, 0], [0, 0]]),
feed_dict={
input_placeholder: input_val,
out_backprop: out_backprop_val
})
# Illegal strides.
with self.assertRaisesRegex((ValueError, errors_impl.UnimplementedError),
"strides in the batch and depth"):
input_val = np.ones([2, 4, 10, 10])
filter_val = np.ones([2, 4, 10, 10])
self.evaluate(
nn_ops.conv2d(
input_val, filter_val, strides=[2, 1, 1, 1], padding="SAME"))
with self.assertRaisesRegex((ValueError, errors_impl.UnimplementedError),
"strides in the batch and depth"):
input_val = np.ones([2, 4, 10, 10])
filter_val = np.ones([2, 4, 10, 10])
self.evaluate(
nn_ops.conv2d(
input_val, filter_val, strides=[1, 1, 1, 2], padding="SAME"))

# TODO(b/195689143): Will enable when fixed for V2 behavior
# # Filter larger than input.
# with self.assertRaisesRegex(ValueError, "Negative dimension size"):
# input_val = np.ones([32, 20, 20, 3])
# filter_val = np.ones([20, 21, 3, 2])
# self.evaluate(
# nn_ops.conv2d(
# input_val, filter_val, strides=[1, 1, 1, 1], padding="VALID"))
# with self.assertRaisesRegex(ValueError, "Negative dimension size"):
# input_val = np.ones([32, 20, 20, 3])
# filter_val = np.ones([21, 20, 3, 2])
# self.evaluate(
# nn_ops.conv2d(
# input_val, filter_val, strides=[1, 1, 1, 1], padding="VALID"))
#
# # Filter larger than input + padding.
# with self.assertRaisesRegex(ValueError, "Negative dimension size"):
# input_val = np.ones([32, 20, 20, 3])
# filter_val = np.ones([24, 25, 3, 2])
# self.evaluate(
# nn_ops.conv2d(
# input_val,
# filter_val,
# strides=[1, 1, 1, 1],
# padding=[[0, 0], [2, 2], [2, 2], [0, 0]]))

# Filter dimensions must be greater than 0.
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError, "filter must not have zero elements"
"|has a non-positive dimension"):
input_val = np.ones([1, 1, 1, 1])
filter_val = np.ones([1, 0, 1, 1])
self.evaluate(
nn_ops.conv2d(
input_val, filter_val, strides=[1, 1, 1, 1], padding="SAME"))

# Negative padding during backprop.
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
"All elements of explicit_paddings must be nonnegative"):
filter_val = np.ones([18, 18, 3, 2])
out_backprop_val = np.ones([32, 3, 2, 2])
self.evaluate(
nn_ops.conv2d_backprop_input([32, 20, 20, 3],
filter_val,
out_backprop_val,
strides=[1, 1, 1, 1],
padding=[[0, 0], [-1, 0], [0, 0], [0,
0]]))
with self.assertRaisesRegex(
errors_impl.InvalidArgumentError,
"All elements of explicit_paddings must be nonnegative"):
input_val = np.ones([32, 20, 20, 3])
out_backprop_val = np.ones([32, 3, 2, 2])
self.evaluate(
nn_ops.conv2d_backprop_filter(
input_val, [18, 18, 3, 2],
out_backprop_val,
strides=[1, 1, 1, 1],
padding=[[0, 0], [-1, 0], [0, 0], [0, 0]]))


class DepthwiseConv2DTest(test.TestCase):
Expand All @@ -2664,10 +2608,10 @@ def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, stride, padding,
"""Verifies the output values of the convolution function.
Args:
tensor_in_sizes: Input tensor dimensions in
[batch, input_rows, input_cols, input_depth].
filter_in_sizes: Filter tensor dimensions in
[filter_rows, filter_cols, input_depth, depth_multiplier].
tensor_in_sizes: Input tensor dimensions in [batch, input_rows,
input_cols, input_depth].
filter_in_sizes: Filter tensor dimensions in [filter_rows, filter_cols,
input_depth, depth_multiplier].
stride: Stride.
padding: Padding type.
expected: An array containing the expected operation outputs.
Expand Down

0 comments on commit 6b54dd0

Please sign in to comment.