Skip to content
Permalink
Browse files Browse the repository at this point in the history
[Security] Raise an exception when input to CompositeTensorVariantToC…
…omponents is not a valid CompositeTensorVariant tensor.

So TF won't crash.

PiperOrigin-RevId: 474594628
  • Loading branch information
JXRiver authored and tensorflower-gardener committed Sep 15, 2022
1 parent 8d2ce9a commit bf594d0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tensorflow/core/kernels/composite_tensor_ops.cc
Expand Up @@ -73,6 +73,10 @@ class CompositeTensorVariantToComponents : public OpKernel {
"tensor, but got ",
encoded_t.DebugString()));
auto* encoded = encoded_t.flat<Variant>()(0).get<CompositeTensorVariant>();
OP_REQUIRES(context, encoded != nullptr,
errors::InvalidArgument("The input `encoded` is not a valid "
"CompositeTensorVariant tensor, got ",
encoded_t.DebugString()));

// Check that the encoded TypeSpec is compatible with the expected TypeSpec.
// For now, we just check that the class matches.
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/python/kernel_tests/composite_tensor_ops_test.py
Expand Up @@ -25,6 +25,7 @@
from tensorflow.python.framework import test_util
from tensorflow.python.ops import composite_tensor_ops
from tensorflow.python.ops import gen_composite_tensor_ops
from tensorflow.python.ops import gen_list_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
Expand Down Expand Up @@ -97,6 +98,18 @@ def testDecodingEmptyNonScalarTensorError(self):
metadata='',
Tcomponents=[dtypes.int32])

def testDecodingInvalidEncodedInputError(self):
with self.assertRaisesRegex(errors.InvalidArgumentError,
'not a valid CompositeTensorVariant tensor'):
self.evaluate(
gen_composite_tensor_ops.CompositeTensorVariantToComponents(
encoded=gen_list_ops.EmptyTensorList(
element_dtype=dtypes.int32,
element_shape=[1, 2],
max_num_elements=2),
metadata='',
Tcomponents=[dtypes.int32]))

def testRoundTripThroughTensorProto(self):
value = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
encoded = composite_tensor_ops.composite_tensor_to_variants(value)
Expand Down

0 comments on commit bf594d0

Please sign in to comment.