Skip to content
Permalink
Browse files Browse the repository at this point in the history
[security] Fix failed shape check in RaggedTensorToVariant.
`row_splits` must have rank 1.

PiperOrigin-RevId: 461915027
  • Loading branch information
JXRiver authored and tensorflower-gardener committed Jul 19, 2022
1 parent a1d8a71 commit 88f93df
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tensorflow/core/kernels/ragged_tensor_to_variant_op.cc
Expand Up @@ -188,6 +188,10 @@ class RaggedTensorToVariantOp : public OpKernel {
batched_ragged_input.mutable_nested_splits()->reserve(
ragged_nested_splits_len);
for (int i = 0; i < ragged_nested_splits_len; i++) {
OP_REQUIRES(context, ragged_nested_splits_in[i].dims() == 1,
errors::InvalidArgument("Requires nested_row_splits[", i, "]",
" to be rank 1 but is rank ",
ragged_nested_splits_in[i].dims()));
batched_ragged_input.append_splits(ragged_nested_splits_in[i]);
}

Expand Down
15 changes: 15 additions & 0 deletions tensorflow/python/ops/ragged/ragged_tensor_test.py
Expand Up @@ -1468,6 +1468,21 @@ def testUnbatchVariantInDataset(self):
for i in range(3):
self.assertAllEqual(sess.run(rt[i]), out)

def testToVariantInvalidParams(self):
self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
r'be rank 1 but is rank 0',
gen_ragged_conversion_ops.ragged_tensor_to_variant,
rt_nested_splits=[0, 1, 2],
rt_dense_values=[0, 1, 2],
batched_input=True)

self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
r'be rank 1 but is rank 2',
gen_ragged_conversion_ops.ragged_tensor_to_variant,
rt_nested_splits=[[[0]], [[1]], [[2]]],
rt_dense_values=[0, 1, 2],
batched_input=True)

def testFromVariantInvalidParams(self):
rt = ragged_factory_ops.constant([[0], [1], [2], [3]])
batched_variant = rt._to_variant(batched_input=True)
Expand Down

0 comments on commit 88f93df

Please sign in to comment.