Skip to content

Commit

Permalink
add filter logic to query result
Browse files Browse the repository at this point in the history
  • Loading branch information
GlenDC committed Apr 8, 2024
1 parent d158c0f commit 3e122da
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 13 deletions.
73 changes: 61 additions & 12 deletions venndb-macros/src/generate_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,6 @@ fn generate_db_struct_method_append(
);

Some(quote! {
// TODO: handle duplicate key,
// but only have error if we have possible error cases
let #entry_field_name = match self.#map_name.entry(data.#field_name.clone()) {
::venndb::__internal::hash_map::Entry::Occupied(_) => return Err(#db_duplicate_error_kind_creation),
::venndb::__internal::hash_map::Entry::Vacant(entry) => entry,
Expand Down Expand Up @@ -516,8 +514,12 @@ fn generate_query_struct_impl(
name, name_query
);

let name_query_result_kind = format_ident!("{}Kind", name_query_result);

let name_query_result_iter = format_ident!("{}Iter", name_query_result);

let name_query_result_iter_kind = format_ident!("{}Kind", name_query_result_iter);

let name_query_result_iter_doc = format!(
"An iterator over the found instances of [`{}`] queried using [`{}`], generated by `#[derive(VennDB)]`.",
name, name_query
Expand Down Expand Up @@ -546,8 +548,6 @@ fn generate_query_struct_impl(
self
}

// TODO: support a filter on the result based on a predicate

/// Execute the query on the database, returning an iterator over the results.
#vis fn execute(&self) -> Option<#name_query_result<'a>> {
let mut filter = ::venndb::__internal::bitvec![1; self.db.rows.len()];
Expand All @@ -557,7 +557,7 @@ fn generate_query_struct_impl(
if filter.any() {
Some(#name_query_result {
rows: &self.db.rows,
v: filter,
references: #name_query_result_kind::Bits(filter),
})
} else {
None
Expand All @@ -569,45 +569,94 @@ fn generate_query_struct_impl(
#[derive(Debug)]
#vis struct #name_query_result<'a> {
rows: &'a [#name],
v: ::venndb::__internal::BitVec,
references: #name_query_result_kind,
}

#[derive(Debug)]
enum #name_query_result_kind {
Bits(::venndb::__internal::BitVec),
Indices(Vec<usize>),
}

impl<'a> #name_query_result<'a> {
#[doc=#query_result_method_doc_first]
#vis fn first(&self) -> &'a #name {
let index = self.v.iter_ones().next().expect("should contains at least one result");
let index = match &self.references {
#name_query_result_kind::Bits(v) => v.iter_ones().next().unwrap(),
#name_query_result_kind::Indices(i) => i[0],
};
&self.rows[index]
}

#[doc=#query_result_method_doc_any]
#vis fn any(&self) -> &'a #name {
let n = ::venndb::__internal::rand_usize() % self.v.count_ones();
let index = self.v.iter_ones().nth(n).unwrap();
let index = match &self.references {
#name_query_result_kind::Bits(v) => {
let n = ::venndb::__internal::rand_usize() % v.count_ones();
v.iter_ones().nth(n).unwrap()
}
#name_query_result_kind::Indices(i) => {
let n = ::venndb::__internal::rand_usize() % i.len();
i[n]
}
};
&self.rows[index]
}

#[doc=#query_result_method_doc_iter]
#vis fn iter(&self) -> #name_query_result_iter<'a, '_> {
#name_query_result_iter {
rows: self.rows,
iter_ones: self.v.iter_ones(),
indices: match &self.references {
#name_query_result_kind::Bits(v) => #name_query_result_iter_kind::Bits(v.iter_ones()),
#name_query_result_kind::Indices(i) => #name_query_result_iter_kind::Indices(i.iter()),
},
}
}

/// Filter the found results with the given predicate.
#vis fn filter<F>(&self, predicate: F) -> Option<#name_query_result<'a>>
where
F: Fn(&#name) -> bool,
{
let indices: Vec<usize> = match &self.references {
#name_query_result_kind::Bits(v) => v.iter_ones().filter(|index| predicate(&self.rows[*index])).collect(),
#name_query_result_kind::Indices(i) => i.iter().filter(|&index| predicate(&self.rows[*index])).map(|index| *index).collect(),
};

if indices.is_empty() {
return None;
}

Some(#name_query_result {
rows: self.rows,
references: #name_query_result_kind::Indices(indices),
})
}
}

#[doc=#name_query_result_iter_doc]
#vis struct #name_query_result_iter<'a, 'b> {
rows: &'a [#name],
iter_ones: ::venndb::__internal::IterOnes<'b, usize, ::venndb::__internal::Lsb0>,
indices: #name_query_result_iter_kind<'b>,
}

impl<'a, 'b> Iterator for #name_query_result_iter<'a, 'b> {
type Item = &'a #name;

fn next(&mut self) -> Option<Self::Item> {
self.iter_ones.next().map(|index| &self.rows[index])
let maybe_index = match &mut self.indices {
#name_query_result_iter_kind::Bits(v) => v.next(),
#name_query_result_iter_kind::Indices(i) => i.next().cloned(),
};
maybe_index.map(|index| &self.rows[index])
}
}

#vis enum #name_query_result_iter_kind<'a> {
Bits(::venndb::__internal::IterOnes<'a, usize, ::venndb::__internal::Lsb0>),
Indices(::std::slice::Iter<'a, usize>),
}
}
}

Expand Down
47 changes: 46 additions & 1 deletion venndb-usage/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ mod tests {

#[test]
fn test_employee_duplicate_key() {
// TODO: replace with error instead of panic
let mut db = EmployeeDB::default();
db.append(Employee {
id: 1,
Expand Down Expand Up @@ -397,4 +396,50 @@ mod tests {
assert_eq!(iter.next().unwrap().id, 2);
assert!(iter.next().is_none());
}

#[test]
fn test_db_result_filter() {
let db = EmployeeDB::from_rows(vec![
Employee {
id: 1,
name: "Alice".to_string(),
is_manager: true,
is_admin: false,
is_active: true,
department: Department::Engineering,
},
Employee {
id: 2,
name: "Bob".to_string(),
is_manager: false,
is_admin: false,
is_active: true,
department: Department::Engineering,
},
])
.unwrap();

let mut query = db.query();
query.is_active(true);
let results = query.execute().unwrap();
let rows = results.iter().collect::<Vec<_>>();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].id, 1);
assert_eq!(rows[1].id, 2);

let results = results
.filter(|r| r.department == Department::Engineering)
.unwrap();
let rows = results.iter().collect::<Vec<_>>();
assert_eq!(rows.len(), 2);
assert_eq!(rows[0].id, 1);
assert_eq!(rows[1].id, 2);

let results = results.filter(|r| r.is_manager).unwrap();
let rows = results.iter().collect::<Vec<_>>();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].id, 1);

assert!(results.filter(|r| r.is_admin).is_none());
}
}

0 comments on commit 3e122da

Please sign in to comment.