Skip to content

Commit

Permalink
Merge pull request #53927 from tensorflow/cherrypick-965b97e4a9650495…
Browse files Browse the repository at this point in the history
…cda5a8c210ef6684b4b9eceb-on-r2.7

Properly validate sparse tensor in `SparseTensorSliceDataset`
  • Loading branch information
mihaimaruseac committed Jan 24, 2022
2 parents de220cf + 8774b05 commit c7f78a6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 19 deletions.
39 changes: 20 additions & 19 deletions tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
Expand Up @@ -238,28 +238,29 @@ class SparseTensorSliceDatasetOp : public DatasetOpKernel {
OP_REQUIRES_OK(ctx, ctx->input("dense_shape", &dense_shape));

OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(indices->shape()),
errors::InvalidArgument(
"Input indices should be a matrix but received shape ",
indices->shape().DebugString()));

const auto num_indices = indices->NumElements();
const auto num_values = values->NumElements();
if (num_indices == 0 || num_values == 0) {
OP_REQUIRES(ctx, num_indices == num_values,
errors::InvalidArgument(
"If indices or values are empty, the other one must also "
"be. Got indices of shape ",
indices->shape().DebugString(), " and values of shape ",
values->shape().DebugString()));
}
errors::InvalidArgument("Input indices must be a matrix. Got: ",
indices->shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(values->shape()),
errors::InvalidArgument(
"Input values should be a vector but received shape ",
indices->shape().DebugString()));
errors::InvalidArgument("Input values must be a vector. Got: ",
values->shape().DebugString()));
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(dense_shape->shape()),
errors::InvalidArgument("Input shape must be a vector. Got: ",
dense_shape->shape().DebugString()));
OP_REQUIRES(
ctx, values->shape().dim_size(0) == indices->shape().dim_size(0),
errors::InvalidArgument(
"Number of values must match first dimension of indices. ", "Got ",
values->shape().dim_size(0),
" values, indices shape: ", indices->shape().DebugString()));
OP_REQUIRES(
ctx, dense_shape->shape().dim_size(0) == indices->shape().dim_size(1),
errors::InvalidArgument(
"Number of dimensions must match second dimension of indices. ",
"Got ", dense_shape->shape().dim_size(0),
" dimensions, indices shape: ", indices->shape().DebugString()));
OP_REQUIRES(ctx, dense_shape->NumElements() > 0,
errors::InvalidArgument(
"Input shape should be a vector but received shape ",
dense_shape->shape().DebugString()));
"The shape argument requires at least one element."));

// We currently ensure that `sparse_tensor` is ordered in the
// batch dimension.
Expand Down
Expand Up @@ -138,6 +138,25 @@ def testEmptySparseTensorSlicesInvalid(self):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={st: sparse_feed})

@combinations.generate(combinations.combine(tf_api_version=1, mode=["graph"]))
def testEmptySparseTensorSlicesInvalid2(self):
"""Test a dataset based on invalid `tf.sparse.SparseTensor`."""
st = array_ops.sparse_placeholder(dtypes.float64)
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.from_sparse_tensor_slices(st))
init_op = iterator.initializer

with self.cached_session() as sess:
# Test with an empty sparse tensor but with non empty values.
empty_indices = [[]]
empty_values = []
dense_shape = [1, 1]
sparse_feed = sparse_tensor.SparseTensorValue(empty_indices, empty_values,
dense_shape)
# Here, we expect the test to fail when running the feed.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={st: sparse_feed})

@combinations.generate(combinations.combine(tf_api_version=2, mode=["eager"]))
def testFromSparseTensorSlicesError(self):
with self.assertRaises(AttributeError):
Expand Down

0 comments on commit c7f78a6

Please sign in to comment.