Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent nullptr deref in SparseTensorSliceDataset
The arguments must determine a valid sparse tensor. This means that when indices are empty then the values must be empty too (and the reverse).

Also added test, by modifying existing test with empty sparse tensor to now run with an invalid sparse tensor input.

PiperOrigin-RevId: 388562757
Change-Id: Id8b54cd7c2316025b4f9a77292c8fb5344d17609
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Aug 3, 2021
1 parent 234a51a commit 02cc160
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
Expand Up @@ -241,6 +241,17 @@ class SparseTensorSliceDatasetOp : public DatasetOpKernel {
errors::InvalidArgument(
"Input indices should be a matrix but received shape ",
indices->shape().DebugString()));

const auto num_indices = indices->NumElements();
const auto num_values = values->NumElements();
if (num_indices == 0 || num_values == 0) {
OP_REQUIRES(ctx, num_indices == num_values,
errors::InvalidArgument(
"If indices or values are empty, the other one must also "
"be. Got indices of shape ",
indices->shape().DebugString(), " and values of shape ",
values->shape().DebugString()));
}
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(values->shape()),
errors::InvalidArgument(
"Input values should be a vector but received shape ",
Expand Down
Expand Up @@ -118,6 +118,26 @@ def testEmptySparseTensorSlices(self):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

@combinations.generate(combinations.combine(tf_api_version=1, mode=["graph"]))
def testEmptySparseTensorSlicesInvalid(self):
"""Test a dataset based on invalid `tf.sparse.SparseTensor`."""
st = array_ops.sparse_placeholder(dtypes.float64)
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.from_sparse_tensor_slices(st))
init_op = iterator.initializer

with self.cached_session() as sess:
# Test with an empty sparse tensor but with non empty values.
empty_indices = np.empty((0, 4), dtype=np.int64)
non_empty_values = [1, 2, 3, 4]
empty_dense_shape = [0, 4, 37, 9]
sparse_feed = sparse_tensor.SparseTensorValue(empty_indices,
non_empty_values,
empty_dense_shape)
# Here, we expect the test to fail when running the feed.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(init_op, feed_dict={st: sparse_feed})

@combinations.generate(combinations.combine(tf_api_version=2, mode=["eager"]))
def testFromSparseTensorSlicesError(self):
with self.assertRaises(AttributeError):
Expand Down

0 comments on commit 02cc160

Please sign in to comment.