Skip to content

Commit

Permalink
metric/kd: Flatten the tree representation
Browse files Browse the repository at this point in the history
  • Loading branch information
tavianator committed May 2, 2020
1 parent 9699f46 commit e9a81a6
Showing 1 changed file with 54 additions and 59 deletions.
113 changes: 54 additions & 59 deletions src/metric/kd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,61 +66,71 @@ where
struct KdNode<T> {
/// The value stored in this node.
item: T,
/// The left subtree, if any.
left: Option<Box<Self>>,
/// The right subtree, if any.
right: Option<Box<Self>>,
/// The size of the left subtree.
left_len: usize,
}

impl<T: Cartesian> KdNode<T> {
/// Create a new KdNode.
fn new(i: usize, mut items: Vec<T>) -> Option<Box<Self>> {
if items.is_empty() {
return None;
fn new(item: T) -> Self {
Self { item, left_len: 0 }
}

/// Build a k-d tree recursively.
fn build(slice: &mut [KdNode<T>], i: usize) {
if slice.is_empty() {
return;
}

items.sort_unstable_by_key(|x| OrderedFloat::from(x.coordinate(i)));

let mid = items.len() / 2;
let right: Vec<T> = items.drain((mid + 1)..).collect();
let item = items.pop().unwrap();
let j = (i + 1) % item.dimensions();
Some(Box::new(Self {
item,
left: Self::new(j, items),
right: Self::new(j, right),
}))
slice.sort_unstable_by_key(|n| OrderedFloat::from(n.item.coordinate(i)));

let mid = slice.len() / 2;
slice.swap(0, mid);

let (node, children) = slice.split_first_mut().unwrap();
let (left, right) = children.split_at_mut(mid);
node.left_len = left.len();

let j = (i + 1) % node.item.dimensions();
Self::build(left, j);
Self::build(right, j);
}

/// Recursively search for nearest neighbors.
fn search<'a, U, N>(&'a self, i: usize, closest: &mut [f64], neighborhood: &mut N)
where
fn recurse<'a, U, N>(
slice: &'a [KdNode<T>],
i: usize,
closest: &mut [f64],
neighborhood: &mut N,
) where
T: 'a,
U: CartesianMetric<&'a T>,
N: Neighborhood<&'a T, U>,
{
neighborhood.consider(&self.item);
let (node, children) = slice.split_first().unwrap();
neighborhood.consider(&node.item);

let target = neighborhood.target();
let ti = target.coordinate(i);
let si = self.item.coordinate(i);
let j = (i + 1) % self.item.dimensions();
let ni = node.item.coordinate(i);
let j = (i + 1) % node.item.dimensions();

let (near, far) = if ti <= si {
(&self.left, &self.right)
let (left, right) = children.split_at(node.left_len);
let (near, far) = if ti <= ni {
(left, right)
} else {
(&self.right, &self.left)
(right, left)
};

if let Some(near) = near {
near.search(j, closest, neighborhood);
if !near.is_empty() {
Self::recurse(near, j, closest, neighborhood);
}

if let Some(far) = far {
if !far.is_empty() {
let saved = closest[i];
closest[i] = si;
closest[i] = ni;
if neighborhood.contains_distance(target.distance(closest)) {
far.search(j, closest, neighborhood);
Self::recurse(far, j, closest, neighborhood);
}
closest[i] = saved;
}
Expand All @@ -129,16 +139,14 @@ impl<T: Cartesian> KdNode<T> {

/// A [k-d tree](https://en.wikipedia.org/wiki/K-d_tree).
#[derive(Debug)]
pub struct KdTree<T> {
root: Option<Box<KdNode<T>>>,
}
pub struct KdTree<T>(Vec<KdNode<T>>);

impl<T: Cartesian> FromIterator<T> for KdTree<T> {
/// Create a new k-d tree from a set of points.
fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
Self {
root: KdNode::new(0, items.into_iter().collect()),
}
let mut nodes: Vec<_> = items.into_iter().map(KdNode::new).collect();
KdNode::build(nodes.as_mut_slice(), 0);
Self(nodes)
}
}

Expand All @@ -153,40 +161,27 @@ where
U: 'b,
N: Neighborhood<&'a T, &'b U>,
{
let target = neighborhood.target();
let dims = target.dimensions();
let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect();
if !self.0.is_empty() {
let target = neighborhood.target();
let dims = target.dimensions();
let mut closest: Vec<_> = (0..dims).map(|i| target.coordinate(i)).collect();

if let Some(root) = &self.root {
root.search(0, &mut closest, &mut neighborhood);
KdNode::recurse(&self.0, 0, &mut closest, &mut neighborhood);
}

neighborhood
}
}

/// An iterator that the moves values out of a k-d tree.
#[derive(Debug)]
pub struct IntoIter<T> {
stack: Vec<Box<KdNode<T>>>,
}

impl<T> IntoIter<T> {
fn new(node: Option<Box<KdNode<T>>>) -> Self {
Self {
stack: node.into_iter().collect(),
}
}
}
pub struct IntoIter<T>(std::vec::IntoIter<KdNode<T>>);

impl<T> Iterator for IntoIter<T> {
type Item = T;

fn next(&mut self) -> Option<T> {
self.stack.pop().map(|node| {
self.stack.extend(node.left);
self.stack.extend(node.right);
node.item
})
self.0.next().map(|n| n.item)
}
}

Expand All @@ -195,7 +190,7 @@ impl<T> IntoIterator for KdTree<T> {
type IntoIter = IntoIter<T>;

fn into_iter(self) -> Self::IntoIter {
IntoIter::new(self.root)
IntoIter(self.0.into_iter())
}
}

Expand Down

0 comments on commit e9a81a6

Please sign in to comment.