Skip to content

Commit

Permalink
[Security] Add a check for empty variant tensor input to CompositeTen…
Browse files Browse the repository at this point in the history
…sorVariantToComponents.

So an exception will be raised instead of segfault.

PiperOrigin-RevId: 474397914
  • Loading branch information
JXRiver authored and tensorflow-jenkins committed Oct 21, 2022
1 parent ee897ca commit 73de58c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tensorflow/core/kernels/composite_tensor_ops.cc
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op_requires.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/kernels/composite_tensor_variant.h"
Expand Down Expand Up @@ -66,6 +67,11 @@ class CompositeTensorVariantToComponents : public OpKernel {

void Compute(OpKernelContext* context) override {
Tensor encoded_t = context->input(0);
OP_REQUIRES(
context, encoded_t.flat<Variant>().size() > 0,
errors::InvalidArgument("Input `encoded` must not be an empty variant "
"tensor, but got ",
encoded_t.DebugString()));
auto* encoded = encoded_t.flat<Variant>()(0).get<CompositeTensorVariant>();

// Check that the encoded TypeSpec is compatible with the expected TypeSpec.
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/python/kernel_tests/composite_tensor_ops_test.py
Expand Up @@ -18,11 +18,13 @@

from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import composite_tensor_ops
from tensorflow.python.ops import gen_composite_tensor_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
Expand Down Expand Up @@ -83,6 +85,18 @@ def testEncodingErrors(self, value, spec, message):
with self.assertRaisesRegex(ValueError, message):
composite_tensor_ops.composite_tensor_to_variants(value(), spec)

def testDecodingEmptyNonScalarTensorError(self):
if not context.executing_eagerly():
# Creating a variant tensor of an empty list is not allowed in eager mode.
return

with self.assertRaisesRegex(errors.InvalidArgumentError,
'must not be an empty variant tensor'):
gen_composite_tensor_ops.CompositeTensorVariantToComponents(
encoded=constant_op.constant([], dtype=dtypes.variant),
metadata='',
Tcomponents=[dtypes.int32])

def testRoundTripThroughTensorProto(self):
value = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
encoded = composite_tensor_ops.composite_tensor_to_variants(value)
Expand Down

0 comments on commit 73de58c

Please sign in to comment.