Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ data=$(printf "${i} %.0s" {1..128})
redis-cli hnsw.node.add test1 node${i-1} ${data}
done

redis-cli bgsave
# redis-cli bgsave

redis-cli hnsw.get test1
redis-cli hnsw.node.get test1 node1
Expand All @@ -17,6 +17,7 @@ redis-cli hnsw.search test1 5 ${data}
for i in {1..100}
do
redis-cli hnsw.node.del test1 node${i-1}
sleep 0.1
done

redis-cli hnsw.del test1
109 changes: 65 additions & 44 deletions src/hnsw/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::convert::From;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::rc::Rc;
use std::sync::{Arc, RwLock};
use std::sync::{Arc, RwLock, Weak};
use std::thread;

#[derive(Debug)]
Expand Down Expand Up @@ -84,12 +84,13 @@ where
}

type NodeRef<T> = Arc<RwLock<_Node<T>>>;
type NodeRefWeak<T> = Weak<RwLock<_Node<T>>>;

#[derive(Clone)]
pub struct _Node<T: Float> {
pub name: String,
pub data: Vec<T>,
pub neighbors: Vec<Vec<Node<T>>>,
pub neighbors: Vec<Vec<NodeWeak<T>>>,
}

impl<T> fmt::Debug for _Node<T>
Expand All @@ -108,7 +109,7 @@ where
.iter()
.map(|l| {
l.into_iter()
.map(|n| n.read().name.to_owned())
.map(|n| n.upgrade().read().name.to_owned())
.collect::<Vec<String>>()
})
.collect::<Vec<Vec<String>>>(),
Expand All @@ -127,28 +128,45 @@ impl<T: Float> _Node<T> {
}
}

fn add_neighbor(&mut self, level: usize, neighbor: Node<T>, capacity: Option<usize>) {
fn add_neighbor(&mut self, level: usize, neighbor: NodeWeak<T>, capacity: Option<usize>) {
self.push_levels(level, capacity);
let neighbors = &mut self.neighbors;
if !neighbors[level].contains(&neighbor) {
neighbors[level].push(neighbor);
}
}

fn rm_neighbor(&mut self, level: usize, neighbor: &Node<T>) {
fn rm_neighbor(&mut self, level: usize, neighbor: &NodeWeak<T>) {
let neighbors = &mut self.neighbors;
let index = neighbors[level]
.iter()
.position(|n| *n == *neighbor)
.unwrap();
neighbors[level].remove(index);
}
}

#[derive(Debug, Clone)]
pub struct NodeWeak<T: Float>(pub NodeRefWeak<T>);

impl<T: Float> PartialEq for NodeWeak<T> {
fn eq(&self, other: &Self) -> bool {
Weak::ptr_eq(&self.0, &other.0)
}
}

// fn clear_neighbors(&mut self, level: usize) {
// let neighbors = &mut self.neighbors;
// let cap = neighbors[level].capacity();
// neighbors[level] = Vec::with_capacity(cap);
// }
impl<T: Float> Eq for NodeWeak<T> {}

impl<T: Float> Hash for NodeWeak<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.upgrade().read().name.hash(state);
}
}

