Skip to content
Permalink
Browse files Browse the repository at this point in the history
Validate data_splits for tf.StringNGrams.
Without validation, we can cause a heap buffer overflow which results in data leakage and/or segfaults.

PiperOrigin-RevId: 332543478
Change-Id: Iee5bda24497a195d09d122355502480830b1b317
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Sep 18, 2020
1 parent 57f0a5e commit 0462de5
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
13 changes: 13 additions & 0 deletions tensorflow/core/kernels/string_ngrams_op.cc
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/errors.h"

namespace tensorflow {
namespace text {
Expand Down Expand Up @@ -60,6 +61,18 @@ class StringNGramsOp : public tensorflow::OpKernel {
OP_REQUIRES_OK(context, context->input("data_splits", &splits));
const auto& splits_vec = splits->flat<SPLITS_TYPE>();

// Validate that the splits are valid indices into data
const int input_data_size = data->flat<tstring>().size();
const int splits_vec_size = splits_vec.size();
for (int i = 0; i < splits_vec_size; ++i) {
bool valid_splits = splits_vec(i) >= 0;
valid_splits = valid_splits && (splits_vec(i) <= input_data_size);
OP_REQUIRES(
context, valid_splits,
errors::InvalidArgument("Invalid split value ", splits_vec(i),
", must be in [0,", input_data_size, "]"));
}

int num_batch_items = splits_vec.size() - 1;
tensorflow::Tensor* ngrams_splits;
OP_REQUIRES_OK(
Expand Down
23 changes: 22 additions & 1 deletion tensorflow/python/ops/raw_ops_test.py
Expand Up @@ -18,16 +18,21 @@
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized

from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.platform import test


@test_util.run_all_in_graph_and_eager_modes
class RawOpsTest(test.TestCase):
@test_util.disable_tfrt
class RawOpsTest(test.TestCase, parameterized.TestCase):

def testSimple(self):
x = constant_op.constant(1)
Expand Down Expand Up @@ -58,6 +63,22 @@ def testDefaults(self):
gen_math_ops.Any(input=x, axis=0),
gen_math_ops.Any(input=x, axis=0, keep_dims=False))

@parameterized.parameters([[0, 8]], [[-1, 6]])
def testStringNGramsBadDataSplits(self, splits):
data = ["aa", "bb", "cc", "dd", "ee", "ff"]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Invalid split value"):
self.evaluate(
gen_string_ops.string_n_grams(
data=data,
data_splits=splits,
separator="",
ngram_widths=[2],
left_pad="",
right_pad="",
pad_width=0,
preserve_short_sequences=False))


if __name__ == "__main__":
ops.enable_eager_execution()
Expand Down

0 comments on commit 0462de5

Please sign in to comment.