Skip to content

Commit

Permalink
Merge pull request #54075 from tensorflow/cherrypick-894ba6b75bc-on-r2.6
Browse files Browse the repository at this point in the history
Merge pull request #53695 from yongtang:53660-tf.sparse.split-crash
  • Loading branch information
mihaimaruseac committed Jan 25, 2022
2 parents 727391e + c4649e2 commit 1e0187c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tensorflow/core/kernels/sparse_split_op.cc
Expand Up @@ -30,11 +30,16 @@ class SparseSplitOp : public OpKernel {
}

void Compute(OpKernelContext* context) override {
const int64 axis_input = context->input(0).scalar<int64>()();
const Tensor& input_axis = context->input(0);
const Tensor& input_indices = context->input(1);
const Tensor& input_values = context->input(2);
const Tensor& input_shape = context->input(3);

OP_REQUIRES(context, TensorShapeUtils::IsScalar(input_axis.shape()),
errors::InvalidArgument(
"Input axis should be a scalar but received shape ",
input_axis.shape().DebugString()),
done);
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()),
errors::InvalidArgument(
"Input indices should be a matrix but received shape ",
Expand All @@ -48,7 +53,8 @@ class SparseSplitOp : public OpKernel {
"Input shape should be a vector but received shape ",
input_shape.shape().DebugString()));

const int64 input_rank = input_shape.vec<int64>().size();
const int64 axis_input = input_axis.scalar<int64_t>()();
const int64 input_rank = input_shape.vec<int64_t>().size();
const int64 axis = (axis_input < 0) ? input_rank + axis_input : axis_input;

OP_REQUIRES(
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/python/kernel_tests/sparse_split_op_test.py
Expand Up @@ -257,6 +257,15 @@ def testArgumentErrors(self):
with self.assertRaisesRegex(ValueError, 'axis is required'):
sparse_ops.sparse_split(num_split=2, sp_input=1)

def testInvalidArgumentError(self):
# Test case for GitHub issue 53660.
axis = [1, 2]
with self.assertRaisesRegexp(errors.InvalidArgumentError,
r'axis should be a scalar'):
self.evaluate(
sparse_ops.sparse_split(
sp_input=self._SparseTensor_4x6(), num_split=3, axis=axis))


if __name__ == '__main__':
test.main()

0 comments on commit 1e0187c

Please sign in to comment.