Skip to content

Commit

Permalink
metric/vp: Flatten the tree representation
Browse files Browse the repository at this point in the history
  • Loading branch information
tavianator committed May 2, 2020
1 parent 5de377b commit 9699f46
Showing 1 changed file with 51 additions and 82 deletions.
133 changes: 51 additions & 82 deletions src/metric/vp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,94 +11,76 @@ struct VpNode<T> {
item: T,
/// The radius of this node.
radius: f64,
/// The subtree inside the radius, if any.
inside: Option<Box<Self>>,
/// The subtree outside the radius, if any.
outside: Option<Box<Self>>,
/// The size of the subtree inside the radius.
inside_len: usize,
}

impl<T: Metric> VpNode<T> {
/// Create a new VpNode.
fn new(mut items: Vec<T>) -> Option<Box<Self>> {
if items.is_empty() {
return None;
fn new(item: T) -> Self {
Self {
item,
radius: 0.0,
inside_len: 0,
}
}

let item = items.pop().unwrap();

items.sort_by_cached_key(|a| item.distance(a));

let mid = items.len() / 2;
let outside: Vec<T> = items.drain(mid..).collect();
/// Build a VP tree recursively.
fn build(slice: &mut [VpNode<T>]) {
if let Some((node, children)) = slice.split_first_mut() {
let item = &node.item;
children.sort_by_cached_key(|n| item.distance(&n.item));

let radius = items.last().map(|l| item.distance(l).into()).unwrap_or(0.0);
let (inside, outside) = children.split_at_mut(children.len() / 2);
if let Some(last) = inside.last() {
node.radius = item.distance(&last.item).into();
}
node.inside_len = inside.len();

Some(Box::new(Self {
item,
radius,
inside: Self::new(items),
outside: Self::new(outside),
}))
Self::build(inside);
Self::build(outside);
}
}
}

trait VpSearch<'a, T, U, N> {
/// Recursively search for nearest neighbors.
fn search(&'a self, neighborhood: &mut N);

/// Search the inside subtree.
fn search_inside(&'a self, distance: f64, neighborhood: &mut N);

/// Search the outside subtree.
fn search_outside(&'a self, distance: f64, neighborhood: &mut N);
}
fn recurse<'a, U, N>(slice: &'a [VpNode<T>], neighborhood: &mut N)
where
T: 'a,
U: Metric<&'a T>,
N: Neighborhood<&'a T, U>,
{
let (node, children) = slice.split_first().unwrap();
let (inside, outside) = children.split_at(node.inside_len);

impl<'a, T, U, N> VpSearch<'a, T, U, N> for VpNode<T>
where
T: 'a,
U: Metric<&'a T>,
N: Neighborhood<&'a T, U>,
{
fn search(&'a self, neighborhood: &mut N) {
let distance = neighborhood.consider(&self.item).into();
let distance = neighborhood.consider(&node.item).into();

if distance <= self.radius {
self.search_inside(distance, neighborhood);
self.search_outside(distance, neighborhood);
if distance <= node.radius {
if !inside.is_empty() && neighborhood.contains(distance - node.radius) {
Self::recurse(inside, neighborhood);
}
if !outside.is_empty() && neighborhood.contains(node.radius - distance) {
Self::recurse(outside, neighborhood);
}
} else {
self.search_outside(distance, neighborhood);
self.search_inside(distance, neighborhood);
}
}

fn search_inside(&'a self, distance: f64, neighborhood: &mut N) {
if let Some(inside) = &self.inside {
if neighborhood.contains(distance - self.radius) {
inside.search(neighborhood);
if !outside.is_empty() && neighborhood.contains(node.radius - distance) {
Self::recurse(outside, neighborhood);
}
}
}

fn search_outside(&'a self, distance: f64, neighborhood: &mut N) {
if let Some(outside) = &self.outside {
if neighborhood.contains(self.radius - distance) {
outside.search(neighborhood);
if !inside.is_empty() && neighborhood.contains(distance - node.radius) {
Self::recurse(inside, neighborhood);
}
}
}
}

/// A [vantage-point tree](https://en.wikipedia.org/wiki/Vantage-point_tree).
#[derive(Debug)]
pub struct VpTree<T> {
root: Option<Box<VpNode<T>>>,
}
pub struct VpTree<T>(Vec<VpNode<T>>);

impl<T: Metric> FromIterator<T> for VpTree<T> {
fn from_iter<I: IntoIterator<Item = T>>(items: I) -> Self {
Self {
root: VpNode::new(items.into_iter().collect::<Vec<_>>()),
}
let mut nodes: Vec<_> = items.into_iter().map(VpNode::new).collect();
VpNode::build(nodes.as_mut_slice());
Self(nodes)
}
}

Expand All @@ -113,36 +95,23 @@ where
U: 'b,
N: Neighborhood<&'a T, &'b U>,
{
if let Some(root) = &self.root {
root.search(&mut neighborhood);
if !self.0.is_empty() {
VpNode::recurse(&self.0, &mut neighborhood);
}

neighborhood
}
}

/// An iterator that moves values out of a VP tree.
#[derive(Debug)]
pub struct IntoIter<T> {
stack: Vec<Box<VpNode<T>>>,
}

impl<T> IntoIter<T> {
fn new(node: Option<Box<VpNode<T>>>) -> Self {
Self {
stack: node.into_iter().collect(),
}
}
}
pub struct IntoIter<T>(std::vec::IntoIter<VpNode<T>>);

impl<T> Iterator for IntoIter<T> {
type Item = T;

fn next(&mut self) -> Option<T> {
self.stack.pop().map(|node| {
self.stack.extend(node.inside);
self.stack.extend(node.outside);
node.item
})
self.0.next().map(|n| n.item)
}
}

Expand All @@ -151,7 +120,7 @@ impl<T> IntoIterator for VpTree<T> {
type IntoIter = IntoIter<T>;

fn into_iter(self) -> Self::IntoIter {
IntoIter::new(self.root)
IntoIter(self.0.into_iter())
}
}

Expand Down

0 comments on commit 9699f46

Please sign in to comment.