Skip to content

Commit

Permalink
Check for unexpected scalars in the shape argument to ParallelConcat.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 504901518
  • Loading branch information
jsmeredith authored and tensorflower-gardener committed Jan 26, 2023
1 parent 789ed75 commit da66bc6
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/inplace_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class ParallelConcatUpdate : public OpKernel {
OP_REQUIRES(
ctx, value.dim_size(0) > loc_,
errors::InvalidArgument("0th dimension of value = ", value.dim_size(0),
" is less than loc_=", loc_));
" must be greater than loc_ = ", loc_));

auto update = ctx->input(1);

Expand Down
7 changes: 7 additions & 0 deletions tensorflow/core/ops/array_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include <algorithm>
#include <ostream>
#include <vector>

#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/full_type.pb.h"
Expand Down Expand Up @@ -309,6 +310,12 @@ REGISTER_OP("ParallelConcat")
return errors::InvalidArgument(
"All input shapes must be fully defined.");
}
if (c->Rank(c->input(i)) < 1) {
return errors::InvalidArgument(
"The rank of all input shapes must be greater than 0, "
"but input ",
i, " had rank ", c->Rank(c->input(i)), ".");
}
DimensionHandle unused;
if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) {
return errors::InvalidArgument("Size of first dimension must be 1.");
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/python/kernel_tests/array_ops/stack_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ def f():
y = gen_array_ops.parallel_concat(values=[["tf"]], shape=0)
return y

with self.assertRaisesRegex(errors.InvalidArgumentError,
r"0th dimension of value .* is less than"):
with self.assertRaisesRegex(
errors.InvalidArgumentError, r"0th dimension .* must be greater than"
):
f()

def testSimpleParallelGPU(self):
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/python/ops/array_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
Expand Down Expand Up @@ -91,6 +92,20 @@ def g(x):
conc = g.get_concrete_function(tensor_spec.TensorSpec([10, None]))
self.assertAllEqual(conc.output_shapes.as_list(), [10])

@test_util.run_in_graph_and_eager_modes
def testParallelConcatFailsWithRankZeroShape(self):
op = array_ops.ParallelConcat
para = {"shape": 0, "values": [1]}

def func():
y = op(**para)
return y

with self.assertRaisesRegex(
Exception, "(rank|dimension) of .* must be greater than .* 0"
):
func()


if __name__ == "__main__":
test.main()

0 comments on commit da66bc6

Please sign in to comment.