From 22d7d2089362cb9798e0acc7465d94a212124412 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Sun, 3 Oct 2021 08:50:06 -0700 Subject: [PATCH] Ensuring that the input to DeserializeSparse is not a scalar. PiperOrigin-RevId: 400554784 Change-Id: Ib658701040d4f707f20b8706e251d5fff46b2671 --- tensorflow/core/ops/sparse_ops.cc | 3 +++ .../kernel_tests/sparse_serialization_ops_test.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc index b1e40e66af8929..c79aeebd94bd33 100644 --- a/tensorflow/core/ops/sparse_ops.cc +++ b/tensorflow/core/ops/sparse_ops.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" namespace tensorflow { @@ -159,6 +160,8 @@ REGISTER_OP("DeserializeSparse") .Attr("Tserialized: {string, variant} = DT_STRING") .SetShapeFn([](InferenceContext* c) { // serialized sparse is [?, ..., ?, 3] vector. + ShapeHandle unused_shape; + TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused_shape)); DimensionHandle unused; TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &unused)); c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, diff --git a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py index 5a48eb825dbfa8..434287679181bb 100644 --- a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py @@ -20,10 +20,12 @@ import numpy as np +from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test @@ -464,6 +466,18 @@ def testDeserializeManyFailsInvalidProto(self): self._testDeserializeFailsInvalidProtoHelper( sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse) + def testDeserializeInvalidVariant(self): + mu = gen_resource_variable_ops.mutex_v2() + mu_lock = gen_resource_variable_ops.mutex_lock(mutex=mu) + + @def_function.function + def f(): + return sparse_ops.deserialize_sparse( + serialized_sparse=mu_lock, dtype=dtypes.int32) + + with self.assertRaisesRegex(ValueError, r"Shape must be at least rank 1"): + f() + if __name__ == "__main__": test.main()