Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two improvements to tf.split's shape function #21113

Merged
merged 2 commits into from
Aug 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 26 additions & 23 deletions tensorflow/core/ops/array_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,38 +631,41 @@ REGISTER_OP("SplitV")
return errors::InvalidArgument(
"Length of size_splits should be equal to num_outputs");
}
int64_t cumsum_outputs = 0;
int64_t total_size = 0;
bool has_neg_one = false;
for (const auto size : data) {
if (size == -1) {
if (has_neg_one) {
return errors::InvalidArgument(
"size_splits can only have one -1");
}
has_neg_one = true;
} else {
total_size += size;
}
}
auto split_dim_size = c->Value(c->Dim(input, split_dim));
// If the sizes of the splits are known, then
// make sure that the sizes add up to the expected
// dimension size, with the possibility of a -1.
// Specify the full output shapes.
for (int i = 0; i < num_outputs; ++i) {
output_shape = c->UnknownShapeOfRank(rank);
TF_RETURN_IF_ERROR(c->ReplaceDim(input, split_dim,
c->MakeDim(data[i]), &output_shape));
auto size = data[i];
if (data[i] == -1 && c->ValueKnown(split_dim_size)) {
size = split_dim_size - total_size;
}
TF_RETURN_IF_ERROR(
c->ReplaceDim(input, split_dim, c->MakeDim(size), &output_shape));
c->set_output(i, output_shape);
if (data[i] == -1 && !has_neg_one)
has_neg_one = true;
else if (data[i] == -1 && has_neg_one)
return errors::InvalidArgument("size_splits can only have one -1");
else
cumsum_outputs += data[i];
}
auto split_dim_size = c->Value(c->Dim(input, split_dim));
if (has_neg_one) {
if (cumsum_outputs < split_dim_size)
cumsum_outputs = split_dim_size;
else
cumsum_outputs = split_dim_size + 1;
if (c->ValueKnown(split_dim_size)) {
if (has_neg_one ? total_size > split_dim_size
: total_size != split_dim_size) {
return errors::InvalidArgument(
"can't split axis of size ", split_dim_size,
" into pieces of size [", str_util::Join(data, ","), "]");
}
}
if (c->ValueKnown(c->Dim(input, split_dim)) &&
cumsum_outputs != c->Value(c->Dim(input, split_dim)))
return errors::InvalidArgument(
"Sum of output sizes must match "
"the size of the original Tensor along the split dimension "
"or the sum of the positive sizes must be less if it contains a "
"-1");
}

return Status::OK();
Expand Down
30 changes: 30 additions & 0 deletions tensorflow/python/kernel_tests/split_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,26 @@ def testSpecialCasesVariable(self):
for dtype in _TEST_DTYPES:
self._testHugeNumberOfTensorsVariable(dtype)

@test_util.run_in_graph_and_eager_modes
def testDegenerateVariable(self):
inp = np.random.rand(4, 4).astype("f")
with test_util.device(use_gpu=True):
result = self.evaluate(array_ops.split(inp, [-1, 4], 0))
self.assertAllEqual(result[0], inp[0:0, :])
self.assertAllEqual(result[1], inp[0:4, :])

result = self.evaluate(array_ops.split(inp, [4, -1], 0))
self.assertAllEqual(result[0], inp[0:4, :])
self.assertAllEqual(result[1], inp[4:4, :])

result = self.evaluate(array_ops.split(inp, [-1, 4], 1))
self.assertAllEqual(result[0], inp[:, 0:0])
self.assertAllEqual(result[1], inp[:, 0:4])

result = self.evaluate(array_ops.split(inp, [4, -1], 1))
self.assertAllEqual(result[0], inp[:, 0:4])
self.assertAllEqual(result[1], inp[:, 4:4])

def _testGradientsSimpleVariable(self, dtype):
inp = self._makeData((4, 4), dtype)
with test_util.device(use_gpu=True):
Expand Down Expand Up @@ -336,6 +356,16 @@ def testShapeFunctionEdgeCases(self):
for s in splits:
self.assertEqual(None, s.get_shape().ndims)

def testVariableShapeFunction(self):
# size_splits too big
with self.assertRaises(ValueError):
array_ops.split([0, 1], [3, -1], axis=0)

# Correct inference of variable dimension
s0, s1 = array_ops.split([0, 1, 2], [2, -1], axis=0)
assert s0.shape.as_list() == [2]
assert s1.shape.as_list() == [1]

def testNonexistentDimTensor(self):
x = array_ops.placeholder(dtypes.int32)
values = np.zeros([5, 30])
Expand Down