Skip to content

Commit

Permalink
[Bug fix] Redefining HNSW index error (#4142)
Browse files Browse the repository at this point in the history
  • Loading branch information
emmanuel-keller committed Jun 6, 2024
1 parent 57b488f commit 39c1b33
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 74 deletions.
13 changes: 9 additions & 4 deletions core/src/idx/trees/hnsw/heuristic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,17 @@ impl Heuristic {
let mut ex = c.to_set();
let mut ext = Vec::with_capacity(m_max.min(c.len()));
for (_, e_id) in c.to_vec().into_iter() {
for &e_adj in layer.get_edges(&e_id).unwrap_or_else(|| unreachable!()).iter() {
if e_adj != q_id && ex.insert(e_adj) {
if let Some(d) = elements.get_distance(q_pt, &e_adj) {
ext.push((d, e_adj));
if let Some(e_conn) = layer.get_edges(&e_id) {
for &e_adj in e_conn.iter() {
if e_adj != q_id && ex.insert(e_adj) {
if let Some(d) = elements.get_distance(q_pt, &e_adj) {
ext.push((d, e_adj));
}
}
}
} else {
#[cfg(debug_assertions)]
unreachable!()
}
}
for (e_dist, e_id) in ext {
Expand Down
23 changes: 13 additions & 10 deletions core/src/idx/trees/hnsw/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,20 @@ where
let neighbors = self.graph.add_node_and_bidirectional_edges(q_id, neighbors);

for e_id in neighbors {
let e_conn =
self.graph.get_edges(&e_id).unwrap_or_else(|| unreachable!("Element: {}", e_id));
if e_conn.len() > self.m_max {
if let Some(e_pt) = elements.get_vector(&e_id) {
let e_c = self.build_priority_list(elements, e_id, e_conn);
let mut e_new_conn = self.graph.new_edges();
heuristic.select(elements, self, e_id, e_pt, e_c, &mut e_new_conn);
#[cfg(debug_assertions)]
assert!(!e_new_conn.contains(&e_id));
self.graph.set_node(e_id, e_new_conn);
if let Some(e_conn) = self.graph.get_edges(&e_id) {
if e_conn.len() > self.m_max {
if let Some(e_pt) = elements.get_vector(&e_id) {
let e_c = self.build_priority_list(elements, e_id, e_conn);
let mut e_new_conn = self.graph.new_edges();
heuristic.select(elements, self, e_id, e_pt, e_c, &mut e_new_conn);
#[cfg(debug_assertions)]
assert!(!e_new_conn.contains(&e_id));
self.graph.set_node(e_id, e_new_conn);
}
}
} else {
#[cfg(debug_assertions)]
unreachable!("Element: {}", e_id);
}
}
eps
Expand Down
78 changes: 46 additions & 32 deletions core/src/idx/trees/hnsw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,39 +121,46 @@ where
mut ep_id: ElementId,
top_up_layers: usize,
) {
let mut ep_dist =
self.elements.get_distance(q_pt, &ep_id).unwrap_or_else(|| unreachable!());

if q_level < top_up_layers {
for layer in self.layers[q_level..top_up_layers].iter_mut().rev() {
(ep_dist, ep_id) = layer
.search_single(&self.elements, q_pt, ep_dist, ep_id, 1)
.peek_first()
.unwrap_or_else(|| unreachable!())
if let Some(mut ep_dist) = self.elements.get_distance(q_pt, &ep_id) {
if q_level < top_up_layers {
for layer in self.layers[q_level..top_up_layers].iter_mut().rev() {
if let Some(ep_dist_id) =
layer.search_single(&self.elements, q_pt, ep_dist, ep_id, 1).peek_first()
{
(ep_dist, ep_id) = ep_dist_id;
} else {
#[cfg(debug_assertions)]
unreachable!()
}
}
}
}

let mut eps = DoublePriorityQueue::from(ep_dist, ep_id);
let mut eps = DoublePriorityQueue::from(ep_dist, ep_id);

let insert_to_up_layers = q_level.min(top_up_layers);
if insert_to_up_layers > 0 {
for layer in self.layers.iter_mut().take(insert_to_up_layers).rev() {
eps = layer.insert(&self.elements, &self.heuristic, self.efc, q_id, q_pt, eps);
let insert_to_up_layers = q_level.min(top_up_layers);
if insert_to_up_layers > 0 {
for layer in self.layers.iter_mut().take(insert_to_up_layers).rev() {
eps = layer.insert(&self.elements, &self.heuristic, self.efc, q_id, q_pt, eps);
}
}
}

self.layer0.insert(&self.elements, &self.heuristic, self.efc, q_id, q_pt, eps);
self.layer0.insert(&self.elements, &self.heuristic, self.efc, q_id, q_pt, eps);

if top_up_layers < q_level {
for layer in self.layers[top_up_layers..q_level].iter_mut() {
if !layer.add_empty_node(q_id) {
unreachable!("Already there {}", q_id);
if top_up_layers < q_level {
for layer in self.layers[top_up_layers..q_level].iter_mut() {
if !layer.add_empty_node(q_id) {
#[cfg(debug_assertions)]
unreachable!("Already there {}", q_id);
}
}
}
}

if q_level > top_up_layers {
self.enter_point = Some(q_id);
if q_level > top_up_layers {
self.enter_point = Some(q_id);
}
} else {
#[cfg(debug_assertions)]
unreachable!()
}
}

Expand Down Expand Up @@ -242,15 +249,22 @@ where

fn search_ep(&self, pt: &SharedVector) -> Option<(f64, ElementId)> {
if let Some(mut ep_id) = self.enter_point {
let mut ep_dist =
self.elements.get_distance(pt, &ep_id).unwrap_or_else(|| unreachable!());
for layer in self.layers.iter().rev() {
(ep_dist, ep_id) = layer
.search_single(&self.elements, pt, ep_dist, ep_id, 1)
.peek_first()
.unwrap_or_else(|| unreachable!());
if let Some(mut ep_dist) = self.elements.get_distance(pt, &ep_id) {
for layer in self.layers.iter().rev() {
if let Some(ep_dist_id) =
layer.search_single(&self.elements, pt, ep_dist, ep_id, 1).peek_first()
{
(ep_dist, ep_id) = ep_dist_id;
} else {
#[cfg(debug_assertions)]
unreachable!()
}
}
Some((ep_dist, ep_id))
} else {
#[cfg(debug_assertions)]
unreachable!()
}
Some((ep_dist, ep_id))
} else {
None
}
Expand Down
11 changes: 8 additions & 3 deletions core/src/sql/statements/define/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ impl DefineIndexStatement {
// Clear the cache
run.clear_cache();
// Check if index already exists
if self.if_not_exists
&& run.get_tb_index(opt.ns(), opt.db(), &self.what, &self.name).await.is_ok()
{
let index_exists =
run.get_tb_index(opt.ns(), opt.db(), &self.what, &self.name).await.is_ok();
if self.if_not_exists && index_exists {
return Err(Error::IxAlreadyExists {
value: self.name.to_string(),
});
Expand Down Expand Up @@ -74,6 +74,11 @@ impl DefineIndexStatement {
Err(e) => return Err(e),
}

// Clear the index store cache
if index_exists {
ctx.get_index_stores().index_removed(opt, &mut run, &self.what, &self.name).await?;
}

// Process the statement
let key = crate::key::table::ix::new(opt.ns(), opt.db(), &self.what, &self.name);
run.add_ns(opt.ns(), opt.strict).await?;
Expand Down
59 changes: 34 additions & 25 deletions lib/tests/define.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1081,29 +1081,14 @@ async fn define_statement_index_multiple_unique_existing() -> Result<(), Error>
DEFINE INDEX test ON user COLUMNS account, email UNIQUE;
INFO FOR TABLE user;
";
let dbs = new_ds().await?;
let ses = Session::owner().with_ns("test").with_db("test");
let res = &mut dbs.execute(sql, &ses, None).await?;
assert_eq!(res.len(), 7);
//
for _ in 0..4 {
let tmp = res.remove(0).result;
assert!(tmp.is_ok());
}
//
let tmp = res.remove(0).result;
assert!(matches!(
tmp.err(),
Some(e) if e.to_string() == r#"Database index `test` already contains ['apple', 'test@surrealdb.com'], with record `user:1`"#
));
//
let tmp = res.remove(0).result;
assert!(matches!(
tmp.err(),
Some(e) if e.to_string() == r#"Database index `test` already contains ['apple', 'test@surrealdb.com'], with record `user:1`"#
));

let tmp = res.remove(0).result?;
let mut t = Test::try_new(sql).await?;
t.skip_ok(4);
t.expect_error(
r#"Database index `test` already contains ['apple', 'test@surrealdb.com'], with record `user:1`"#,
);
t.expect_error(
r#"Database index `test` already contains ['apple', 'test@surrealdb.com'], with record `user:1`"#,
);
let val = Value::parse(
"{
events: {},
Expand All @@ -1113,8 +1098,7 @@ async fn define_statement_index_multiple_unique_existing() -> Result<(), Error>
lives: {},
}",
);
assert_eq!(tmp, val);
//
t.expect_value(val);
Ok(())
}

Expand Down Expand Up @@ -1232,6 +1216,31 @@ async fn define_statement_index_multiple_unique_embedded_multiple() -> Result<()
Ok(())
}

#[tokio::test]
async fn define_statement_index_multiple_hnsw() -> Result<(), Error> {
let sql = "
CREATE pts:3 SET point = [8,9,10,11];
DEFINE INDEX hnsw_pts ON pts FIELDS point HNSW DIMENSION 4 DIST EUCLIDEAN TYPE F32 EFC 500 M 12;
DEFINE INDEX hnsw_pts ON pts FIELDS point HNSW DIMENSION 4 DIST EUCLIDEAN TYPE F32 EFC 500 M 12;
INFO FOR TABLE pts;
";
let mut t = Test::try_new(sql).await?;
t.skip_ok(3);
let val = Value::parse(
"{
events: {},
fields: {},
tables: {},
indexes: {
hnsw_pts: 'DEFINE INDEX hnsw_pts ON pts FIELDS point HNSW DIMENSION 4 DIST EUCLIDEAN TYPE F32 EFC 500 M 12 M0 24 LM 0.40242960438184466f'
},
lives: {},
}",
);
t.expect_value(val);
Ok(())
}

#[tokio::test]
async fn define_statement_index_on_schemafull_without_permission() -> Result<(), Error> {
let sql = "
Expand Down

0 comments on commit 39c1b33

Please sign in to comment.