From 3ff201830e24f43c0bfcae06671aad4c46f72969 Mon Sep 17 00:00:00 2001 From: Ed Schmerling Date: Fri, 3 Nov 2017 17:23:48 -0700 Subject: [PATCH] Add skip predicate to inrange, fixes #53 --- README.md | 2 +- src/ball_tree.jl | 16 +++++++++------- src/brute_tree.jl | 12 +++++++++--- src/inrange.jl | 22 ++++++++++++---------- src/kd_tree.jl | 14 ++++++++------ src/knn.jl | 2 +- src/tree_ops.jl | 25 +++++++++++++++++-------- test/test_inrange.jl | 13 +++++++++++++ 8 files changed, 70 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index d9d1a67..32a3c5b 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ idxs, dists = knn(kdtree, point, k, true) A range search finds all neighbors within the range `r` of given point(s). This is done with the method: ```jl -inrange(tree, points, r, sortres = false) -> idxs +inrange(tree, points, r, sortres = false, skip = always_false) -> idxs ``` Note that for performance reasons the distances are not returned. The arguments to `inrange` are the same as for `knn` except that `sortres` just sorts the returned index vector. diff --git a/src/ball_tree.jl b/src/ball_tree.jl index 1d2de4f..8733e5c 100644 --- a/src/ball_tree.jl +++ b/src/ball_tree.jl @@ -189,9 +189,10 @@ end function _inrange(tree::BallTree{V}, point::AbstractVector, radius::Number, - idx_in_ball::Vector{Int}) where {V} + idx_in_ball::Vector{Int}, + skip::Function) where {V} ball = HyperSphere(convert(V, point), convert(eltype(V), radius)) # The "query ball" - inrange_kernel!(tree, 1, point, ball, idx_in_ball) # Call the recursive range finder + inrange_kernel!(tree, 1, point, ball, idx_in_ball, skip) # Call the recursive range finder return end @@ -199,7 +200,8 @@ function inrange_kernel!(tree::BallTree, index::Int, point::AbstractVector, query_ball::HyperSphere, - idx_in_ball::Vector{Int}) + idx_in_ball::Vector{Int}, + skip::Function) @NODE 1 if index > length(tree.hyper_spheres) @@ -216,17 +218,17 @@ function inrange_kernel!(tree::BallTree, # At a leaf node, check all points in the leaf node if isleaf(tree.tree_data.n_internal_nodes, index) - add_points_inrange!(idx_in_ball, tree, index, point, query_ball.r, true) + add_points_inrange!(idx_in_ball, tree, index, point, query_ball.r, true, skip) return end # The query ball encloses the sub tree bounding sphere. Add all points in the # sub tree without checking the distance function. if encloses(tree.metric, sphere, query_ball) - addall(tree, index, idx_in_ball) + addall(tree, index, idx_in_ball, skip) else # Recursively call the left and right sub tree. - inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball) - inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball) + inrange_kernel!(tree, getleft(index), point, query_ball, idx_in_ball, skip) + inrange_kernel!(tree, getright(index), point, query_ball, idx_in_ball, skip) end end diff --git a/src/brute_tree.jl b/src/brute_tree.jl index 12cc844..8293fda 100644 --- a/src/brute_tree.jl +++ b/src/brute_tree.jl @@ -55,8 +55,9 @@ end function _inrange(tree::BruteTree, point::AbstractVector, radius::Number, - idx_in_ball::Vector{Int}) - inrange_kernel!(tree, point, radius, idx_in_ball) + idx_in_ball::Vector{Int}, + skip::Function) + inrange_kernel!(tree, point, radius, idx_in_ball, skip) return end @@ -64,8 +65,13 @@ end function inrange_kernel!(tree::BruteTree, point::AbstractVector, r::Number, - idx_in_ball::Vector{Int}) + idx_in_ball::Vector{Int}, + skip::Function) for i in 1:length(tree.data) + if skip(i) + continue + end + @POINT 1 d = evaluate(tree.metric, tree.data[i], point) if d <= r diff --git a/src/inrange.jl b/src/inrange.jl index b47522e..4466d7b 100644 --- a/src/inrange.jl +++ b/src/inrange.jl @@ -1,28 +1,30 @@ check_radius(r) = r < 0 && throw(ArgumentError("the query radius r must be ≧ 0")) """ - inrange(tree::NNTree, points, radius [, sortres=false]) -> indices + inrange(tree::NNTree, points, radius [, sortres=false, skip=always_false]) -> indices Find all the points in the tree which is closer than `radius` to `points`. If -`sortres = true` the resulting indices are sorted. +`sortres = true` the resulting indices are sorted. `skip` is an optional predicate +to determine if a point that would be returned should be skipped. """ function inrange(tree::NNTree, points::Vector{T}, radius::Number, - sortres=false) where {T <: AbstractVector} + sortres=false, + skip::Function=always_false) where {T <: AbstractVector} check_input(tree, points) check_radius(radius) idxs = [Vector{Int}() for _ in 1:length(points)] for i in 1:length(points) - inrange_point!(tree, points[i], radius, sortres, idxs[i]) + inrange_point!(tree, points[i], radius, sortres, idxs[i], skip) end return idxs end -function inrange_point!(tree, point, radius, sortres, idx) - _inrange(tree, point, radius, idx) +function inrange_point!(tree, point, radius, sortres, idx, skip) + _inrange(tree, point, radius, idx, skip) if tree.reordered @inbounds for j in 1:length(idx) idx[j] = tree.indices[idx[j]] @@ -32,15 +34,15 @@ function inrange_point!(tree, point, radius, sortres, idx) return end -function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false) where {V, T <: Number} +function inrange(tree::NNTree{V}, point::AbstractVector{T}, radius::Number, sortres=false, skip::Function=always_false) where {V, T <: Number} check_input(tree, point) check_radius(radius) idx = Int[] - inrange_point!(tree, point, radius, sortres, idx) + inrange_point!(tree, point, radius, sortres, idx, skip) return idx end -function inrange(tree::NNTree{V}, point::Matrix{T}, radius::Number, sortres=false) where {V, T <: Number} +function inrange(tree::NNTree{V}, point::Matrix{T}, radius::Number, sortres=false, skip::Function=always_false) where {V, T <: Number} dim = size(point, 1) npoints = size(point, 2) if isbits(T) @@ -48,5 +50,5 @@ function inrange(tree::NNTree{V}, point::Matrix{T}, radius::Number, sortres=fals else new_data = SVector{dim,T}[SVector{dim,T}(point[:, i]) for i in 1:npoints] end - inrange(tree, new_data, radius, sortres) + inrange(tree, new_data, radius, sortres, skip) end diff --git a/src/kd_tree.jl b/src/kd_tree.jl index e80304f..09c2e65 100644 --- a/src/kd_tree.jl +++ b/src/kd_tree.jl @@ -203,10 +203,11 @@ end function _inrange(tree::KDTree, point::AbstractVector, radius::Number, - idx_in_ball = Int[]) + idx_in_ball = Int[], + skip::Function = always_false) init_min = get_min_distance(tree.hyper_rec, point) inrange_kernel!(tree, 1, point, eval_op(tree.metric, radius, zero(init_min)), idx_in_ball, - init_min) + init_min, skip) return end @@ -216,7 +217,8 @@ function inrange_kernel!(tree::KDTree, point::AbstractVector, r::Number, idx_in_ball::Vector{Int}, - min_dist) + min_dist, + skip::Function) @NODE 1 # Point is outside hyper rectangle, skip the whole sub tree if min_dist > r @@ -225,7 +227,7 @@ function inrange_kernel!(tree::KDTree, # At a leaf node. Go through all points in node and add those in range if isleaf(tree.tree_data.n_internal_nodes, index) - add_points_inrange!(idx_in_ball, tree, index, point, r, false) + add_points_inrange!(idx_in_ball, tree, index, point, r, false, skip) return end @@ -247,7 +249,7 @@ function inrange_kernel!(tree::KDTree, ddiff = max(zero(lo - p_dim), lo - p_dim) end # Call closer sub tree - inrange_kernel!(tree, close, point, r, idx_in_ball, min_dist) + inrange_kernel!(tree, close, point, r, idx_in_ball, min_dist, skip) # TODO: We could potentially also keep track of the max distance # between the point and the hyper rectangle and add the whole sub tree @@ -259,5 +261,5 @@ function inrange_kernel!(tree::KDTree, ddiff_pow = eval_pow(M, ddiff) diff_tot = eval_diff(M, split_diff_pow, ddiff_pow) new_min = eval_reduce(M, min_dist, diff_tot) - inrange_kernel!(tree, far, point, r, idx_in_ball, new_min) + inrange_kernel!(tree, far, point, r, idx_in_ball, new_min, skip) end diff --git a/src/knn.jl b/src/knn.jl index c2605f2..02acafb 100644 --- a/src/knn.jl +++ b/src/knn.jl @@ -5,7 +5,7 @@ function check_k(tree, k) end """ - knn(tree::NNTree, points, k [, sortres=false]) -> indices, distances + knn(tree::NNTree, points, k [, sortres=false, skip=always_false]) -> indices, distances Performs a lookup of the `k` nearest neigbours to the `points` from the data in the `tree`. If `sortres = true` the result is sorted such that the results are diff --git a/src/tree_ops.jl b/src/tree_ops.jl index 68fa4bf..bfff956 100644 --- a/src/tree_ops.jl +++ b/src/tree_ops.jl @@ -94,14 +94,14 @@ end tree::NNTree, index::Int, point::AbstractVector, do_end::Bool, skip::F) where {F} for z in get_leaf_range(tree.tree_data, index) + if skip(tree.indices[z]) + continue + end + @POINT 1 idx = tree.reordered ? z : tree.indices[z] dist_d = evaluate(tree.metric, tree.data[idx], point, do_end) if dist_d <= best_dists[1] - if skip(tree.indices[z]) - continue - end - best_dists[1] = dist_d best_idxs[1] = idx percolate_down!(best_dists, best_idxs, dist_d, idx) @@ -116,8 +116,13 @@ end # This will probably prevent SIMD and other optimizations so some care is needed # to evaluate if it is worth it. @inline function add_points_inrange!(idx_in_ball::Vector{Int}, tree::NNTree, - index::Int, point::AbstractVector, r::Number, do_end::Bool) + index::Int, point::AbstractVector, r::Number, + do_end::Bool, skip::Function) for z in get_leaf_range(tree.tree_data, index) + if skip(tree.indices[z]) + continue + end + @POINT 1 idx = tree.reordered ? z : tree.indices[z] dist_d = evaluate(tree.metric, tree.data[idx], point, do_end) @@ -129,18 +134,22 @@ end # Add all points in this subtree since we have determined # they are all within the desired range -function addall(tree::NNTree, index::Int, idx_in_ball::Vector{Int}) +function addall(tree::NNTree, index::Int, idx_in_ball::Vector{Int}, skip::Function) tree_data = tree.tree_data @NODE 1 if isleaf(tree.tree_data.n_internal_nodes, index) for z in get_leaf_range(tree.tree_data, index) + if skip(tree.indices[z]) + continue + end + @POINT_UNCHECKED 1 idx = tree.reordered ? z : tree.indices[z] push!(idx_in_ball, idx) end return else - addall(tree, getleft(index), idx_in_ball) - addall(tree, getright(index), idx_in_ball) + addall(tree, getleft(index), idx_in_ball, skip) + addall(tree, getright(index), idx_in_ball, skip) end end diff --git a/test/test_inrange.jl b/test/test_inrange.jl index f963665..0a9b163 100644 --- a/test/test_inrange.jl +++ b/test/test_inrange.jl @@ -44,3 +44,16 @@ end end end + +@testset "inrange skip" begin + @testset "tree type" for TreeType in trees_with_brute + data = rand(2, 1000) + tree = TreeType(data) + id = 123 + + idxs = inrange(tree, data[:, id], 2, true) + @test id in idxs + idxs = inrange(tree, data[:, id], 2, true, i -> i == id) + @test !(id in idxs) + end +end