Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorflow/core/ops/sparse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/errors.h"

namespace tensorflow {
Expand Down Expand Up @@ -159,6 +160,8 @@ REGISTER_OP("DeserializeSparse")
.Attr("Tserialized: {string, variant} = DT_STRING")
.SetShapeFn([](InferenceContext* c) {
// serialized sparse is [?, ..., ?, 3] vector.
ShapeHandle unused_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused_shape));
DimensionHandle unused;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &unused));
c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

import numpy as np

from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test

Expand Down Expand Up @@ -464,6 +466,18 @@ def testDeserializeManyFailsInvalidProto(self):
self._testDeserializeFailsInvalidProtoHelper(
sparse_ops.serialize_sparse, sparse_ops.deserialize_many_sparse)

def testDeserializeInvalidVariant(self):
mu = gen_resource_variable_ops.mutex_v2()
mu_lock = gen_resource_variable_ops.mutex_lock(mutex=mu)

@def_function.function
def f():
return sparse_ops.deserialize_sparse(
serialized_sparse=mu_lock, dtype=dtypes.int32)

with self.assertRaisesRegex(ValueError, r"Shape must be at least rank 1"):
f()


if __name__ == "__main__":
test.main()