From 73de58ca7226cf6711f910733337da4333acc9bd Mon Sep 17 00:00:00 2001 From: Jun Xu Date: Wed, 14 Sep 2022 14:54:10 -0700 Subject: [PATCH] [Security] Add a check for empty variant tensor input to CompositeTensorVariantToComponents. So an exception will be raised instead of segfault. PiperOrigin-RevId: 474397914 --- tensorflow/core/kernels/composite_tensor_ops.cc | 6 ++++++ .../kernel_tests/composite_tensor_ops_test.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/tensorflow/core/kernels/composite_tensor_ops.cc b/tensorflow/core/kernels/composite_tensor_ops.cc index f41b02991bba43..bc4f96e6bb2fe6 100644 --- a/tensorflow/core/kernels/composite_tensor_ops.cc +++ b/tensorflow/core/kernels/composite_tensor_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/kernels/composite_tensor_variant.h" @@ -66,6 +67,11 @@ class CompositeTensorVariantToComponents : public OpKernel { void Compute(OpKernelContext* context) override { Tensor encoded_t = context->input(0); + OP_REQUIRES( + context, encoded_t.flat().size() > 0, + errors::InvalidArgument("Input `encoded` must not be an empty variant " + "tensor, but got ", + encoded_t.DebugString())); auto* encoded = encoded_t.flat()(0).get(); // Check that the encoded TypeSpec is compatible with the expected TypeSpec. diff --git a/tensorflow/python/kernel_tests/composite_tensor_ops_test.py b/tensorflow/python/kernel_tests/composite_tensor_ops_test.py index e5e9d1ef9bf6d9..7a10cae3ebc10a 100644 --- a/tensorflow/python/kernel_tests/composite_tensor_ops_test.py +++ b/tensorflow/python/kernel_tests/composite_tensor_ops_test.py @@ -18,11 +18,13 @@ from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import composite_tensor_ops +from tensorflow.python.ops import gen_composite_tensor_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops @@ -83,6 +85,18 @@ def testEncodingErrors(self, value, spec, message): with self.assertRaisesRegex(ValueError, message): composite_tensor_ops.composite_tensor_to_variants(value(), spec) + def testDecodingEmptyNonScalarTensorError(self): + if not context.executing_eagerly(): + # Creating a variant tensor of an empty list is not allowed in eager mode. + return + + with self.assertRaisesRegex(errors.InvalidArgumentError, + 'must not be an empty variant tensor'): + gen_composite_tensor_ops.CompositeTensorVariantToComponents( + encoded=constant_op.constant([], dtype=dtypes.variant), + metadata='', + Tcomponents=[dtypes.int32]) + def testRoundTripThroughTensorProto(self): value = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]]) encoded = composite_tensor_ops.composite_tensor_to_variants(value)