diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..6d71923 --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1 @@ +chain_width = 70 diff --git a/README.md b/README.md index f904e49..bae3940 100644 --- a/README.md +++ b/README.md @@ -50,10 +50,10 @@ trie.insert("abc".bytes(), "ABC"); trie.insert("abcde".bytes(), "ABCDE"); let prefixes = trie.find_prefixes("abcd".bytes()); -assert_eq!(prefixes, vec!["A", "AB", "ABC"]); +assert_eq!(prefixes, vec![&"A", &"AB", &"ABC"]); let longest = trie.find_longest_prefix("abcd".bytes()); -assert_eq!(longest, Some("ABC")); +assert_eq!(longest, Some("ABC").as_ref()); ``` ### 🔍 Find postfixes @@ -71,7 +71,7 @@ trie.insert("applet".bytes(), "Applet"); trie.insert("apricot".bytes(), "Apricot"); let strings = trie.find_postfixes("app".bytes()); -assert_eq!(strings, vec!["App", "Apple", "Applet"]); +assert_eq!(strings, vec![&"App", &"Apple", &"Applet"]); ``` ### 🔑 Key-based retrieval functions @@ -88,7 +88,7 @@ trie.insert("applet".bytes(), "Applet"); assert!(trie.contains_key("app".bytes())); assert!(!trie.contains_key("not_existing_key".bytes())); assert_eq!(trie.get("app".bytes()), Some("App").as_ref()); -assert_eq!(trie.get("none".bytes()), None); +assert_eq!(trie.get("none".bytes()), None.as_ref()); for (k, v) in trie.iter() { println!("kv: {:?} {}", k, v); diff --git a/src/trie.rs b/src/trie.rs index 7dc7243..86bec7e 100644 --- a/src/trie.rs +++ b/src/trie.rs @@ -7,13 +7,12 @@ use serde::{Deserialize, Serialize}; use std::clone::Clone; use std::cmp::{Eq, Ord}; -/// Prefix tree object +/// Prefix tree object, contains 1 field for the `root` node of the tree #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] -pub struct Trie { +pub struct Trie { /// Root of the prefix tree - nodes: Vec>, - values: Vec, + root: TrieNode, } impl Trie { @@ -26,10 +25,9 @@ impl Trie { /// /// let t = Trie::::new(); /// ``` - pub fn new() -> Trie { + pub fn new() -> Self { Trie { - nodes: Vec::>::new(), - values: Vec::::new(), + root: TrieNode::default(), } } @@ -44,7 +42,7 @@ impl Trie { /// assert!(t.is_empty()); /// ``` pub fn is_empty(&self) -> bool { - self.nodes.is_empty() + self.root.children.is_empty() } /// Adds a new key to the `Trie` @@ -62,41 +60,7 @@ impl Trie { /// assert!(!t.is_empty()); /// ``` pub fn insert>(&mut self, key: I, value: V) { - let mut node_id = if self.is_empty() { - self.create_new_node() - } else { - 0usize - }; - for c in key { - if let Some(id) = self.nodes[node_id].find(&c) { - node_id = id; - } else { - let new_node_id = self.create_new_node(); - self.nodes[node_id].insert(&c, new_node_id); - node_id = new_node_id; - } - } - // NOTE: nicer syntax, but some lines missed by coverage - // for c in key { - // node_id = self.nodes[node_id] - // .find(&c) - // .unwrap_or_else(|| { - // let new_node_id = self.create_new_node(); - // self.nodes[node_id].insert(&c, new_node_id); - // new_node_id - // }); - // } - let value_id = match self.nodes[node_id].get_value() { - Some(id) => { - self.values[id] = value; - id - } - None => { - self.values.push(value); - self.values.len() - 1 - } - }; - self.nodes[node_id].set_value(value_id); + self.root.insert(key, value); } /// Clears the trie @@ -114,8 +78,7 @@ impl Trie { /// assert!(t.is_empty()); /// ``` pub fn clear(&mut self) { - self.nodes.clear(); - self.values.clear(); + self.root = TrieNode::default(); } /// Looks for the key in trie @@ -136,11 +99,12 @@ impl Trie { /// assert!(!t.contains_key(another_data)); /// ``` pub fn contains_key>(&self, key: I) -> bool { - if self.values.is_empty() && self.nodes.is_empty() { + if self.is_empty() { return false; } + // self.root.find_node(key).is_some() match self.find_node(key) { - Some(node_id) => self.nodes[node_id].may_be_leaf(), + Some(node) => node.may_be_leaf(), None => false, } } @@ -162,10 +126,7 @@ impl Trie { /// assert_eq!(t.get(another_data), None); /// ``` pub fn get>(&self, key: I) -> Option<&V> { - self.find_node(key) - .and_then(|node_id| self.nodes[node_id].get_value()) - .and_then(|value_id| self.values.get(value_id)) - // .cloned() + self.find_node(key).and_then(|node| node.get_value()) } /// Sets the value pointed by a key @@ -189,16 +150,9 @@ impl Trie { /// .is_err()); /// ``` pub fn set_value>(&mut self, key: I, value: V) -> Result<(), TrieError> { - self.find_node(key) + self.find_node_mut(key) .ok_or_else(|| TrieError::NotFound("Key not found".to_string())) - .and_then(|node_id| { - self.nodes[node_id] - .get_value() - .ok_or_else(|| TrieError::NotFound(format!("Value not found {}", node_id))) - .map(|value_id| { - self.values[value_id] = value; - }) - }) + .map(|node| node.set_value(value)) } /// Returns a list of all prefixes in the trie for a given string, ordered from smaller to longer. @@ -214,19 +168,19 @@ impl Trie { /// trie.insert("abcde".bytes(), "ABCDE"); /// /// let prefixes = trie.find_prefixes("abcd".bytes()); - /// assert_eq!(prefixes, vec!["ABC", "ABCD"]); - /// assert_eq!(trie.find_prefixes("efghij".bytes()), Vec::<&str>::new()); - /// assert_eq!(trie.find_prefixes("abz".bytes()), Vec::<&str>::new()); + /// assert_eq!(prefixes, vec![&"ABC", &"ABCD"]); + /// assert_eq!(trie.find_prefixes("efghij".bytes()), Vec::<&&str>::new()); + /// assert_eq!(trie.find_prefixes("abz".bytes()), Vec::<&&str>::new()); /// ``` - pub fn find_prefixes>(&self, key: I) -> Vec { - let mut node_id = 0usize; + pub fn find_prefixes>(&self, key: I) -> Vec<&V> { + let mut node = &self.root; let mut prefixes = Vec::new(); - for c in key { - if let Some(child_id) = self.nodes[node_id].find(&c) { - node_id = child_id; - if let Some(value_id) = self.nodes[node_id].get_value() { - prefixes.push(self.values[value_id].clone()); + for k in key { + if let Some(next) = node.children.iter().find(|(ckey, _)| ckey == &k).map(|(_, n)| n) { + if let Some(value) = &next.value { + prefixes.push(value); } + node = next; } else { break; } @@ -246,31 +200,30 @@ impl Trie { /// trie.insert("http://purl.obolibrary.org/obo/DOID_".bytes(), "doid"); /// trie.insert("http://purl.obolibrary.org/obo/".bytes(), "obo"); /// - /// assert_eq!(trie.find_longest_prefix("http://purl.obolibrary.org/obo/DOID_1234".bytes()), Some("doid")); - /// assert_eq!(trie.find_longest_prefix("http://purl.obolibrary.org/obo/1234".bytes()), Some("obo")); - /// assert_eq!(trie.find_longest_prefix("notthere".bytes()), None); - /// assert_eq!(trie.find_longest_prefix("httno".bytes()), None); + /// assert_eq!(trie.find_longest_prefix("http://purl.obolibrary.org/obo/DOID_1234".bytes()), Some("doid").as_ref()); + /// assert_eq!(trie.find_longest_prefix("http://purl.obolibrary.org/obo/1234".bytes()), Some("obo").as_ref()); + /// assert_eq!(trie.find_longest_prefix("notthere".bytes()), None.as_ref()); + /// assert_eq!(trie.find_longest_prefix("httno".bytes()), None.as_ref()); /// ``` - pub fn find_longest_prefix>(&self, key: I) -> Option { - if self.nodes.is_empty() { - return None; - } - let mut node_id = 0usize; - let mut last_value_id: Option = None; - for c in key { - if let Some(child_id) = self.nodes[node_id].find(&c) { - node_id = child_id; - if self.nodes[node_id].may_be_leaf() { - last_value_id = self.nodes[node_id].get_value(); + pub fn find_longest_prefix>(&self, key: I) -> Option<&V> { + { + let mut current = &self.root; + let mut last_value: Option<&V> = None.as_ref(); + for k in key { + if let Some((_, next_node)) = current.children.iter().find(|(key, _)| key == &k) { + if next_node.value.is_some() { + last_value = next_node.value.as_ref(); + } + current = next_node; + } else { + break; } - } else { - break; } + last_value } - last_value_id.map(|id| self.values[id].clone()) } - /// Returns a list of all strings in the trie that start with the given prefix. + /// Returns a list of all strings in the `Trie` that start with the given prefix. /// /// # Example /// @@ -284,58 +237,40 @@ impl Trie { /// trie.insert("apricot".bytes(), "Apricot"); /// /// let strings = trie.find_postfixes("app".bytes()); - /// assert_eq!(strings, vec!["App", "Apple", "Applet"]); - /// assert_eq!(trie.find_postfixes("bpp".bytes()), Vec::<&str>::new()); - /// assert_eq!(trie.find_postfixes("apzz".bytes()), Vec::<&str>::new()); + /// assert_eq!(strings, vec![&"App", &"Apple", &"Applet"]); + /// assert_eq!(trie.find_postfixes("bpp".bytes()), Vec::<&&str>::new()); + /// assert_eq!(trie.find_postfixes("apzz".bytes()), Vec::<&&str>::new()); /// ``` - pub fn find_postfixes>(&self, prefix: I) -> Vec { - match self.find_node(prefix) { - Some(node_id) => { - // Collects all values from the subtree rooted at the given node. - let mut values = Vec::new(); - self.dfs(node_id, &mut values); - values - } - None => Vec::new(), + pub fn find_postfixes>(&self, prefix: I) -> Vec<&V> { + let mut postfixes = Vec::new(); + if let Some(node) = self.find_node(prefix) { + self.collect_values(node, &mut postfixes); } + postfixes } - /// Depth-first search to collect values. - fn dfs(&self, node_id: usize, values: &mut Vec) { - if let Some(value_id) = self.nodes[node_id].get_value() { - values.push(self.values[value_id].clone()); + #[allow(clippy::only_used_in_recursion)] + fn collect_values<'a>(&self, node: &'a TrieNode, values: &mut Vec<&'a V>) { + if let Some(ref value) = node.value { + values.push(value); } - for &(_, child_id) in &self.nodes[node_id].children { - self.dfs(child_id, values); + for (_, child) in &node.children { + self.collect_values(child, values); } } - /// Finds the node in the trie by the key + /// Finds the node in the `Trie` for a given key /// /// Internal API - fn find_node>(&self, key: I) -> Option { - if self.nodes.is_empty() { - return None; - } - let mut node_id = 0usize; - for c in key { - match self.nodes[node_id].find(&c) { - Some(child_id) => node_id = child_id, - None => return None, - } - } - Some(node_id) + fn find_node>(&self, key: I) -> Option<&TrieNode> { + self.root.find_node(key) } - /// Creates a new node and returns the node id - /// - /// Internal API - fn create_new_node(&mut self) -> usize { - self.nodes.push(TrieNode::new(None)); - self.nodes.len() - 1 + fn find_node_mut>(&mut self, key: I) -> Option<&mut TrieNode> { + self.root.find_node_mut(key) } - /// Iterate the nodes in the trie + /// Iterate the nodes in the `Trie` /// /// # Example /// @@ -366,39 +301,31 @@ impl Default for Trie { } /// Iterator for the `Trie` struct -pub struct TrieIterator<'a, K, V> { - trie: &'a Trie, - stack: Vec<(usize, Vec)>, // Stack with node id and current path +pub struct TrieIterator<'a, K: Eq + Ord + Clone, V> { + stack: Vec<(&'a TrieNode, Vec)>, // Stack with node reference and current path } -impl<'a, K, V> TrieIterator<'a, K, V> { +impl<'a, K: Eq + Ord + Clone, V: Clone> TrieIterator<'a, K, V> { fn new(trie: &'a Trie) -> Self { TrieIterator { - trie, - stack: vec![(0, Vec::new())], // Start with root node and empty path + stack: vec![(&trie.root, Vec::new())], // Start with root node and empty path } } } -impl<'a, K, V> Iterator for TrieIterator<'a, K, V> -where - K: Eq + Ord + Clone, - V: Clone, -{ +impl<'a, K: Eq + Ord + Clone, V: Clone> Iterator for TrieIterator<'a, K, V> { type Item = (Vec, V); // Yield key-value pairs - fn next(&mut self) -> Option { - while let Some((node_id, path)) = self.stack.pop() { - let node = &self.trie.nodes[node_id]; + while let Some((node, path)) = self.stack.pop() { // Push children to the stack with updated path - for &(ref key_part, child_id) in &node.children { + for (key_part, child) in &node.children { let mut new_path = path.clone(); new_path.push(key_part.clone()); - self.stack.push((child_id, new_path)); + self.stack.push((child, new_path)); } // Return value if it exists - if let Some(value_id) = node.get_value() { - return Some((path, self.trie.values[value_id].clone())); + if let Some(ref value) = node.value { + return Some((path, value.clone())); } } None diff --git a/src/trie_node.rs b/src/trie_node.rs index f9e0f1b..14ce666 100644 --- a/src/trie_node.rs +++ b/src/trie_node.rs @@ -5,46 +5,69 @@ use serde::{Deserialize, Serialize}; use std::clone::Clone; use std::cmp::{Eq, Ord}; +/// A node in the `Trie`, it holds a value, and a list of children nodes #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug, Clone)] -pub struct TrieNode { - pub value: Option, - pub children: Vec<(T, usize)>, +pub struct TrieNode { + pub value: Option, + pub children: Vec<(K, TrieNode)>, } -impl TrieNode { - pub fn new(value: Option) -> TrieNode { +impl TrieNode { + pub fn new() -> Self { TrieNode { - value, - children: Vec::<(T, usize)>::new(), + value: None, + children: Vec::new(), } } - pub fn find(&self, key: &T) -> Option { - if self.children.is_empty() { - // Slightly improves performance by avoiding closure creation in further code - return None; + /// Insert a node in the trie + pub fn insert>(&mut self, mut key: I, value: V) { + if let Some(part) = key.next() { + if let Some(child) = self.children.iter_mut().find(|child| child.0 == part) { + child.1.insert(key, value); + } else { + let mut new_node = TrieNode::new(); + new_node.insert(key, value); + self.children.push((part, new_node)); + } + } else { + self.value = Some(value); } - if let Ok(idx) = self.children.binary_search_by(|x| x.0.cmp(key)) { - return Some(self.children[idx].1); + } + + /// Recursively find a node searching through children + pub fn find_node>(&self, mut key: I) -> Option<&Self> { + if let Some(p) = key.next() { + self.children.iter().find(|c| c.0 == p)?.1.find_node(key) + } else { + Some(self) } - None } - pub fn insert(&mut self, key: &T, child_id: usize) { - self.children.push((key.clone(), child_id)); - self.children.sort_by(|a, b| a.0.cmp(&b.0)); + pub fn find_node_mut>(&mut self, mut key: I) -> Option<&mut Self> { + if let Some(p) = key.next() { + self.children.iter_mut().find(|c| c.0 == p)?.1.find_node_mut(key) + } else { + Some(self) + } } - pub fn set_value(&mut self, value: usize) { + pub fn set_value(&mut self, value: V) { self.value = Some(value); } - pub fn get_value(&self) -> Option { - self.value + pub fn get_value(&self) -> Option<&V> { + self.value.as_ref() } pub fn may_be_leaf(&self) -> bool { self.value.is_some() } } + +impl Default for TrieNode { + fn default() -> Self { + Self::new() + } +} diff --git a/tests/trie_tests.rs b/tests/trie_tests.rs index 1c3b16a..71c3d1d 100644 --- a/tests/trie_tests.rs +++ b/tests/trie_tests.rs @@ -64,9 +64,9 @@ mod tests { trie.insert("abcd".bytes(), "ABCD"); trie.insert("abcde".bytes(), "ABCDE"); let prefixes = trie.find_prefixes("abcd".bytes()); - assert_eq!(prefixes, vec!["ABC", "ABCD"]); - assert_eq!(trie.find_prefixes("efghij".bytes()), Vec::<&str>::new()); - assert_eq!(trie.find_prefixes("abz".bytes()), Vec::<&str>::new()); + assert_eq!(prefixes, vec![&"ABC", &"ABCD"]); + assert_eq!(trie.find_prefixes("efghij".bytes()), Vec::<&&str>::new()); + assert_eq!(trie.find_prefixes("abz".bytes()), Vec::<&&str>::new()); } #[test]