Skip to content

Commit

Permalink
Merge pull request #49566 from geetachavan1/cherrypicks_XS4JA
Browse files Browse the repository at this point in the history
Fix heap buffer overflow in tf.raw_ops.UnicodeEncode.
  • Loading branch information
mihaimaruseac committed May 27, 2021
2 parents 9d2b67c + 6fb0006 commit f113232
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions tensorflow/core/kernels/unicode_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,17 @@ class UnicodeEncodeOp : public OpKernel {
const Tensor& input_splits = context->input(1);
const auto input_splits_flat = input_splits.flat<SPLITS_TYPE>();

// Operation will treat first argument in input_splits as if it were zero
// regardless of its actual value since splits should begin with zero and
// end with the length of the input values vector.
OP_REQUIRES(
context, input_splits_flat(0) == 0,
errors::InvalidArgument("First value in input_splits must be zero."));
OP_REQUIRES(context,
input_splits_flat(input_splits_flat.size() - 1) ==
input_tensor_flat.size(),
errors::InvalidArgument("Last value in input_splits must be "
"equal to length of input_tensor."));
// Since we limit to a 2-D input (flat_values of rank 1 and a single splits
// tensor), our output dimension will be 1 with it's size equal to the
// number of splits (outer dimension or ragged tensor).
Expand All @@ -548,6 +559,14 @@ class UnicodeEncodeOp : public OpKernel {
for (int i = 1; i < input_splits_flat.size(); ++i) {
icu::UnicodeString unicode_string;
icu::UnicodeStringAppendable appendable_unicode_string(unicode_string);
OP_REQUIRES(
context, input_splits_flat(i - 1) <= input_splits_flat(i),
errors::InvalidArgument(
"Values in input_splits must be equal or in ascending order."));
OP_REQUIRES(
context, input_splits_flat(i) <= input_tensor_flat.size(),
errors::InvalidArgument("Values in input_splits must be less than or "
"equal to input_tensor length."));
for (; idx < input_splits_flat(i); ++idx) {
int32 code_point = input_tensor_flat(idx);
// Check for invalid code point
Expand Down

0 comments on commit f113232

Please sign in to comment.