Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent overflow in sparse op
PiperOrigin-RevId: 372442006
Change-Id: I60fe31cd7e56fb3501e97c63500caf902ddeee96
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed May 6, 2021
1 parent 0908c2f commit 4c0ee93
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions tensorflow/core/kernels/sparse_split_op.cc
Expand Up @@ -63,11 +63,18 @@ class SparseSplitOp : public OpKernel {
input_shape.vec<int64>()(axis),
"), got ", num_split_));

// Prevent overflow by constructing the dense shape separately
TensorShape dense_shape;
const auto input_shape_flat = input_shape.flat<int64>();
for (int i = 0; i < input_shape.NumElements(); i++) {
OP_REQUIRES_OK(context,
dense_shape.AddDimWithStatus(input_shape_flat(i)));
}

sparse::SparseTensor sparse_tensor;
OP_REQUIRES_OK(context,
sparse::SparseTensor::Create(
input_indices, input_values,
TensorShape(input_shape.vec<int64>()), &sparse_tensor));
sparse::SparseTensor::Create(input_indices, input_values,
dense_shape, &sparse_tensor));

std::vector<sparse::SparseTensor> outputs;
OP_REQUIRES_OK(context, sparse::SparseTensor::Split<T>(
Expand Down

0 comments on commit 4c0ee93

Please sign in to comment.