Skip to content

Commit

Permalink
perf: Refactor the Trie and TrieNode structs to replace the 2 vec…
Browse files Browse the repository at this point in the history
…tors used to store values and trie node, by the root node of the trie. The values are now directly stored in the TrieNode (instead of having integers to get values from an array), and the Trie only contains 1 field: the root node of the trie. This reduces amount of effective line of code from 127 to 97.

This seems to have improved performance in some cases (benchmark has been run twice to make sure those numbers are somewhat reproducible):

* -30% on prefix/postfix matches (was ~350µs, is now ~230µs)
* Sometimes -98% on mismatch (was ~70µs, is now ~1µs sometimes)
* Interestingly we see a regression +20% for massive_match, even if match see an improvement of -33%

Full Criterion benchmark results (the diff is with the previous implementation with vectors):

```
trie_match              time:   [6.5473 ns 6.7279 ns 6.9287 ns]
                        change: [-34.375% -32.939% -31.369%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 1 outliers among 100 measurements (1.00%)
  1 (1.00%) high mild

trie_mismatch           time:   [3.6109 ns 3.6956 ns 3.7926 ns]
                        change: [-46.306% -44.989% -43.610%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 1 outliers among 100 measurements (1.00%)
  1 (1.00%) high mild

trie_massive_match      time:   [154.76 µs 156.44 µs 158.59 µs]
                        change: [+18.768% +21.621% +24.627%] (p = 0.00 < 0.05)
                        Performance has regressed.
Found 5 outliers among 100 measurements (5.00%)
  5 (5.00%) high mild

trie_massive_mismatch_on_0
                        time:   [25.287 µs 25.821 µs 26.399 µs]
                        change: [-14.458% -11.912% -9.2597%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 5 outliers among 100 measurements (5.00%)
  5 (5.00%) high mild

trie_massive_mismatch_on_1
                        time:   [1.1365 µs 1.1609 µs 1.1929 µs]
                        change: [-98.281% -98.262% -98.240%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 7 outliers among 100 measurements (7.00%)
  4 (4.00%) high mild
  3 (3.00%) high severe

trie_massive_mismatch_on_2
                        time:   [24.619 µs 24.728 µs 24.879 µs]
                        change: [-3.7176% -2.9119% -1.7259%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 6 outliers among 100 measurements (6.00%)
  3 (3.00%) high mild
  3 (3.00%) high severe

trie_massive_mismatch_on_3
                        time:   [22.300 µs 22.377 µs 22.473 µs]
                        change: [-12.754% -12.221% -11.670%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 5 outliers among 100 measurements (5.00%)
  4 (4.00%) high mild
  1 (1.00%) high severe

trie_prefixes_match     time:   [240.67 µs 247.70 µs 255.57 µs]
                        change: [-30.363% -28.485% -26.708%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 1 outliers among 100 measurements (1.00%)
  1 (1.00%) high mild

trie_postfixes_match    time:   [227.52 µs 233.84 µs 241.09 µs]
                        change: [-28.965% -26.333% -23.562%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 7 outliers among 100 measurements (7.00%)
  4 (4.00%) high mild
  3 (3.00%) high severe

trie_prefix_longest_match
                        time:   [133.98 µs 137.34 µs 141.30 µs]
                        change: [-41.923% -39.851% -37.699%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 20 outliers among 100 measurements (20.00%)
  5 (5.00%) high mild
  15 (15.00%) high severe

trie_massive_prefixes_match
                        time:   [228.73 µs 234.37 µs 241.08 µs]
                        change: [-31.061% -29.485% -27.961%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 17 outliers among 100 measurements (17.00%)
  6 (6.00%) high mild
  11 (11.00%) high severe

trie_massive_longest_prefixes_match
                        time:   [157.31 µs 160.18 µs 163.82 µs]
                        change: [-28.283% -25.535% -22.555%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 11 outliers among 100 measurements (11.00%)
  7 (7.00%) high mild
  4 (4.00%) high severe

trie_massive_postfixes_match
                        time:   [226.97 µs 231.64 µs 237.56 µs]
                        change: [-21.204% -18.930% -16.464%] (p = 0.00 < 0.05)
                        Performance has improved.
Found 5 outliers among 100 measurements (5.00%)
  3 (3.00%) high mild
  2 (2.00%) high severe
```

