Skip to content
Permalink
Browse files Browse the repository at this point in the history
Ensuring that the input to DeserializeSparse is not a scalar.
PiperOrigin-RevId: 400554784
Change-Id: Ib658701040d4f707f20b8706e251d5fff46b2671
  • Loading branch information
rohan100jain authored and tensorflower-gardener committed Oct 3, 2021
1 parent 7b596c4 commit d3738dd
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tensorflow/core/ops/sparse_ops.cc
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"

namespace tensorflow {
Expand Down Expand Up @@ -159,6 +160,8 @@ REGISTER_OP("DeserializeSparse")
.Attr("Tserialized: {string, variant} = DT_STRING")
.SetShapeFn([](InferenceContext* c) {
// serialized sparse is [?, ..., ?, 3] vector.
ShapeHandle unused_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused_shape));
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &unused));
c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
Expand Up @@ -16,10 +16,12 @@

import numpy as np

from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test

Expand Down Expand Up @@ -460,6 +462,18 @@ def testDeserializeManyFailsInvalidProto(self):
self._testDeserializeFailsInvalidProtoHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)

def testDeserializeInvalidVariant(self):
mu = gen_resource_variable_ops.mutex_v2()
mu_lock = gen_resource_variable_ops.mutex_lock(mutex=mu)

@def_function.function
def f():
return sparse_ops.deserialize_sparse(
serialized_sparse=mu_lock, dtype=dtypes.int32)

with self.assertRaisesRegex(ValueError, r"Shape must be at least rank 1"):
f()


if __name__ == "__main__":
test.main()

0 comments on commit d3738dd

Please sign in to comment.