Skip to content

Commit

Permalink
Add skip predicate to inrange, fixes KristofferC#53
Browse files Browse the repository at this point in the history
  • Loading branch information
schmrlng committed Nov 4, 2017
1 parent 012a3da commit 3657a02
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 35 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ idxs, dists = knn(kdtree, point, k, true)
A range search finds all neighbors within the range `r` of given point(s).
This is done with the method:
```jl
inrange(tree, points, r, sortres = false) -> idxs
inrange(tree, points, r, sortres = false, skip = always_false) -> idxs
```
Note that for performance reasons the distances are not returned. The arguments to `inrange` are the same as for `knn` except that `sortres` just sorts the returned index vector.

Expand Down
16 changes: 9 additions & 7 deletions src/ball_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,19 @@ end
function _inrange(tree::BallTree{V},
point::AbstractVector,
radius::Number,
idx_in_ball::Vector{Int}) where {V}
idx_in_ball::Vector{Int},
skip::Function) where {V}
ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball"
inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder
inrange_kernel!(tree, 1, point, ball, idx_in_ball, skip) # Call the recursive range finder
return
end

function inrange_kernel!(tree::BallTree,
index::Int,
point::AbstractVector,
query_ball::HyperSphere,
idx_in_ball::Vector{Int})
idx_in_ball::Vector{Int},
skip::Function)
@NODE 1

if index > length(tree.hyper_spheres)
Expand All @@ -216,17 +218,17 @@ function inrange_kernel!(tree::BallTree,

# At a leaf node, check all points in the leaf node
if isleaf(tree.tree_data.n_internal_nodes, index)
add_points_inrange!(idx_in_ball, tree, index, point, query_ball.r, true)
add_points_inrange!(idx_in_ball, tree, index, point, query_ball.r, true, skip)
return
end

# The query ball encloses the sub tree bounding sphere. Add all points in the
# sub tree without checking the distance function.
if encloses(tree.metric, sphere, query_ball)
addall(tree, index, idx_in_ball)
addall(tree, index, idx_in_ball, skip)
else
# Recursively call the left and right sub tree.
inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball)
inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball)
inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, skip)
inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, skip)
end
end
12 changes: 9 additions & 3 deletions src/brute_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,23 @@ end
function _inrange(tree::BruteTree,
point::AbstractVector,
radius::Number,
idx_in_ball::Vector{Int})
inrange_kernel!(tree, point, radius, idx_in_ball)
idx_in_ball::Vector{Int},
skip::Function)
inrange_kernel!(tree, point, radius, idx_in_ball, skip)
return
end


function inrange_kernel!(tree::BruteTree,
point::AbstractVector,
r::Number,
idx_in_ball::Vector{Int})
idx_in_ball::Vector{Int},
skip::Function)
for i in 1:length(tree.data)
if skip(i)
continue
end

@POINT 1
d = evaluate(tree.metric, tree.data[i], point)
if d <= r
Expand Down
22 changes: 12 additions & 10 deletions src/inrange.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
check_radius(r) = r < 0 && throw(ArgumentError("the query radius r must be ≧ 0"))

"""
inrange(tree::NNTree, points, radius [, sortres=false]) -> indices
inrange(tree::NNTree, points, radius [, sortres=false, skip=always_false]) -> indices
Find all the points in the tree which is closer than `radius` to `points`. If
`sortres = true` the resulting indices are sorted.
`sortres = true` the resulting indices are sorted. `skip` is an optional predicate
to determine if a point that would be returned should be skipped.
"""
function inrange(tree::NNTree,
points::Vector{T},
radius::Number,
sortres=false) where {T <: AbstractVector}
sortres=false,
skip::Function=always_false) where {T <: AbstractVector}
check_input(tree, points)
check_radius(radius)

idxs = [Vector{Int}() for _ in 1:length(points)]

for i in 1:length(points)
inrange_point!(tree, points[i], radius, sortres, idxs[i])
inrange_point!(tree, points[i], radius, sortres, idxs[i], skip)
end
return idxs
end

function inrange_point!(tree, point, radius, sortres, idx)
_inrange(tree, point, radius, idx)
function inrange_point!(tree, point, radius, sortres, idx, skip)
_inrange(tree, point, radius, idx, skip)
if tree.reordered
@inbounds for j in 1:length(idx)
idx[j] = tree.indices[idx[j]]
Expand All @@ -32,21 +34,21 @@ function inrange_point!(tree, point, radius, sortres, idx)
return
end

function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) where {V, T <: Number}
function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false, skip::Function=always_false) where {V, T <: Number}
check_input(tree, point)
check_radius(radius)
idx = Int[]
inrange_point!(tree, point, radius, sortres, idx)
inrange_point!(tree, point, radius, sortres, idx, skip)
return idx
end

