Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix multiple vulnerabilities in tf.raw_ops.*CountSparseOutput.
Also add tests for these API points, both for the happy paths and for the vulnerable ones.

PiperOrigin-RevId: 332563222
Change-Id: Ib3b52116a83a134c2e742a7c66e5e956db8fba05
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Sep 19, 2020
1 parent 4eab87c commit 3cbb917
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tensorflow/core/kernels/count_ops.cc
Expand Up @@ -178,10 +178,30 @@ class SparseCount : public OpKernel {
const Tensor& weights = context->input(3);
bool use_weights = weights.NumElements() > 0;

OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()),
errors::InvalidArgument(
"Input indices must be a 2-dimensional tensor. Got: ",
indices.shape().DebugString()));

if (use_weights) {
OP_REQUIRES(
context, weights.shape() == values.shape(),
errors::InvalidArgument(
"Weights and values must have the same shape. Weight shape: ",
weights.shape().DebugString(),
"; values shape: ", values.shape().DebugString()));
}

bool is_1d = shape.NumElements() == 1;
int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
int num_values = values.NumElements();

OP_REQUIRES(context, num_values == indices.shape().dim_size(0),
errors::InvalidArgument(
"Number of values must match first dimension of indices.",
"Got ", num_values,
" values, indices shape: ", indices.shape().DebugString()));

const auto indices_values = indices.matrix<int64>();
const auto values_values = values.flat<T>();
const auto weight_values = weights.flat<W>();
Expand Down Expand Up @@ -235,12 +255,33 @@ class RaggedCount : public OpKernel {
bool use_weights = weights.NumElements() > 0;
bool is_1d = false;

if (use_weights) {
OP_REQUIRES(
context, weights.shape() == values.shape(),
errors::InvalidArgument(
"Weights and values must have the same shape. Weight shape: ",
weights.shape().DebugString(),
"; values shape: ", values.shape().DebugString()));
}

const auto splits_values = splits.flat<int64>();
const auto values_values = values.flat<T>();
const auto weight_values = weights.flat<W>();
int num_batches = splits.NumElements() - 1;
int num_values = values.NumElements();

OP_REQUIRES(
context, num_batches > 0,
errors::InvalidArgument(
"Must provide at least 2 elements for the splits argument"));
OP_REQUIRES(context, splits_values(0) == 0,
errors::InvalidArgument("Splits must start with 0, not with ",
splits_values(0)));
OP_REQUIRES(context, splits_values(num_batches) == num_values,
errors::InvalidArgument(
"Splits must end with the number of values, got ",
splits_values(num_batches), " instead of ", num_values));

auto per_batch_counts = BatchedMap<W>(num_batches);
T max_value = 0;
int batch_idx = 0;
Expand Down
118 changes: 118 additions & 0 deletions tensorflow/python/ops/bincount_ops_test.py
Expand Up @@ -25,7 +25,9 @@
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import bincount_ops
from tensorflow.python.ops import gen_count_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
Expand Down Expand Up @@ -834,5 +836,121 @@ def test_ragged_input_different_shape_fails(self):
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))


@test_util.run_all_in_graph_and_eager_modes
@test_util.disable_tfrt
class RawOpsTest(test.TestCase, parameterized.TestCase):

def testSparseCountSparseOutputBadIndicesShape(self):
indices = [[[0], [0]], [[0], [1]], [[1], [0]], [[1], [2]]]
values = [1, 1, 1, 10]
weights = [1, 2, 4, 6]
dense_shape = [2, 3]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Input indices must be a 2-dimensional tensor"):
self.evaluate(
gen_count_ops.SparseCountSparseOutput(
indices=indices,
values=values,
dense_shape=dense_shape,
weights=weights,
binary_output=False))

def testSparseCountSparseOutputBadWeightsShape(self):
indices = [[0, 0], [0, 1], [1, 0], [1, 2]]
values = [1, 1, 1, 10]
weights = [1, 2, 4]
dense_shape = [2, 3]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Weights and values must have the same shape"):
self.evaluate(
gen_count_ops.SparseCountSparseOutput(
indices=indices,
values=values,
dense_shape=dense_shape,
weights=weights,
binary_output=False))

def testSparseCountSparseOutputBadNumberOfValues(self):
indices = [[0, 0], [0, 1], [1, 0]]
values = [1, 1, 1, 10]
weights = [1, 2, 4, 6]
dense_shape = [2, 3]
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"Number of values must match first dimension of indices"):
self.evaluate(
gen_count_ops.SparseCountSparseOutput(
indices=indices,
values=values,
dense_shape=dense_shape,
weights=weights,
binary_output=False))

def testRaggedCountSparseOutput(self):
splits = [0, 4, 7]
values = [1, 1, 2, 1, 2, 10, 5]
weights = [1, 2, 3, 4, 5, 6, 7]
output_indices, output_values, output_shape = self.evaluate(
gen_count_ops.RaggedCountSparseOutput(
splits=splits, values=values, weights=weights, binary_output=False))
self.assertAllEqual([[0, 1], [0, 2], [1, 2], [1, 5], [1, 10]],
output_indices)
self.assertAllEqual([7, 3, 5, 7, 6], output_values)
self.assertAllEqual([2, 11], output_shape)

def testRaggedCountSparseOutputBadWeightsShape(self):
splits = [0, 4, 7]
values = [1, 1, 2, 1, 2, 10, 5]
weights = [1, 2, 3, 4, 5, 6]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Weights and values must have the same shape"):
self.evaluate(
gen_count_ops.RaggedCountSparseOutput(
splits=splits,
values=values,
weights=weights,
binary_output=False))

def testRaggedCountSparseOutputEmptySplits(self):
splits = []
values = [1, 1, 2, 1, 2, 10, 5]
weights = [1, 2, 3, 4, 5, 6, 7]
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"Must provide at least 2 elements for the splits argument"):
self.evaluate(
gen_count_ops.RaggedCountSparseOutput(
splits=splits,
values=values,
weights=weights,
binary_output=False))

def testRaggedCountSparseOutputBadSplitsStart(self):
splits = [1, 7]
values = [1, 1, 2, 1, 2, 10, 5]
weights = [1, 2, 3, 4, 5, 6, 7]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Splits must start with 0"):
self.evaluate(
gen_count_ops.RaggedCountSparseOutput(
splits=splits,
values=values,
weights=weights,
binary_output=False))

def testRaggedCountSparseOutputBadSplitsEnd(self):
splits = [0, 5]
values = [1, 1, 2, 1, 2, 10, 5]
weights = [1, 2, 3, 4, 5, 6, 7]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Splits must end with the number of values"):
self.evaluate(
gen_count_ops.RaggedCountSparseOutput(
splits=splits,
values=values,
weights=weights,
binary_output=False))


if __name__ == "__main__":
test.main()

0 comments on commit 3cbb917

Please sign in to comment.