Skip to content
Permalink
Browse files Browse the repository at this point in the history
Disallow dims input of 0 in tf.raw_ops.UnravelIndex
PiperOrigin-RevId: 384284198
Change-Id: Ia1804ef1aec57b4d857ea507e6891bcccde18e9b
  • Loading branch information
pak-laura authored and tensorflower-gardener committed Jul 12, 2021
1 parent f9e7f42 commit a776040
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
8 changes: 8 additions & 0 deletions tensorflow/core/kernels/unravel_index_op.cc
Expand Up @@ -53,6 +53,14 @@ class UnravelIndexOp : public OpKernel {
dims_tensor.shape().DebugString(), "\""));

auto dims = dims_tensor.vec<Tidx>();
// Make sure dims does not contain a zero
for (int i = 0; i < dims.size(); i++) {
OP_REQUIRES(
ctx, dims(i) != 0,
errors::InvalidArgument("Input dims cannot contain a dim of zero, "
"but dims contains zero at index ",
i));
}

// Chek to make sure indices is not out of boundary
Eigen::Tensor<Tidx, 0, Eigen::RowMajor> dims_prod_eigen = dims.prod();
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/kernel_tests/array_ops_test.py
Expand Up @@ -1575,7 +1575,7 @@ def testUnravelIndexZeroDim(self):
with self.cached_session():
for dtype in [dtypes.int32, dtypes.int64]:
with self.assertRaisesRegex(errors.InvalidArgumentError,
"index is out of bound as with dims"):
"dims cannot contain a dim of zero"):
indices = constant_op.constant([2, 5, 7], dtype=dtype)
dims = constant_op.constant([3, 0], dtype=dtype)
self.evaluate(array_ops.unravel_index(indices=indices, dims=dims))
Expand Down

0 comments on commit a776040

Please sign in to comment.