BREAKING CHANGE: the find prefixes/postfixes functions now returns references to value instead of cloned values
  • Loading branch information
vemonet committed Dec 23, 2023
1 parent b0bf31a commit e7b9531
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 172 deletions.
1 change: 1 addition & 0 deletions .rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
chain_width = 70
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ trie.insert("abc".bytes(), "ABC");
trie.insert("abcde".bytes(), "ABCDE");

let prefixes = trie.find_prefixes("abcd".bytes());
assert_eq!(prefixes, vec!["A", "AB", "ABC"]);
assert_eq!(prefixes, vec![&"A", &"AB", &"ABC"]);

let longest = trie.find_longest_prefix("abcd".bytes());
assert_eq!(longest, Some("ABC"));
assert_eq!(longest, Some("ABC").as_ref());
```

### 🔍 Find postfixes
Expand All @@ -71,7 +71,7 @@ trie.insert("applet".bytes(), "Applet");
trie.insert("apricot".bytes(), "Apricot");

let strings = trie.find_postfixes("app".bytes());
assert_eq!(strings, vec!["App", "Apple", "Applet"]);
assert_eq!(strings, vec![&"App", &"Apple", &"Applet"]);
```

### 🔑 Key-based retrieval functions
Expand All @@ -88,7 +88,7 @@ trie.insert("applet".bytes(), "Applet");
assert!(trie.contains_key("app".bytes()));
assert!(!trie.contains_key("not_existing_key".bytes()));
assert_eq!(trie.get("app".bytes()), Some("App").as_ref());
assert_eq!(trie.get("none".bytes()), None);
assert_eq!(trie.get("none".bytes()), None.as_ref());