impl<T: Float> NodeWeak<T> {
pub fn upgrade(&self) -> Node<T> {
Node(self.0.upgrade().unwrap())
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -191,20 +209,19 @@ impl<T: Float> Node<T> {
node.push_levels(level, capacity);
}

fn add_neighbor(&self, level: usize, neighbor: Node<T>, capacity: Option<usize>) {
fn add_neighbor(&self, level: usize, neighbor: NodeWeak<T>, capacity: Option<usize>) {
let node = &mut self.0.try_write().unwrap();
node.add_neighbor(level, neighbor, capacity);
}

fn rm_neighbor(&self, level: usize, neighbor: &Node<T>) {
fn rm_neighbor(&self, level: usize, neighbor: &NodeWeak<T>) {
let node = &mut self.0.try_write().unwrap();
node.rm_neighbor(level, neighbor);
}

// fn clear_neighbors(&self, level: usize) {
// let node = &mut self.0.try_write().unwrap();
// node.clear_neighbors(level);
// }
pub fn downgrade(&self) -> NodeWeak<T> {
NodeWeak(Arc::downgrade(&self.0))
}
}

type SimPairRef<T, R> = Rc<RefCell<_SimPair<T, R>>>;
Expand Down Expand Up @@ -291,9 +308,9 @@ pub struct Index<T: Float, R: Float> {
pub level_mult: f64, // level generation factor
pub node_count: usize, // count of nodes
pub max_layer: usize, // idx of top layer
pub layers: Vec<HashSet<Node<T>>>, // distinct nodes in each layer
pub layers: Vec<HashSet<NodeWeak<T>>>, // distinct nodes in each layer
pub nodes: HashMap<String, Node<T>>, // hashmap of nodes
pub enterpoint: Option<Node<T>>, // enterpoint node
pub enterpoint: Option<NodeWeak<T>>, // enterpoint node
rng_: StdRng, // rng for level generation
}

Expand Down Expand Up @@ -347,7 +364,7 @@ impl<T: Float, R: Float> fmt::Debug for Index<T, R> {
self.node_count,
self.max_layer,
match &self.enterpoint {
Some(node) => node.read().name.clone(),
Some(node) => node.upgrade().read().name.clone(),
None => "null".to_owned(),
},
)
Expand Down Expand Up @@ -400,10 +417,10 @@ where

if self.node_count == 0 {
let node = Node::new(name, data, self.m_max_0);
self.enterpoint = Some(node.clone());
self.enterpoint = Some(node.downgrade());

let mut layer = HashSet::new();
layer.insert(node.clone());
layer.insert(node.downgrade());
self.layers.push(layer);

self.nodes.insert(name.to_owned(), node);
Expand Down Expand Up @@ -432,7 +449,7 @@ where
self.node_count -= 1;

for lc in (0..(self.max_layer + 1)).rev() {
if self.layers[lc].remove(&node) {
if self.layers[lc].remove(&node.downgrade()) {
break;
}
}
Expand All @@ -455,7 +472,7 @@ where

// update enterpoint if necessary
match &self.enterpoint {
Some(ep) if node == *ep => {
Some(ep) if node == ep.upgrade() => {
let mut new_ep = None;
for lc in (0..(self.max_layer + 1)).rev() {
match self.layers[lc].iter().next() {
Expand Down Expand Up @@ -518,8 +535,8 @@ where

let mut lc = l_max;
while lc > l {
w = self.search_level(data, &ep, 1, lc);
ep = w.pop().unwrap().read().node.clone();
w = self.search_level(data, &ep.upgrade(), 1, lc);
ep = w.pop().unwrap().read().node.downgrade();

if lc == 0 {
break;
Expand All @@ -529,7 +546,7 @@ where

let mut updated = HashSet::new();
for lc in (0..(min(l_max, l) + 1)).rev() {
w = self.search_level(data, &ep, self.ef_construction, lc);
w = self.search_level(data, &ep.upgrade(), self.ef_construction, lc);
let mut neighbors = self.select_neighbors(query, &w, self.m, lc, true, true, None);
self.connect_neighbors(query, &neighbors, lc);

Expand All @@ -551,10 +568,10 @@ where
for n in eneighbors {
let ensim = OrderedFloat::from((self.mfunc)(
&enr.data,
&n.read().data,
&n.upgrade().read().data,
self.data_dim,
));
let enpair = SimPair::new(ensim, n.to_owned());
let enpair = SimPair::new(ensim, n.upgrade());
econn.push(enpair);
}
}
Expand All @@ -570,7 +587,7 @@ where
}
}

ep = w.peek().unwrap().read().node.clone();
ep = w.peek().unwrap().read().node.downgrade();
}

// update nodes in redis
Expand All @@ -583,14 +600,14 @@ where
// new enterpoint if we're in a higher layer
if l > l_max {
self.max_layer = l;
self.enterpoint = Some(query.to_owned());
self.enterpoint = Some(query.downgrade());
while self.layers.len() < l + 1 {
self.layers.push(HashSet::new());
}
}

// add node to layer set
self.layers[l].insert(query.to_owned());
self.layers[l].insert(query.downgrade());

Ok(())
}
Expand Down Expand Up @@ -641,6 +658,7 @@ where
let cpr = cpair.read();
let neighbors = &cpr.node.read().neighbors[level];
for neighbor in neighbors {
let neighbor = neighbor.upgrade();
if !v.contains(&neighbor) {
v.insert(neighbor.clone());

Expand Down Expand Up @@ -699,8 +717,9 @@ where
let epair = ccopy.pop().unwrap();

for eneighbor in &epair.read().node.read().neighbors[lc] {
if *eneighbor == *query
|| (ignored_node.is_some() && *eneighbor == *ignored_node.unwrap())
let eneighbor = eneighbor.upgrade();
if eneighbor == *query
|| (ignored_node.is_some() && eneighbor == *ignored_node.unwrap())
{
continue;
}
Expand Down Expand Up @@ -765,9 +784,9 @@ where
let npair = neighbors.pop().unwrap();
let npr = npair.read();

query.add_neighbor(level, npr.node.clone(), Some(self.m_max_0));
query.add_neighbor(level, npr.node.downgrade(), Some(self.m_max_0));
npr.node
.add_neighbor(level, query.clone(), Some(self.m_max_0));
.add_neighbor(level, query.downgrade(), Some(self.m_max_0));
}
}

Expand All @@ -788,9 +807,9 @@ where
while !newconn.is_empty() {
let newpair = newconn.pop().unwrap();
let npr = newpair.read();
node.add_neighbor(level, npr.node.clone(), Some(self.m_max_0));
node.add_neighbor(level, npr.node.downgrade(), Some(self.m_max_0));
npr.node
.add_neighbor(level, node.clone(), Some(self.m_max_0));
.add_neighbor(level, node.downgrade(), Some(self.m_max_0));
updated.insert(npr.node.clone());
// if new neighbor exists in the old set then we remove it from
// the set of neighbors to be removed
Expand All @@ -806,14 +825,14 @@ where
while !rmconn.is_empty() {
let rmpair = rmconn.pop().unwrap();
let rmpr = rmpair.read();
node.rm_neighbor(level, &rmpr.node);
node.rm_neighbor(level, &rmpr.node.downgrade());
// if node to be removed is the ignored node then pass
match ignored_node {
Some(n) if rmpr.node == *n => {
continue;
}
_ => {
rmpr.node.rm_neighbor(level, &node);
rmpr.node.rm_neighbor(level, &node.downgrade());
updated.insert(rmpr.node.clone());
}
}
Expand All @@ -828,6 +847,7 @@ where
let mut updated = HashSet::new();

for n in neighbors {
let n = n.upgrade();
let nnewconn: BinaryHeap<SimPair<T, R>>;
let mut nconn: BinaryHeap<SimPair<T, R>>;
{
Expand All @@ -836,17 +856,18 @@ where
nconn = BinaryHeap::with_capacity(nneighbors.len());

for nn in nneighbors {
let nn = nn.upgrade();
let nnsim =
OrderedFloat::from((self.mfunc)(&nr.data, &nn.read().data, self.data_dim));
let nnpair = SimPair::new(nnsim, nn.to_owned());
nconn.push(nnpair);
}

let m_max = if lc == 0 { self.m_max_0 } else { self.m_max };
nnewconn = self.select_neighbors(n, &nconn, m_max, lc, true, true, Some(node));
nnewconn = self.select_neighbors(&n, &nconn, m_max, lc, true, true, Some(node));
}
updated.insert(n.clone());
let up = self.update_node_connections(n, &nnewconn, &nconn, lc, Some(node));
let up = self.update_node_connections(&n, &nnewconn, &nconn, lc, Some(node));
for u in up {
updated.insert(u);
}
Expand All @@ -861,12 +882,12 @@ where

let mut lc = l_max;
while lc > 0 {
let w = self.search_level(query, &ep, 1, lc);
ep = w.peek().unwrap().read().node.clone();
let w = self.search_level(query, &ep.upgrade(), 1, lc);
ep = w.peek().unwrap().read().node.downgrade();
lc -= 1;
}

let mut w = self.search_level(query, &ep, ef, 0);
let mut w = self.search_level(query, &ep.upgrade(), ef, 0);

let mut res = Vec::with_capacity(k);
while res.len() < k && !w.is_empty() {
Expand Down
26 changes: 23 additions & 3 deletions src/hnsw/hnsw_tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::hnsw::hnsw::*;
use crate::hnsw::metrics::euclidean;
use std::sync::Arc;
use std::{thread, time};

#[test]
fn hnsw_test() {
Expand All @@ -22,6 +23,18 @@ fn hnsw_test() {
let data = vec![i as f32; 4];
index.add_node(&name, &data, mock_fn).unwrap();
}
// sleep for a brief period to make sure all threads are done
let ten_millis = time::Duration::from_millis(10);
thread::sleep(ten_millis);
for i in 0..100 {
let node_name = format!("node{}", i);
let node = index.nodes.get(&node_name).unwrap();
let sc = Arc::strong_count(&node.0);
if sc > 1 {
println!("{:?}", node);
}
assert_eq!(sc, 1);
}
assert_eq!(index.node_count, 100);
assert_ne!(index.enterpoint, None);

Expand All @@ -44,15 +57,22 @@ fn hnsw_test() {
assert_eq!(index.node_count, 100 - i - 1);
assert_eq!(index.nodes.get(&node_name).is_none(), true);
for l in &index.layers {
assert_eq!(l.contains(&node), false);
assert_eq!(l.contains(&node.downgrade()), false);
}
for n in index.nodes.values() {
for l in &n.read().neighbors {
for nn in l {
assert_ne!(*nn, node);
assert_ne!(nn.upgrade(), node);
}
}
}
assert_eq!(Arc::strong_count(&node.0), 1);
// sleep for a brief period to make sure all threads are done
let ten_millis = time::Duration::from_millis(10);
thread::sleep(ten_millis);
let sc = Arc::strong_count(&node.0);
if sc > 1 {
println!("Delete {:?}", node);
}
assert_eq!(sc, 1);
}
}
Loading