diff --git a/cmd.sh b/cmd.sh index 7f01256..c938313 100755 --- a/cmd.sh +++ b/cmd.sh @@ -6,7 +6,7 @@ data=$(printf "${i} %.0s" {1..128}) redis-cli hnsw.node.add test1 node${i-1} ${data} done -redis-cli bgsave +# redis-cli bgsave redis-cli hnsw.get test1 redis-cli hnsw.node.get test1 node1 @@ -17,6 +17,7 @@ redis-cli hnsw.search test1 5 ${data} for i in {1..100} do redis-cli hnsw.node.del test1 node${i-1} +sleep 0.1 done redis-cli hnsw.del test1 diff --git a/src/hnsw/hnsw.rs b/src/hnsw/hnsw.rs index 90e80f4..d47477d 100644 --- a/src/hnsw/hnsw.rs +++ b/src/hnsw/hnsw.rs @@ -12,7 +12,7 @@ use std::convert::From; use std::fmt; use std::hash::{Hash, Hasher}; use std::rc::Rc; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, RwLock, Weak}; use std::thread; #[derive(Debug)] @@ -84,12 +84,13 @@ where } type NodeRef = Arc>>; +type NodeRefWeak = Weak>>; #[derive(Clone)] pub struct _Node { pub name: String, pub data: Vec, - pub neighbors: Vec>>, + pub neighbors: Vec>>, } impl fmt::Debug for _Node @@ -108,7 +109,7 @@ where .iter() .map(|l| { l.into_iter() - .map(|n| n.read().name.to_owned()) + .map(|n| n.upgrade().read().name.to_owned()) .collect::>() }) .collect::>>(), @@ -127,7 +128,7 @@ impl _Node { } } - fn add_neighbor(&mut self, level: usize, neighbor: Node, capacity: Option) { + fn add_neighbor(&mut self, level: usize, neighbor: NodeWeak, capacity: Option) { self.push_levels(level, capacity); let neighbors = &mut self.neighbors; if !neighbors[level].contains(&neighbor) { @@ -135,7 +136,7 @@ impl _Node { } } - fn rm_neighbor(&mut self, level: usize, neighbor: &Node) { + fn rm_neighbor(&mut self, level: usize, neighbor: &NodeWeak) { let neighbors = &mut self.neighbors; let index = neighbors[level] .iter() @@ -143,12 +144,29 @@ impl _Node { .unwrap(); neighbors[level].remove(index); } +} + +#[derive(Debug, Clone)] +pub struct NodeWeak(pub NodeRefWeak); + +impl PartialEq for NodeWeak { + fn eq(&self, other: &Self) -> bool { + Weak::ptr_eq(&self.0, &other.0) + } +} - // fn clear_neighbors(&mut self, level: usize) { - // let neighbors = &mut self.neighbors; - // let cap = neighbors[level].capacity(); - // neighbors[level] = Vec::with_capacity(cap); - // } +impl Eq for NodeWeak {} + +impl Hash for NodeWeak { + fn hash(&self, state: &mut H) { + self.upgrade().read().name.hash(state); + } +} + +impl NodeWeak { + pub fn upgrade(&self) -> Node { + Node(self.0.upgrade().unwrap()) + } } #[derive(Debug, Clone)] @@ -191,20 +209,19 @@ impl Node { node.push_levels(level, capacity); } - fn add_neighbor(&self, level: usize, neighbor: Node, capacity: Option) { + fn add_neighbor(&self, level: usize, neighbor: NodeWeak, capacity: Option) { let node = &mut self.0.try_write().unwrap(); node.add_neighbor(level, neighbor, capacity); } - fn rm_neighbor(&self, level: usize, neighbor: &Node) { + fn rm_neighbor(&self, level: usize, neighbor: &NodeWeak) { let node = &mut self.0.try_write().unwrap(); node.rm_neighbor(level, neighbor); } - // fn clear_neighbors(&self, level: usize) { - // let node = &mut self.0.try_write().unwrap(); - // node.clear_neighbors(level); - // } + pub fn downgrade(&self) -> NodeWeak { + NodeWeak(Arc::downgrade(&self.0)) + } } type SimPairRef = Rc>>; @@ -291,9 +308,9 @@ pub struct Index { pub level_mult: f64, // level generation factor pub node_count: usize, // count of nodes pub max_layer: usize, // idx of top layer - pub layers: Vec>>, // distinct nodes in each layer + pub layers: Vec>>, // distinct nodes in each layer pub nodes: HashMap>, // hashmap of nodes - pub enterpoint: Option>, // enterpoint node + pub enterpoint: Option>, // enterpoint node rng_: StdRng, // rng for level generation } @@ -347,7 +364,7 @@ impl fmt::Debug for Index { self.node_count, self.max_layer, match &self.enterpoint { - Some(node) => node.read().name.clone(), + Some(node) => node.upgrade().read().name.clone(), None => "null".to_owned(), }, ) @@ -400,10 +417,10 @@ where if self.node_count == 0 { let node = Node::new(name, data, self.m_max_0); - self.enterpoint = Some(node.clone()); + self.enterpoint = Some(node.downgrade()); let mut layer = HashSet::new(); - layer.insert(node.clone()); + layer.insert(node.downgrade()); self.layers.push(layer); self.nodes.insert(name.to_owned(), node); @@ -432,7 +449,7 @@ where self.node_count -= 1; for lc in (0..(self.max_layer + 1)).rev() { - if self.layers[lc].remove(&node) { + if self.layers[lc].remove(&node.downgrade()) { break; } } @@ -455,7 +472,7 @@ where // update enterpoint if necessary match &self.enterpoint { - Some(ep) if node == *ep => { + Some(ep) if node == ep.upgrade() => { let mut new_ep = None; for lc in (0..(self.max_layer + 1)).rev() { match self.layers[lc].iter().next() { @@ -518,8 +535,8 @@ where let mut lc = l_max; while lc > l { - w = self.search_level(data, &ep, 1, lc); - ep = w.pop().unwrap().read().node.clone(); + w = self.search_level(data, &ep.upgrade(), 1, lc); + ep = w.pop().unwrap().read().node.downgrade(); if lc == 0 { break; @@ -529,7 +546,7 @@ where let mut updated = HashSet::new(); for lc in (0..(min(l_max, l) + 1)).rev() { - w = self.search_level(data, &ep, self.ef_construction, lc); + w = self.search_level(data, &ep.upgrade(), self.ef_construction, lc); let mut neighbors = self.select_neighbors(query, &w, self.m, lc, true, true, None); self.connect_neighbors(query, &neighbors, lc); @@ -551,10 +568,10 @@ where for n in eneighbors { let ensim = OrderedFloat::from((self.mfunc)( &enr.data, - &n.read().data, + &n.upgrade().read().data, self.data_dim, )); - let enpair = SimPair::new(ensim, n.to_owned()); + let enpair = SimPair::new(ensim, n.upgrade()); econn.push(enpair); } } @@ -570,7 +587,7 @@ where } } - ep = w.peek().unwrap().read().node.clone(); + ep = w.peek().unwrap().read().node.downgrade(); } // update nodes in redis @@ -583,14 +600,14 @@ where // new enterpoint if we're in a higher layer if l > l_max { self.max_layer = l; - self.enterpoint = Some(query.to_owned()); + self.enterpoint = Some(query.downgrade()); while self.layers.len() < l + 1 { self.layers.push(HashSet::new()); } } // add node to layer set - self.layers[l].insert(query.to_owned()); + self.layers[l].insert(query.downgrade()); Ok(()) } @@ -641,6 +658,7 @@ where let cpr = cpair.read(); let neighbors = &cpr.node.read().neighbors[level]; for neighbor in neighbors { + let neighbor = neighbor.upgrade(); if !v.contains(&neighbor) { v.insert(neighbor.clone()); @@ -699,8 +717,9 @@ where let epair = ccopy.pop().unwrap(); for eneighbor in &epair.read().node.read().neighbors[lc] { - if *eneighbor == *query - || (ignored_node.is_some() && *eneighbor == *ignored_node.unwrap()) + let eneighbor = eneighbor.upgrade(); + if eneighbor == *query + || (ignored_node.is_some() && eneighbor == *ignored_node.unwrap()) { continue; } @@ -765,9 +784,9 @@ where let npair = neighbors.pop().unwrap(); let npr = npair.read(); - query.add_neighbor(level, npr.node.clone(), Some(self.m_max_0)); + query.add_neighbor(level, npr.node.downgrade(), Some(self.m_max_0)); npr.node - .add_neighbor(level, query.clone(), Some(self.m_max_0)); + .add_neighbor(level, query.downgrade(), Some(self.m_max_0)); } } @@ -788,9 +807,9 @@ where while !newconn.is_empty() { let newpair = newconn.pop().unwrap(); let npr = newpair.read(); - node.add_neighbor(level, npr.node.clone(), Some(self.m_max_0)); + node.add_neighbor(level, npr.node.downgrade(), Some(self.m_max_0)); npr.node - .add_neighbor(level, node.clone(), Some(self.m_max_0)); + .add_neighbor(level, node.downgrade(), Some(self.m_max_0)); updated.insert(npr.node.clone()); // if new neighbor exists in the old set then we remove it from // the set of neighbors to be removed @@ -806,14 +825,14 @@ where while !rmconn.is_empty() { let rmpair = rmconn.pop().unwrap(); let rmpr = rmpair.read(); - node.rm_neighbor(level, &rmpr.node); + node.rm_neighbor(level, &rmpr.node.downgrade()); // if node to be removed is the ignored node then pass match ignored_node { Some(n) if rmpr.node == *n => { continue; } _ => { - rmpr.node.rm_neighbor(level, &node); + rmpr.node.rm_neighbor(level, &node.downgrade()); updated.insert(rmpr.node.clone()); } } @@ -828,6 +847,7 @@ where let mut updated = HashSet::new(); for n in neighbors { + let n = n.upgrade(); let nnewconn: BinaryHeap>; let mut nconn: BinaryHeap>; { @@ -836,6 +856,7 @@ where nconn = BinaryHeap::with_capacity(nneighbors.len()); for nn in nneighbors { + let nn = nn.upgrade(); let nnsim = OrderedFloat::from((self.mfunc)(&nr.data, &nn.read().data, self.data_dim)); let nnpair = SimPair::new(nnsim, nn.to_owned()); @@ -843,10 +864,10 @@ where } let m_max = if lc == 0 { self.m_max_0 } else { self.m_max }; - nnewconn = self.select_neighbors(n, &nconn, m_max, lc, true, true, Some(node)); + nnewconn = self.select_neighbors(&n, &nconn, m_max, lc, true, true, Some(node)); } updated.insert(n.clone()); - let up = self.update_node_connections(n, &nnewconn, &nconn, lc, Some(node)); + let up = self.update_node_connections(&n, &nnewconn, &nconn, lc, Some(node)); for u in up { updated.insert(u); } @@ -861,12 +882,12 @@ where let mut lc = l_max; while lc > 0 { - let w = self.search_level(query, &ep, 1, lc); - ep = w.peek().unwrap().read().node.clone(); + let w = self.search_level(query, &ep.upgrade(), 1, lc); + ep = w.peek().unwrap().read().node.downgrade(); lc -= 1; } - let mut w = self.search_level(query, &ep, ef, 0); + let mut w = self.search_level(query, &ep.upgrade(), ef, 0); let mut res = Vec::with_capacity(k); while res.len() < k && !w.is_empty() { diff --git a/src/hnsw/hnsw_tests.rs b/src/hnsw/hnsw_tests.rs index 163afbe..3e8807a 100644 --- a/src/hnsw/hnsw_tests.rs +++ b/src/hnsw/hnsw_tests.rs @@ -1,6 +1,7 @@ use crate::hnsw::hnsw::*; use crate::hnsw::metrics::euclidean; use std::sync::Arc; +use std::{thread, time}; #[test] fn hnsw_test() { @@ -22,6 +23,18 @@ fn hnsw_test() { let data = vec![i as f32; 4]; index.add_node(&name, &data, mock_fn).unwrap(); } + // sleep for a brief period to make sure all threads are done + let ten_millis = time::Duration::from_millis(10); + thread::sleep(ten_millis); + for i in 0..100 { + let node_name = format!("node{}", i); + let node = index.nodes.get(&node_name).unwrap(); + let sc = Arc::strong_count(&node.0); + if sc > 1 { + println!("{:?}", node); + } + assert_eq!(sc, 1); + } assert_eq!(index.node_count, 100); assert_ne!(index.enterpoint, None); @@ -44,15 +57,22 @@ fn hnsw_test() { assert_eq!(index.node_count, 100 - i - 1); assert_eq!(index.nodes.get(&node_name).is_none(), true); for l in &index.layers { - assert_eq!(l.contains(&node), false); + assert_eq!(l.contains(&node.downgrade()), false); } for n in index.nodes.values() { for l in &n.read().neighbors { for nn in l { - assert_ne!(*nn, node); + assert_ne!(nn.upgrade(), node); } } } - assert_eq!(Arc::strong_count(&node.0), 1); + // sleep for a brief period to make sure all threads are done + let ten_millis = time::Duration::from_millis(10); + thread::sleep(ten_millis); + let sc = Arc::strong_count(&node.0); + if sc > 1 { + println!("Delete {:?}", node); + } + assert_eq!(sc, 1); } } diff --git a/src/lib.rs b/src/lib.rs index eedf4ad..c5c76c5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -191,7 +191,7 @@ fn make_index(ctx: &Context, ir: &IndexRedis) -> Result, RedisEr Some(node) => node, None => return Err(format!("Node: {} does not exist", node_name).into()), }; - node_layer.push(nn.clone()); + node_layer.push(nn.downgrade()); } target.write().neighbors.push(node_layer); } @@ -205,7 +205,7 @@ fn make_index(ctx: &Context, ir: &IndexRedis) -> Result, RedisEr Some(n) => n, None => return Err(format!("Node: {} does not exist", node_name).into()), }; - node_layer.insert(node.clone()); + node_layer.insert(node.downgrade()); } index.layers.push(node_layer); } @@ -217,7 +217,7 @@ fn make_index(ctx: &Context, ir: &IndexRedis) -> Result, RedisEr Some(n) => n, None => return Err(format!("Node: {} does not exist", node_name).into()), }; - Some(node.clone()) + Some(node.downgrade()) } None => None, }; @@ -292,6 +292,15 @@ fn delete_node(ctx: &Context, args: Vec) -> RedisResult { let node_name = format!("{}.{}.{}", PREFIX, &args[1], &args[2]); + // TODO return error if node has more than 1 strong_count + let node = index.nodes.get(&node_name).unwrap(); + if Arc::strong_count(&node.0) > 1 { + return Err(format!( + "{} is being accessed, unable to delete. Try again later", + &node_name + ) + .into()); + } match index.delete_node(&node_name, update_node) { Err(e) => return Err(e.error_string().into()), _ => (), diff --git a/src/types.rs b/src/types.rs index 778954d..5c41bc9 100644 --- a/src/types.rs +++ b/src/types.rs @@ -44,7 +44,7 @@ impl From<&Index> for IndexRedis { .iter() .map(|l| { l.into_iter() - .map(|n| n.read().name.clone()) + .map(|n| n.upgrade().read().name.clone()) .collect::>() }) .collect(), @@ -54,7 +54,7 @@ impl From<&Index> for IndexRedis { .map(|k| k.clone()) .collect::>(), enterpoint: match &index.enterpoint { - Some(ep) => Some(ep.read().name.clone()), + Some(ep) => Some(ep.upgrade().read().name.clone()), None => None, }, } @@ -271,7 +271,7 @@ impl From<&Node> for NodeRedis { .into_iter() .map(|l| { l.into_iter() - .map(|n| n.read().name.clone()) + .map(|n| n.upgrade().read().name.clone()) .collect::>() }) .collect(),