From e7b9531f146b85fd798b59dbac0b01de146042e5 Mon Sep 17 00:00:00 2001 From: Vincent Emonet Date: Sat, 23 Dec 2023 14:13:44 +0100 Subject: [PATCH] perf: Refactor the `Trie` and `TrieNode` structs to replace the 2 vectors used to store values and trie node, by the root node of the trie. The values are now directly stored in the TrieNode (instead of having integers to get values from an array), and the Trie only contains 1 field: the root node of the trie. This reduces amount of effective line of code from 127 to 97. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This seems to have improved performance in some cases (benchmark has been run twice to make sure those numbers are somewhat reproducible): * -30% on prefix/postfix matches (was ~350µs, is now ~230µs) * Sometimes -98% on mismatch (was ~70µs, is now ~1µs sometimes) * Interestingly we see a regression +20% for massive_match, even if match see an improvement of -33% Full Criterion benchmark results (the diff is with the previous implementation with vectors): ``` trie_match time: [6.5473 ns 6.7279 ns 6.9287 ns] change: [-34.375% -32.939% -31.369%] (p = 0.00 < 0.05) Performance has improved. Found 1 outliers among 100 measurements (1.00%) 1 (1.00%) high mild trie_mismatch time: [3.6109 ns 3.6956 ns 3.7926 ns] change: [-46.306% -44.989% -43.610%] (p = 0.00 < 0.05) Performance has improved. Found 1 outliers among 100 measurements (1.00%) 1 (1.00%) high mild trie_massive_match time: [154.76 µs 156.44 µs 158.59 µs] change: [+18.768% +21.621% +24.627%] (p = 0.00 < 0.05) Performance has regressed. Found 5 outliers among 100 measurements (5.00%) 5 (5.00%) high mild trie_massive_mismatch_on_0 time: [25.287 µs 25.821 µs 26.399 µs] change: [-14.458% -11.912% -9.2597%] (p = 0.00 < 0.05) Performance has improved. Found 5 outliers among 100 measurements (5.00%) 5 (5.00%) high mild trie_massive_mismatch_on_1 time: [1.1365 µs 1.1609 µs 1.1929 µs] change: [-98.281% -98.262% -98.240%] (p = 0.00 < 0.05) Performance has improved. Found 7 outliers among 100 measurements (7.00%) 4 (4.00%) high mild 3 (3.00%) high severe trie_massive_mismatch_on_2 time: [24.619 µs 24.728 µs 24.879 µs] change: [-3.7176% -2.9119% -1.7259%] (p = 0.00 < 0.05) Performance has improved. Found 6 outliers among 100 measurements (6.00%) 3 (3.00%) high mild 3 (3.00%) high severe trie_massive_mismatch_on_3 time: [22.300 µs 22.377 µs 22.473 µs] change: [-12.754% -12.221% -11.670%] (p = 0.00 < 0.05) Performance has improved. Found 5 outliers among 100 measurements (5.00%) 4 (4.00%) high mild 1 (1.00%) high severe trie_prefixes_match time: [240.67 µs 247.70 µs 255.57 µs] change: [-30.363% -28.485% -26.708%] (p = 0.00 < 0.05) Performance has improved. Found 1 outliers among 100 measurements (1.00%) 1 (1.00%) high mild trie_postfixes_match time: [227.52 µs 233.84 µs 241.09 µs] change: [-28.965% -26.333% -23.562%] (p = 0.00 < 0.05) Performance has improved. Found 7 outliers among 100 measurements (7.00%) 4 (4.00%) high mild 3 (3.00%) high severe trie_prefix_longest_match time: [133.98 µs 137.34 µs 141.30 µs] change: [-41.923% -39.851% -37.699%] (p = 0.00 < 0.05) Performance has improved. Found 20 outliers among 100 measurements (20.00%) 5 (5.00%) high mild 15 (15.00%) high severe trie_massive_prefixes_match time: [228.73 µs 234.37 µs 241.08 µs] change: [-31.061% -29.485% -27.961%] (p = 0.00 < 0.05) Performance has improved. Found 17 outliers among 100 measurements (17.00%) 6 (6.00%) high mild 11 (11.00%) high severe trie_massive_longest_prefixes_match time: [157.31 µs 160.18 µs 163.82 µs] change: [-28.283% -25.535% -22.555%] (p = 0.00 < 0.05) Performance has improved. Found 11 outliers among 100 measurements (11.00%) 7 (7.00%) high mild 4 (4.00%) high severe trie_massive_postfixes_match time: [226.97 µs 231.64 µs 237.56 µs] change: [-21.204% -18.930% -16.464%] (p = 0.00 < 0.05) Performance has improved. Found 5 outliers among 100 measurements (5.00%) 3 (3.00%) high mild 2 (2.00%) high severe ``` BREAKING CHANGE: the find prefixes/postfixes functions now returns references to value instead of cloned values --- .rustfmt.toml | 1 + README.md | 8 +- src/trie.rs | 217 +++++++++++++++----------------------------- src/trie_node.rs | 63 +++++++++---- tests/trie_tests.rs | 6 +- 5 files changed, 123 insertions(+), 172 deletions(-) create mode 100644 .rustfmt.toml diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..6d71923 --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1 @@ +chain_width = 70 diff --git a/README.md b/README.md index f904e49..bae3940 100644 --- a/README.md +++ b/README.md @@ -50,10 +50,10 @@ trie.insert("abc".bytes(), "ABC"); trie.insert("abcde".bytes(), "ABCDE"); let prefixes = trie.find_prefixes("abcd".bytes()); -assert_eq!(prefixes, vec!["A", "AB", "ABC"]); +assert_eq!(prefixes, vec![&"A", &"AB", &"ABC"]); let longest = trie.find_longest_prefix("abcd".bytes()); -assert_eq!(longest, Some("ABC")); +assert_eq!(longest, Some("ABC").as_ref()); ``` ### 🔍 Find postfixes @@ -71,7 +71,7 @@ trie.insert("applet".bytes(), "Applet"); trie.insert("apricot".bytes(), "Apricot"); let strings = trie.find_postfixes("app".bytes()); -assert_eq!(strings, vec!["App", "Apple", "Applet"]); +assert_eq!(strings, vec![&"App", &"Apple", &"Applet"]); ``` ### 🔑 Key-based retrieval functions @@ -88,7 +88,7 @@ trie.insert("applet".bytes(), "Applet"); assert!(trie.contains_key("app".bytes())); assert!(!trie.contains_key("not_existing_key".bytes())); assert_eq!(trie.get("app".bytes()), Some("App").as_ref()); -assert_eq!(trie.get("none".bytes()), None); +assert_eq!(trie.get("none".bytes()), None.as_ref()); for (k, v) in trie.iter() { println!("kv: {:?} {}", k, v); diff --git a/src/trie.rs b/src/trie.rs index 7dc7243..86bec7e 100644 --- a/src/trie.rs +++ b/src/trie.rs @@ -7,13 +7,12 @@ use serde::{Deserialize, Serialize}; use std::clone::Clone; use std::cmp::{Eq, Ord}; -/// Prefix tree object +/// Prefix tree object, contains 1 field for the `root` node of the tree #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] -pub struct Trie { +pub struct Trie { /// Root of the prefix tree - nodes: Vec>, - values: Vec, + root: TrieNode, } impl Trie { @@ -26,10 +25,9 @@ impl Trie { /// /// let t = Trie::::new(); /// ``` - pub fn new() -> Trie { + pub fn new() -> Self { Trie { - nodes: Vec::>::new(), - values: Vec::::new(), + root: TrieNode::default(), } } @@ -44,7 +42,7 @@ impl Trie { /// assert!(t.is_empty()); /// ``` pub fn is_empty(&self) -> bool { - self.nodes.is_empty() + self.root.children.is_empty() } /// Adds a new key to the `Trie` @@ -62,41 +60,7 @@ impl Trie { /// assert!(!t.is_empty()); /// ``` pub fn insert>(&mut self, key: I, value: V) { - let mut node_id = if self.is_empty() { - self.create_new_node() - } else { - 0usize - }; - for c in key { - if let Some(id) = self.nodes[node_id].find(&c) { - node_id = id; - } else { - let new_node_id = self.create_new_node(); - self.nodes[node_id].insert(&c, new_node_id); - node_id = new_node_id; - } - } - // NOTE: nicer syntax, but some lines missed by coverage - // for c in key { - // node_id = self.nodes[node_id] - // .find(&c) - // .unwrap_or_else(|| { - // let new_node_id = self.create_new_node(); - // self.nodes[node_id].insert(&c, new_node_id); - // new_node_id - // }); - // } - let value_id = match self.nodes[node_id].get_value() { - Some(id) => { - self.values[id] = value; - id - } - None => { - self.values.push(value); - self.values.len() - 1 - } - }; - self.nodes[node_id].set_value(value_id); + self.root.insert(key, value); } /// Clears the trie @@ -114,8 +78,7 @@ impl Trie { /// assert!(t.is_empty()); /// ``` pub fn clear(&mut self) { - self.nodes.clear(); - self.values.clear(); + self.root = TrieNode::default(); } /// Looks for the key in trie @@ -136,11 +99,12 @@ impl Trie { /// assert!(!t.contains_key(another_data)); /// ``` pub fn contains_key>(&self, key: I) -> bool { - if self.values.is_empty() && self.nodes.is_empty() { + if self.is_empty() { return false; } + // self.root.find_node(key).is_some() match self.find_node(key) { - Some(node_id) => self.nodes[node_id].may_be_leaf(), + Some(node) => node.may_be_leaf(), None => false, } } @@ -162,10 +126,7 @@ impl Trie { /// assert_eq!(t.get(another_data), None); /// ``` pub fn get>(&self, key: I) -> Option<&V> { - self.find_node(key) - .and_then(|node_id| self.nodes[node_id].get_value()) - .and_then(|value_id| self.values.get(value_id)) - // .cloned() + self.find_node(key).and_then(|node| node.get_value()) } /// Sets the value pointed by a key @@ -189,16 +150,9 @@ impl Trie { /// .is_err()); /// ``` pub fn set_value>(&mut self, key: I, value: V) -> Result<(), TrieError> { - self.find_node(key) + self.find_node_mut(key) .ok_or_else(|| TrieError::NotFound("Key not found".to_string())) - .and_then(|node_id| { - self.nodes[node_id] - .get_value() - .ok_or_else(|| TrieError::NotFound(format!("Value not found {}", node_id))) - .map(|value_id| { - self.values[value_id] = value; - }) - }) + .map(|node| node.set_value(value)) } /// Returns a list of all prefixes in the trie for a given string, ordered from smaller to longer. @@ -214,19 +168,19 @@ impl Trie { /// trie.insert("abcde".bytes(), "ABCDE"); /// /// let prefixes = trie.find_prefixes("abcd".bytes()); - /// assert_eq!(prefixes, vec!["ABC", "ABCD"]); - /// assert_eq!(trie.find_prefixes("efghij".bytes()), Vec::<&str>::new()); - /// assert_eq!(trie.find_prefixes("abz".bytes()), Vec::<&str>::new()); + /// assert_eq!(prefixes, vec![&"ABC", &"ABCD"]); + /// assert_eq!(trie.find_prefixes("efghij".bytes()), Vec::<&&str>::new()); + /// assert_eq!(trie.find_prefixes("abz".bytes()), Vec::<&&str>::new()); /// ``` - pub fn find_prefixes>(&self, key: I) -> Vec { - let mut node_id = 0usize; + pub fn find_prefixes>(&self, key: I) -> Vec<&V> { + let mut node = &self.root; let mut prefixes = Vec::new(); - for c in key { - if let Some(child_id) = self.nodes[node_id].find(&c) { - node_id = child_id; - if let Some(value_id) = self.nodes[node_id].get_value() { - prefixes.push(self.values[value_id].clone()); + for k in key { + if let Some(next) = node.children.iter().find(|(ckey, _)| ckey == &k).map(|(_, n)| n) { + if let Some(value) = &next.value { + prefixes.push(value); } + node = next; } else { break; } @@ -246,31 +200,30 @@ impl Trie { /// trie.insert("http://purl.obolibrary.org/obo/DOID_".bytes(), "doid"); /// trie.insert("http://purl.obolibrary.org/obo/".bytes(), "obo"); /// - /// assert_eq!(trie.find_longest_prefix("http://purl.obolibrary.org/obo/DOID_1234".bytes()), Some("doid")); - /// assert_eq!(trie.find_longest_prefix("http://purl.obolibrary.org/obo/1234".bytes()), Some("obo")); - /// assert_eq!(trie.find_longest_prefix("notthere".bytes()), None); - /// assert_eq!(trie.find_longest_prefix("httno".bytes()), None); + /// assert_eq!(trie.find_longest_prefix("http://purl.obolibrary.org/obo/DOID_1234".bytes()), Some("doid").as_ref()); + /// assert_eq!(trie.find_longest_prefix("http://purl.obolibrary.org/obo/1234".bytes()), Some("obo").as_ref()); + /// assert_eq!(trie.find_longest_prefix("notthere".bytes()), None.as_ref()); + /// assert_eq!(trie.find_longest_prefix("httno".bytes()), None.as_ref()); /// ``` - pub fn find_longest_prefix>(&self, key: I) -> Option { - if self.nodes.is_empty() { - return None; - } - let mut node_id = 0usize; - let mut last_value_id: Option = None; - for c in key { - if let Some(child_id) = self.nodes[node_id].find(&c) { - node_id = child_id; - if self.nodes[node_id].may_be_leaf() { - last_value_id = self.nodes[node_id].get_value(); + pub fn find_longest_prefix>(&self, key: I) -> Option<&V> { + { + let mut current = &self.root; + let mut last_value: Option<&V> = None.as_ref(); + for k in key { + if let Some((_, next_node)) = current.children.iter().find(|(key, _)| key == &k) { + if next_node.value.is_some() { + last_value = next_node.value.as_ref(); + } + current = next_node; + } else { + break; } - } else { - break; } + last_value } - last_value_id.map(|id| self.values[id].clone()) } - /// Returns a list of all strings in the trie that start with the given prefix. + /// Returns a list of all strings in the `Trie` that start with the given prefix. /// /// # Example /// @@ -284,58 +237,40 @@ impl Trie { /// trie.insert("apricot".bytes(), "Apricot"); /// /// let strings = trie.find_postfixes("app".bytes()); - /// assert_eq!(strings, vec!["App", "Apple", "Applet"]); - /// assert_eq!(trie.find_postfixes("bpp".bytes()), Vec::<&str>::new()); - /// assert_eq!(trie.find_postfixes("apzz".bytes()), Vec::<&str>::new()); + /// assert_eq!(strings, vec![&"App", &"Apple", &"Applet"]); + /// assert_eq!(trie.find_postfixes("bpp".bytes()), Vec::<&&str>::new()); + /// assert_eq!(trie.find_postfixes("apzz".bytes()), Vec::<&&str>::new()); /// ``` - pub fn find_postfixes>(&self, prefix: I) -> Vec { - match self.find_node(prefix) { - Some(node_id) => { - // Collects all values from the subtree rooted at the given node. - let mut values = Vec::new(); - self.dfs(node_id, &mut values); - values - } - None => Vec::new(), + pub fn find_postfixes>(&self, prefix: I) -> Vec<&V> { + let mut postfixes = Vec::new(); + if let Some(node) = self.find_node(prefix) { + self.collect_values(node, &mut postfixes); } + postfixes } - /// Depth-first search to collect values. - fn dfs(&self, node_id: usize, values: &mut Vec) { - if let Some(value_id) = self.nodes[node_id].get_value() { - values.push(self.values[value_id].clone()); + #[allow(clippy::only_used_in_recursion)] + fn collect_values<'a>(&self, node: &'a TrieNode, values: &mut Vec<&'a V>) { + if let Some(ref value) = node.value { + values.push(value); } - for &(_, child_id) in &self.nodes[node_id].children { - self.dfs(child_id, values); + for (_, child) in &node.children { + self.collect_values(child, values); } } - /// Finds the node in the trie by the key + /// Finds the node in the `Trie` for a given key /// /// Internal API - fn find_node>(&self, key: I) -> Option { - if self.nodes.is_empty() { - return None; - } - let mut node_id = 0usize; - for c in key { - match self.nodes[node_id].find(&c) { - Some(child_id) => node_id = child_id, - None => return None, - } - } - Some(node_id) + fn find_node>(&self, key: I) -> Option<&TrieNode> { + self.root.find_node(key) } - /// Creates a new node and returns the node id - /// - /// Internal API - fn create_new_node(&mut self) -> usize { - self.nodes.push(TrieNode::new(None)); - self.nodes.len() - 1 + fn find_node_mut>(&mut self, key: I) -> Option<&mut TrieNode> { + self.root.find_node_mut(key) } - /// Iterate the nodes in the trie + /// Iterate the nodes in the `Trie` /// /// # Example /// @@ -366,39 +301,31 @@ impl Default for Trie { } /// Iterator for the `Trie` struct -pub struct TrieIterator<'a, K, V> { - trie: &'a Trie, - stack: Vec<(usize, Vec)>, // Stack with node id and current path +pub struct TrieIterator<'a, K: Eq + Ord + Clone, V> { + stack: Vec<(&'a TrieNode, Vec)>, // Stack with node reference and current path } -impl<'a, K, V> TrieIterator<'a, K, V> { +impl<'a, K: Eq + Ord + Clone, V: Clone> TrieIterator<'a, K, V> { fn new(trie: &'a Trie) -> Self { TrieIterator { - trie, - stack: vec![(0, Vec::new())], // Start with root node and empty path + stack: vec![(&trie.root, Vec::new())], // Start with root node and empty path } } } -impl<'a, K, V> Iterator for TrieIterator<'a, K, V> -where - K: Eq + Ord + Clone, - V: Clone, -{ +impl<'a, K: Eq + Ord + Clone, V: Clone> Iterator for TrieIterator<'a, K, V> { type Item = (Vec, V); // Yield key-value pairs - fn next(&mut self) -> Option { - while let Some((node_id, path)) = self.stack.pop() { - let node = &self.trie.nodes[node_id]; + while let Some((node, path)) = self.stack.pop() { // Push children to the stack with updated path - for &(ref key_part, child_id) in &node.children { + for (key_part, child) in &node.children { let mut new_path = path.clone(); new_path.push(key_part.clone()); - self.stack.push((child_id, new_path)); + self.stack.push((child, new_path)); } // Return value if it exists - if let Some(value_id) = node.get_value() { - return Some((path, self.trie.values[value_id].clone())); + if let Some(ref value) = node.value { + return Some((path, value.clone())); } } None diff --git a/src/trie_node.rs b/src/trie_node.rs index f9e0f1b..14ce666 100644 --- a/src/trie_node.rs +++ b/src/trie_node.rs @@ -5,46 +5,69 @@ use serde::{Deserialize, Serialize}; use std::clone::Clone; use std::cmp::{Eq, Ord}; +/// A node in the `Trie`, it holds a value, and a list of children nodes #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] -pub struct TrieNode { - pub value: Option, - pub children: Vec<(T, usize)>, +pub struct TrieNode { + pub value: Option, + pub children: Vec<(K, TrieNode)>, } -impl TrieNode { - pub fn new(value: Option) -> TrieNode { +impl TrieNode { + pub fn new() -> Self { TrieNode { - value, - children: Vec::<(T, usize)>::new(), + value: None, + children: Vec::new(), } } - pub fn find(&self, key: &T) -> Option { - if self.children.is_empty() { - // Slightly improves performance by avoiding closure creation in further code - return None; + /// Insert a node in the trie + pub fn insert>(&mut self, mut key: I, value: V) { + if let Some(part) = key.next() { + if let Some(child) = self.children.iter_mut().find(|child| child.0 == part) { + child.1.insert(key, value); + } else { + let mut new_node = TrieNode::new(); + new_node.insert(key, value); + self.children.push((part, new_node)); + } + } else { + self.value = Some(value); } - if let Ok(idx) = self.children.binary_search_by(|x| x.0.cmp(key)) { - return Some(self.children[idx].1); + } + + /// Recursively find a node searching through children + pub fn find_node>(&self, mut key: I) -> Option<&Self> { + if let Some(p) = key.next() { + self.children.iter().find(|c| c.0 == p)?.1.find_node(key) + } else { + Some(self) } - None } - pub fn insert(&mut self, key: &T, child_id: usize) { - self.children.push((key.clone(), child_id)); - self.children.sort_by(|a, b| a.0.cmp(&b.0)); + pub fn find_node_mut>(&mut self, mut key: I) -> Option<&mut Self> { + if let Some(p) = key.next() { + self.children.iter_mut().find(|c| c.0 == p)?.1.find_node_mut(key) + } else { + Some(self) + } } - pub fn set_value(&mut self, value: usize) { + pub fn set_value(&mut self, value: V) { self.value = Some(value); } - pub fn get_value(&self) -> Option { - self.value + pub fn get_value(&self) -> Option<&V> { + self.value.as_ref() } pub fn may_be_leaf(&self) -> bool { self.value.is_some() } } + +impl Default for TrieNode { + fn default() -> Self { + Self::new() + } +} diff --git a/tests/trie_tests.rs b/tests/trie_tests.rs index 1c3b16a..71c3d1d 100644 --- a/tests/trie_tests.rs +++ b/tests/trie_tests.rs @@ -64,9 +64,9 @@ mod tests { trie.insert("abcd".bytes(), "ABCD"); trie.insert("abcde".bytes(), "ABCDE"); let prefixes = trie.find_prefixes("abcd".bytes()); - assert_eq!(prefixes, vec!["ABC", "ABCD"]); - assert_eq!(trie.find_prefixes("efghij".bytes()), Vec::<&str>::new()); - assert_eq!(trie.find_prefixes("abz".bytes()), Vec::<&str>::new()); + assert_eq!(prefixes, vec![&"ABC", &"ABCD"]); + assert_eq!(trie.find_prefixes("efghij".bytes()), Vec::<&&str>::new()); + assert_eq!(trie.find_prefixes("abz".bytes()), Vec::<&&str>::new()); } #[test]