Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A negative size in one of the split sizes allowed the computed size o… #52701

Merged
merged 1 commit into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions tensorflow/core/kernels/split_v_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ class SplitVOpBase : public OpKernel {
(*split_sizes_vec)[neg_one_dim] = input_size_split_dim - determined_size;
}

for (int i = 0; i < split_sizes_vec->size(); ++i) {
const Tlen& split_size = (*split_sizes_vec)[i];
OP_REQUIRES(context, split_size >= Tlen(0),
errors::InvalidArgument("Split size at index ", i,
" must be >= 0. Got: ", split_size));
}

// Special case 2: split along the 1st dimension. The requirements are that
// either we are splitting the outer dimension of two or more such that
// every outer subpart is aligned or that the split sizes mean that they are
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/ops/array_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,12 @@ REGISTER_OP("SplitV")
if (data[i] == -1 && c->ValueKnown(split_dim_size)) {
size = split_dim_size - total_size;
}
// If we have a negative known size (either explicit, or computed
// via -1), then the split sizes are invalid.
if (size < -1 || (size == -1 && c->ValueKnown(split_dim_size))) {
return errors::InvalidArgument("Split size at index ", i,
" must be >= 0. Got: ", size);
}
TF_RETURN_IF_ERROR(
c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape));
c->set_output(i, output_shape);
Expand Down
18 changes: 18 additions & 0 deletions tensorflow/python/kernel_tests/split_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,24 @@ def testNonexistentDimTensor(self):
"must have exactly one element"):
sess.run(y, {x: np.array([], dtype=np.int32), splits: [4, 11, 15]})

@test_util.run_in_graph_and_eager_modes
def testNegativeSizes(self):
x = constant_op.constant([1, 2, 3], dtypes.float32)
# A size of -1 signifies to determine size based on sum of other splits.
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
"Split size at index 1 must be >= 0. Got: -2"):
splits = [-1, -2]
self.evaluate(array_ops.split(x, splits, axis=0))

@test_util.run_in_graph_and_eager_modes
def testBadSplitSizes(self):
x = constant_op.constant([1, 2], dtypes.float32)
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
"Determined shape must either match input"
"|can't split axis"):
splits = [1, 2]
self.evaluate(array_ops.split(x, splits, axis=0))


if __name__ == "__main__":
test.main()