function inrange(tree::NNTree{V}, point::Matrix{T}, radius::Number, sortres=false) where {V, T <: Number}
function inrange(tree::NNTree{V}, point::Matrix{T}, radius::Number, sortres=false, skip::Function=always_false) where {V, T <: Number}
dim = size(point, 1)
npoints = size(point, 2)
if isbits(T)
new_data = reinterpret(SVector{dim,T}, point, (length(point) ÷ dim,))
else
new_data = SVector{dim,T}[SVector{dim,T}(point[:, i]) for i in 1:npoints]
end
inrange(tree, new_data, radius, sortres)
inrange(tree, new_data, radius, sortres, skip)
end
14 changes: 8 additions & 6 deletions src/kd_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,11 @@ end
function _inrange(tree::KDTree,
point::AbstractVector,
radius::Number,
idx_in_ball = Int[])
idx_in_ball = Int[],
skip::Function = always_false)
init_min = get_min_distance(tree.hyper_rec, point)
inrange_kernel!(tree, 1, point, eval_op(tree.metric, radius, zero(init_min)), idx_in_ball,
init_min)
init_min, skip)
return
end

Expand All @@ -216,7 +217,8 @@ function inrange_kernel!(tree::KDTree,
point::AbstractVector,
r::Number,
idx_in_ball::Vector{Int},
min_dist)
min_dist,
skip::Function)
@NODE 1
# Point is outside hyper rectangle, skip the whole sub tree
if min_dist > r
Expand All @@ -225,7 +227,7 @@ function inrange_kernel!(tree::KDTree,

# At a leaf node. Go through all points in node and add those in range
if isleaf(tree.tree_data.n_internal_nodes, index)
add_points_inrange!(idx_in_ball, tree, index, point, r, false)
add_points_inrange!(idx_in_ball, tree, index, point, r, false, skip)
return
end

Expand All @@ -247,7 +249,7 @@ function inrange_kernel!(tree::KDTree,
ddiff = max(zero(lo - p_dim), lo - p_dim)
end
# Call closer sub tree
inrange_kernel!(tree, close, point, r, idx_in_ball, min_dist)
inrange_kernel!(tree, close, point, r, idx_in_ball, min_dist, skip)

# TODO: We could potentially also keep track of the max distance
# between the point and the hyper rectangle and add the whole sub tree
Expand All @@ -259,5 +261,5 @@ function inrange_kernel!(tree::KDTree,
ddiff_pow = eval_pow(M, ddiff)
diff_tot = eval_diff(M, split_diff_pow, ddiff_pow)
new_min = eval_reduce(M, min_dist, diff_tot)
inrange_kernel!(tree, far, point, r, idx_in_ball, new_min)
inrange_kernel!(tree, far, point, r, idx_in_ball, new_min, skip)
end
25 changes: 17 additions & 8 deletions src/tree_ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ end
tree::NNTree, index::Int, point::AbstractVector,
do_end::Bool, skip::F) where {F}
for z in get_leaf_range(tree.tree_data, index)
if skip(tree.indices[z])
continue
end

@POINT 1
idx = tree.reordered ? z : tree.indices[z]
dist_d = evaluate(tree.metric, tree.data[idx], point, do_end)
if dist_d <= best_dists[1]
if skip(tree.indices[z])
continue
end

best_dists[1] = dist_d
best_idxs[1] = idx
percolate_down!(best_dists, best_idxs, dist_d, idx)
Expand All @@ -116,8 +116,13 @@ end
# This will probably prevent SIMD and other optimizations so some care is needed
# to evaluate if it is worth it.
@inline function add_points_inrange!(idx_in_ball::Vector{Int}, tree::NNTree,
index::Int, point::AbstractVector, r::Number, do_end::Bool)
index::Int, point::AbstractVector, r::Number,
do_end::Bool, skip::Function)
for z in get_leaf_range(tree.tree_data, index)
if skip(tree.indices[z])
continue
end

@POINT 1
idx = tree.reordered ? z : tree.indices[z]
dist_d = evaluate(tree.metric, tree.data[idx], point, do_end)
Expand All @@ -129,18 +134,22 @@ end

# Add all points in this subtree since we have determined
# they are all within the desired range
function addall(tree::NNTree, index::Int, idx_in_ball::Vector{Int})
function addall(tree::NNTree, index::Int, idx_in_ball::Vector{Int}, skip::Function)
tree_data = tree.tree_data
@NODE 1
if isleaf(tree.tree_data.n_internal_nodes, index)
for z in get_leaf_range(tree.tree_data, index)
if skip(tree.indices[z])
continue
end

@POINT_UNCHECKED 1
idx = tree.reordered ? z : tree.indices[z]
push!(idx_in_ball, idx)
end
return
else
addall(tree, getleft(index), idx_in_ball)
addall(tree, getright(index), idx_in_ball)
addall(tree, getleft(index), idx_in_ball, skip)
addall(tree, getright(index), idx_in_ball, skip)
end
end
13 changes: 13 additions & 0 deletions test/test_inrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,16 @@
end
end
end

@testset "inrange skip" begin
@testset "tree type" for TreeType in trees_with_brute
data = rand(2, 1000)
tree = TreeType(data)
id = 123

idxs = inrange(tree, data[:, id], 2, true)
@test id in idxs
idxs = inrange(tree, data[:, id], 2, true, i -> i == id)
@test !(id in idxs)
end
end

0 comments on commit 3657a02

Please sign in to comment.