Skip to content

Commit

Permalink
BUG: KDTree balanced_tree is unbalanced for degenerate data
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed Jul 5, 2021
1 parent 5ddd07a commit a69718a
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 50 deletions.
87 changes: 43 additions & 44 deletions scipy/spatial/ckdtree/src/build.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -105,55 +105,55 @@ build(ckdtree *self, ckdtree_intp_t start_idx, intptr_t end_idx,
else {
/* split with the sliding midpoint rule */
split = (maxval + minval) / 2;
}

p = start_idx;
q = end_idx - 1;
while (p <= q) {
if (data[indices[p] * m + d] < split)
++p;
else if (data[indices[q] * m + d] >= split)
--q;
else {
ckdtree_intp_t t = indices[p];
indices[p] = indices[q];
indices[q] = t;
++p;
--q;
p = start_idx;
q = end_idx - 1;
while (p <= q) {
if (data[indices[p] * m + d] < split)
++p;
else if (data[indices[q] * m + d] >= split)
--q;
else {
ckdtree_intp_t t = indices[p];
indices[p] = indices[q];
indices[q] = t;
++p;
--q;
}
}
}
/* slide midpoint if necessary */
if (p == start_idx) {
/* no points less than split */
j = start_idx;
split = data[indices[j] * m + d];
for (i = start_idx+1; i < end_idx; ++i) {
if (data[indices[i] * m + d] < split) {
j = i;
split = data[indices[j] * m + d];
/* slide midpoint if necessary */
if (p == start_idx) {
/* no points less than split */
j = start_idx;
split = data[indices[j] * m + d];
for (i = start_idx+1; i < end_idx; ++i) {
if (data[indices[i] * m + d] < split) {
j = i;
split = data[indices[j] * m + d];
}
}
ckdtree_intp_t t = indices[start_idx];
indices[start_idx] = indices[j];
indices[j] = t;
p = start_idx + 1;
q = start_idx;
}
ckdtree_intp_t t = indices[start_idx];
indices[start_idx] = indices[j];
indices[j] = t;
p = start_idx + 1;
q = start_idx;
}
else if (p == end_idx) {
/* no points greater than split */
j = end_idx - 1;
split = data[indices[j] * m + d];
for (i = start_idx; i < end_idx-1; ++i) {
if (data[indices[i] * m + d] > split) {
j = i;
split = data[indices[j] * m + d];
else if (p == end_idx) {
/* no points greater than split */
j = end_idx - 1;
split = data[indices[j] * m + d];
for (i = start_idx; i < end_idx-1; ++i) {
if (data[indices[i] * m + d] > split) {
j = i;
split = data[indices[j] * m + d];
}
}
ckdtree_intp_t t = indices[end_idx-1];
indices[end_idx-1] = indices[j];
indices[j] = t;
p = end_idx - 1;
q = end_idx - 2;
}
ckdtree_intp_t t = indices[end_idx-1];
indices[end_idx-1] = indices[j];
indices[j] = t;
p = end_idx - 1;
q = end_idx - 2;
}

if (CKDTREE_LIKELY(_compact)) {
Expand Down Expand Up @@ -244,4 +244,3 @@ build_weights (ckdtree *self, double *node_weights, double *weights)
add_weights(self, node_weights, 0, weights);
return 0;
}

11 changes: 5 additions & 6 deletions scipy/spatial/tests/test_kdtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,16 +1254,15 @@ def test_kdtree_duplicated_inputs(kdtree_type):
# it shall not divide more than 3 nodes.
# root left (1), and right (2)
kdtree = kdtree_type(data, leafsize=1)
assert_equal(kdtree.size, 3)
assert kdtree.size == 3

kdtree = kdtree_type(data)
assert_equal(kdtree.size, 3)
assert kdtree.size == 3

# if compact_nodes are disabled, the number
# of nodes is n (per leaf) + (m - 1)* 2 (splits per dimension) + 1
# and the root
# if compact_nodes is disabled, the maximum number
# of nodes is that of a balanced tree (2 * n - 1)
kdtree = kdtree_type(data, compact_nodes=False, leafsize=1)
assert_equal(kdtree.size, n + m * 2 - 1)
assert kdtree.size <= 2 * n - 1

def test_kdtree_noncumulative_nondecreasing(kdtree_type):
# check kdtree with duplicated inputs
Expand Down

0 comments on commit a69718a

Please sign in to comment.