Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix failed check in SparseTensorToCSRSparseMatrix
Security vulnerability fix. A `CHECK` fails if inputing either an empty `dense_shape`,
or a non-rank-2 `indices`. Added appropriate checks and tests.

PiperOrigin-RevId: 446053984
  • Loading branch information
cantonios authored and tensorflower-gardener committed May 2, 2022
1 parent 68ae78b commit ea50a40
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
Expand Up @@ -67,6 +67,13 @@ class SparseTensorToCSRSparseMatrixCPUOp : public OpKernel {
const Tensor& values = ctx->input(1);
const Tensor& dense_shape = ctx->input(2);
const int rank = dense_shape.NumElements();
OP_REQUIRES(
ctx, TensorShapeUtils::IsVector(dense_shape.shape()),
errors::InvalidArgument("dense_shape must be rank 1 but got rank",
dense_shape.shape().dims()));
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(indices.shape()),
errors::InvalidArgument("indices must be rank 2 but got rank",
indices.shape().dims()));
OP_REQUIRES(ctx, rank == 2 || rank == 3,
errors::InvalidArgument("SparseTensor must have rank 2 or 3; ",
"but indices has rank: ", rank));
Expand Down
Expand Up @@ -168,6 +168,25 @@ def testSparseTensorConversion(self):
self.assertAllClose(a_values, a_st_rt_value.values)
self.assertAllEqual(a_dense_shape, a_st_rt_value.dense_shape)

def testSparseTensorConversionInvalidInputShapes(self):
values = constant_op.constant(
0.554979503, shape=[5], dtype=dtypes.float32)
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"must be rank 1"):
indices = constant_op.constant(0, shape=[5, 2], dtype=dtypes.int64)
dense_shape = constant_op.constant(53, shape=[], dtype=dtypes.int64)
csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
indices=indices, values=values, dense_shape=dense_shape)
self.evaluate(csr)

with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
"must be rank 2"):
indices = constant_op.constant(0, shape=[5], dtype=dtypes.int64)
dense_shape = constant_op.constant(53, shape=[1], dtype=dtypes.int64)
csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(
indices=indices, values=values, dense_shape=dense_shape)
self.evaluate(csr)

# TODO(b/139491352): Add handle_data propagation to array_ops.identity.
@test_util.run_deprecated_v1
def testCSRSparseMatrixResourceVariable(self):
Expand Down

0 comments on commit ea50a40

Please sign in to comment.