Skip to content

Commit

Permalink
Merge pull request #49701 from geetachavan1/cherrypicks_ZNR2C
Browse files Browse the repository at this point in the history
Prevent overflow in sparse op
  • Loading branch information
mihaimaruseac committed May 26, 2021
2 parents 01e36a0 + 5788883 commit e4de6e9
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions tensorflow/core/kernels/sparse_split_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,18 @@ class SparseSplitOp : public OpKernel {
input_shape.vec<int64>()(split_dim), "), got ",
num_split_));

// Prevent overflow by constructing the dense shape separately
TensorShape dense_shape;
const auto input_shape_flat = input_shape.flat<int64>();
for (int i = 0; i < input_shape.NumElements(); i++) {
OP_REQUIRES_OK(context,
dense_shape.AddDimWithStatus(input_shape_flat(i)));
}

sparse::SparseTensor sparse_tensor;
OP_REQUIRES_OK(context,
sparse::SparseTensor::Create(
input_indices, input_values,
TensorShape(input_shape.vec<int64>()), &sparse_tensor));
sparse::SparseTensor::Create(input_indices, input_values,
dense_shape, &sparse_tensor));

std::vector<sparse::SparseTensor> outputs;
OP_REQUIRES_OK(context,
Expand Down

0 comments on commit e4de6e9

Please sign in to comment.