for (k, v) in trie.iter() {
println!("kv: {:?} {}", k, v);
Expand Down
217 changes: 72 additions & 145 deletions src/trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@ use serde::{Deserialize, Serialize};
use std::clone::Clone;
use std::cmp::{Eq, Ord};

/// Prefix tree object
/// Prefix tree object, contains 1 field for the `root` node of the tree
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct Trie<K, V> {
pub struct Trie<K: Eq + Ord + Clone, V> {
/// Root of the prefix tree
nodes: Vec<TrieNode<K>>,
values: Vec<V>,
root: TrieNode<K, V>,
}

impl<K: Eq + Ord + Clone, V: Clone> Trie<K, V> {
Expand All @@ -26,10 +25,9 @@ impl<K: Eq + Ord + Clone, V: Clone> Trie<K, V> {
///
/// let t = Trie::<char, String>::new();
/// ```
pub fn new() -> Trie<K, V> {
pub fn new() -> Self {
Trie {
nodes: Vec::<TrieNode<K>>::new(),
values: Vec::<V>::new(),
root: TrieNode::default(),
}
}

Expand All @@ -44,7 +42,7 @@ impl<K: Eq + Ord + Clone, V: Clone> Trie<K, V> {
/// assert!(t.is_empty());
/// ```
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
self.root.children.is_empty()
}

/// Adds a new key to the `Trie`
Expand All @@ -62,41 +60,7 @@ impl<K: Eq + Ord + Clone, V: Clone> Trie<K, V> {
/// assert!(!t.is_empty());
/// ```
pub fn insert<I: Iterator<Item = K>>(&mut self, key: I, value: V) {
let mut node_id = if self.is_empty() {
self.create_new_node()
} else {
0usize
};
for c in key {
if let Some(id) = self.nodes[node_id].find(&c) {
node_id = id;
} else {
let new_node_id = self.create_new_node();
self.nodes[node_id].insert(&c, new_node_id);
node_id = new_node_id;
}
}
// NOTE: nicer syntax, but some lines missed by coverage
// for c in key {
// node_id = self.nodes[node_id]
// .find(&c)
// .unwrap_or_else(|| {
// let new_node_id = self.create_new_node();
// self.nodes[node_id].insert(&c, new_node_id);
// new_node_id
// });
// }
let value_id = match self.nodes[node_id].get_value() {
Some(id) => {
self.values[id] = value;
id
}
None => {
self.values.push(value);
self.values.len() - 1
}
};
self.nodes[node_id].set_value(value_id);
self.root.insert(key, value);
}

/// Clears the trie
Expand All @@ -114,8 +78,7 @@ impl<K: Eq + Ord + Clone, V: Clone> Trie<K, V> {
/// assert!(t.is_empty());
/// ```
pub fn clear(&mut self) {
self.nodes.clear();
self.values.clear();
self.root = TrieNode::default();
}

/// Looks for the key in trie
Expand All @@ -136,11 +99,12 @@ impl<K: Eq + Ord + Clone, V: Clone> Trie<K, V> {
/// assert!(!t.contains_key(another_data));
/// ```
pub fn contains_key<I: Iterator<Item = K>>(&self, key: I) -> bool {
if self.values.is_empty() && self.nodes.is_empty() {
if self.is_empty() {
return false;
}
// self.root.find_node(key).is_some()
match self.find_node(key) {
Some(node_id) => self.nodes[node_id].may_be_leaf(),
Some(node) => node.may_be_leaf(),
None => false,
}
}
Expand All @@ -162,10 +126,7 @@ impl<K: Eq + Ord + Clone, V: Clone> Trie<K, V> {
/// assert_eq!(t.get(another_data), None);
/// ```
pub fn get<I: Iterator<Item = K>>(&self, key: I) -> Option<&V> {
self.find_node(key)
.and_then(|node_id| self.nodes[node_id].get_value())
.and_then(|value_id| self.values.get(value_id))
// .cloned()
self.find_node(key).and_then(|node| node.get_value())
}

/// Sets the value pointed by a key
Expand All @@ -189,16 +150,9 @@ impl<K: Eq + Ord + Clone, V: Clone> Trie<K, V> {
/// .is_err());
/// ```
pub fn set_value<I: Iterator<Item = K>>(&mut self, key: I, value: V) -> Result<(), TrieError> {
self.find_node(key)
self.find_node_mut(key)
.ok_or_else(|| TrieError::NotFound("Key not found".to_string()))
.and_then(|node_id| {
self.nodes[node_id]
.get_value()
.ok_or_else(|| TrieError::NotFound(format!("Value not found {}", node_id)))
.map(|value_id| {
self.values[value_id] = value;
})
})
.map(|node| node.set_value(value))
}

/// Returns a list of all prefixes in the trie for a given string, ordered from smaller to longer.
Expand All @@ -214,19 +168,19 @@ impl<K: Eq + Ord + Clone, V: Clone> Trie<K, V> {
/// trie.insert("abcde".bytes(), "ABCDE");
///
/// let prefixes = trie.find_prefixes("abcd".bytes());
/// assert_eq!(prefixes, vec!["ABC", "ABCD"]);
/// assert_eq!(trie.find_prefixes("efghij".bytes()), Vec::<&str>::new());
/// assert_eq!(trie.find_prefixes("abz".bytes()), Vec::<&str>::new());
/// assert_eq!(prefixes, vec![&"ABC", &"ABCD"]);
/// assert_eq!(trie.find_prefixes("efghij".bytes()), Vec::<&&str>::new());
/// assert_eq!(trie.find_prefixes("abz".bytes()), Vec::<&&str>::new());
/// ```
pub fn find_prefixes<I: Iterator<Item = K>>(&self, key: I) -> Vec<V> {
let mut node_id = 0usize;
pub fn find_prefixes<I: Iterator<Item = K>>(&self, key: I) -> Vec<&V> {
let mut node = &self.root;
let mut prefixes = Vec::new();
for c in key {
if let Some(child_id) = self.nodes[node_id].find(&c) {
node_id = child_id;
if let Some(value_id) = self.nodes[node_id].get_value() {
prefixes.push(self.values[value_id].clone());
for k in key {
if let Some(next) = node.children.iter().find(|(ckey, _)| ckey == &k).map(|(_, n)| n) {
if let Some(value) = &next.value {
prefixes.push(value);
}
node = next;
} else {
break;
}
Expand All @@ -246,31 +200,30 @@ impl<K: Eq + Ord + Clone, V: Clone> Trie<K, V> {
/// trie.insert("http://purl.obolibrary.org/obo/DOID_".bytes(), "doid");
/// trie.insert("http://purl.obolibrary.org/obo/".bytes(), "obo");
///
/// assert_eq!(trie.find_longest_prefix("http://purl.obolibrary.org/obo/DOID_1234".bytes()), Some("doid"));
/// assert_eq!(trie.find_longest_prefix("http://purl.obolibrary.org/obo/1234".bytes()), Some("obo"));
/// assert_eq!(trie.find_longest_prefix("notthere".bytes()), None);
/// assert_eq!(trie.find_longest_prefix("httno".bytes()), None);
/// assert_eq!(trie.find_longest_prefix("http://purl.obolibrary.org/obo/DOID_1234".bytes()), Some("doid").as_ref());
/// assert_eq!(trie.find_longest_prefix("http://purl.obolibrary.org/obo/1234".bytes()), Some("obo").as_ref());
/// assert_eq!(trie.find_longest_prefix("notthere".bytes()), None.as_ref());
/// assert_eq!(trie.find_longest_prefix("httno".bytes()), None.as_ref());
/// ```
pub fn find_longest_prefix<I: Iterator<Item = K>>(&self, key: I) -> Option<V> {
if self.nodes.is_empty() {
return None;
}
let mut node_id = 0usize;
let mut last_value_id: Option<usize> = None;
for c in key {
if let Some(child_id) = self.nodes[node_id].find(&c) {
node_id = child_id;
if self.nodes[node_id].may_be_leaf() {
last_value_id = self.nodes[node_id].get_value();
pub fn find_longest_prefix<I: Iterator<Item = K>>(&self, key: I) -> Option<&V> {
{
let mut current = &self.root;
let mut last_value: Option<&V> = None.as_ref();
for k in key {
if let Some((_, next_node)) = current.children.iter().find(|(key, _)| key == &k) {
if next_node.value.is_some() {
last_value = next_node.value.as_ref();
}
current = next_node;
} else {
break;
}
} else {
break;
}
last_value
}
last_value_id.map(|id| self.values[id].clone())
}

/// Returns a list of all strings in the trie that start with the given prefix.
/// Returns a list of all strings in the `Trie` that start with the given prefix.
///
/// # Example
///
Expand All @@ -284,58 +237,40 @@ impl<K: Eq + Ord + Clone, V: Clone> Trie<K, V> {
/// trie.insert("apricot".bytes(), "Apricot");
///
/// let strings = trie.find_postfixes("app".bytes());
/// assert_eq!(strings, vec!["App", "Apple", "Applet"]);
/// assert_eq!(trie.find_postfixes("bpp".bytes()), Vec::<&str>::new());
/// assert_eq!(trie.find_postfixes("apzz".bytes()), Vec::<&str>::new());
/// assert_eq!(strings, vec![&"App", &"Apple", &"Applet"]);
/// assert_eq!(trie.find_postfixes("bpp".bytes()), Vec::<&&str>::new());
/// assert_eq!(trie.find_postfixes("apzz".bytes()), Vec::<&&str>::new());
/// ```
pub fn find_postfixes<I: Iterator<Item = K>>(&self, prefix: I) -> Vec<V> {
match self.find_node(prefix) {
Some(node_id) => {
// Collects all values from the subtree rooted at the given node.
let mut values = Vec::new();
self.dfs(node_id, &mut values);
values
}
None => Vec::new(),
pub fn find_postfixes<I: Iterator<Item = K>>(&self, prefix: I) -> Vec<&V> {
let mut postfixes = Vec::new();
if let Some(node) = self.find_node(prefix) {
self.collect_values(node, &mut postfixes);
}
postfixes
}

/// Depth-first search to collect values.
fn dfs(&self, node_id: usize, values: &mut Vec<V>) {
if let Some(value_id) = self.nodes[node_id].get_value() {
values.push(self.values[value_id].clone());
#[allow(clippy::only_used_in_recursion)]
fn collect_values<'a>(&self, node: &'a TrieNode<K, V>, values: &mut Vec<&'a V>) {
if let Some(ref value) = node.value {
values.push(value);
}
for &(_, child_id) in &self.nodes[node_id].children {
self.dfs(child_id, values);
for (_, child) in &node.children {
self.collect_values(child, values);
}
}

/// Finds the node in the trie by the key
/// Finds the node in the `Trie` for a given key
///
/// Internal API
fn find_node<I: Iterator<Item = K>>(&self, key: I) -> Option<usize> {
if self.nodes.is_empty() {
return None;
}
let mut node_id = 0usize;
for c in key {
match self.nodes[node_id].find(&c) {
Some(child_id) => node_id = child_id,
None => return None,
}
}
Some(node_id)
fn find_node<I: Iterator<Item = K>>(&self, key: I) -> Option<&TrieNode<K, V>> {
self.root.find_node(key)
}

/// Creates a new node and returns the node id
///
/// Internal API
fn create_new_node(&mut self) -> usize {
self.nodes.push(TrieNode::new(None));
self.nodes.len() - 1
fn find_node_mut<I: Iterator<Item = K>>(&mut self, key: I) -> Option<&mut TrieNode<K, V>> {
self.root.find_node_mut(key)
}

/// Iterate the nodes in the trie
/// Iterate the nodes in the `Trie`
///
/// # Example
///
Expand Down Expand Up @@ -366,39 +301,31 @@ impl<T: Eq + Ord + Clone, U: Clone> Default for Trie<T, U> {
}

/// Iterator for the `Trie` struct
pub struct TrieIterator<'a, K, V> {
trie: &'a Trie<K, V>,
stack: Vec<(usize, Vec<K>)>, // Stack with node id and current path
pub struct TrieIterator<'a, K: Eq + Ord + Clone, V> {
stack: Vec<(&'a TrieNode<K, V>, Vec<K>)>, // Stack with node reference and current path
}

impl<'a, K, V> TrieIterator<'a, K, V> {
impl<'a, K: Eq + Ord + Clone, V: Clone> TrieIterator<'a, K, V> {
fn new(trie: &'a Trie<K, V>) -> Self {
TrieIterator {
trie,
stack: vec![(0, Vec::new())], // Start with root node and empty path
stack: vec![(&trie.root, Vec::new())], // Start with root node and empty path
}
}
}

impl<'a, K, V> Iterator for TrieIterator<'a, K, V>
where
K: Eq + Ord + Clone,
V: Clone,
{
impl<'a, K: Eq + Ord + Clone, V: Clone> Iterator for TrieIterator<'a, K, V> {
type Item = (Vec<K>, V); // Yield key-value pairs

fn next(&mut self) -> Option<Self::Item> {
while let Some((node_id, path)) = self.stack.pop() {
let node = &self.trie.nodes[node_id];
while let Some((node, path)) = self.stack.pop() {
// Push children to the stack with updated path
for &(ref key_part, child_id) in &node.children {
for (key_part, child) in &node.children {
let mut new_path = path.clone();
new_path.push(key_part.clone());
self.stack.push((child_id, new_path));
self.stack.push((child, new_path));
}
// Return value if it exists
if let Some(value_id) = node.get_value() {
return Some((path, self.trie.values[value_id].clone()));
if let Some(ref value) = node.value {
return Some((path, value.clone()));
}
}
None
Expand Down
Loading

0 comments on commit e7b9531

Please sign in to comment.