From fbdabe45a1f947b489fbd1d10c0ea849dd3837d3 Mon Sep 17 00:00:00 2001 From: molpopgen Date: Tue, 7 Dec 2021 15:35:05 -0800 Subject: [PATCH] Add preorder node traversal using latest tskit C API. Closes #177 Fix preorder stack population for when root changes within a tree. Fixes #187 --- src/trees.rs | 104 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/src/trees.rs b/src/trees.rs index a62f6b651..335f20f4c 100644 --- a/src/trees.rs +++ b/src/trees.rs @@ -497,6 +497,7 @@ impl Tree { ) -> Box + '_> { match order { NodeTraversalOrder::Preorder => Box::new(PreorderNodeIterator::new(self)), + NodeTraversalOrder::Postorder => Box::new(PostorderNodeIterator::new(self)), } } @@ -621,6 +622,10 @@ pub enum NodeTraversalOrder { ///For trees with multiple roots, start at the left root, ///traverse to tips, proceeed to the next root, etc.. Preorder, + ///Postorder traversal, starting at the root(s) of a [`Tree`]. + ///For trees with multiple roots, start at the left root, + ///traverse to tips, proceeed to the next root, etc.. + Postorder, } struct PreorderNodeIterator<'a> { @@ -663,6 +668,11 @@ impl NodeIterator for PreorderNodeIterator<'_> { None => { if let Some(r) = self.root_stack.pop() { self.current_node_ = Some(r); + let mut c = self.tree.right_child(r).unwrap(); + while c != NodeId::NULL { + self.node_stack.push(c); + c = self.tree.left_sib(c).unwrap(); + } } } }; @@ -675,6 +685,64 @@ impl NodeIterator for PreorderNodeIterator<'_> { iterator_for_nodeiterator!(PreorderNodeIterator<'_>); +struct PostorderNodeIterator<'a> { + nodes: Vec, + current_node_index: usize, + num_nodes_current_tree: usize, + // Make the lifetime checker happy. + tree: std::marker::PhantomData<&'a Tree>, +} + +impl<'a> PostorderNodeIterator<'a> { + fn new(tree: &'a Tree) -> Self { + let mut num_nodes_current_tree: usize = 0; + let ptr = std::ptr::addr_of_mut!(num_nodes_current_tree); + let mut nodes = vec![ + NodeId::NULL; + // NOTE: this fn does not return error codes + unsafe { ll_bindings::tsk_tree_get_size_bound(tree.inner) } as usize + ]; + + let rv = unsafe { + ll_bindings::tsk_tree_postorder( + tree.inner, + NodeId::NULL.into(), // start from virtual root + nodes.as_mut_ptr() as *mut tsk_id_t, + ptr as *mut tsk_size_t, + ) + }; + + // This is either out of memory + // or node out of range. + // The former is fatal, and the latter + // not relevant (for now), as we start at roots. + if rv < 0 { + panic!("fatal error calculating postoder node list"); + } + + Self { + nodes, + current_node_index: 0, + num_nodes_current_tree, + tree: std::marker::PhantomData, + } + } +} + +impl<'a> Iterator for PostorderNodeIterator<'a> { + type Item = NodeId; + fn next(&mut self) -> Option { + match self.current_node_index < self.num_nodes_current_tree { + true => { + let rv = self.nodes[self.current_node_index]; + self.current_node_index += 1; + Some(rv) + } + false => None, + } + } +} + struct RootIterator<'a> { current_root: Option, next_root: NodeId, @@ -1320,14 +1388,50 @@ pub(crate) mod test_trees { assert_eq!(treeseq.num_trees(), 2); let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap(); while let Some(tree) = tree_iter.next() { + let mut preoder_nodes = vec![]; + let mut postoder_nodes = vec![]; for n in tree.traverse_nodes(NodeTraversalOrder::Preorder) { let mut nsamples = 0; + preoder_nodes.push(n); for _ in tree.samples(n).unwrap() { nsamples += 1; } assert!(nsamples > 0); assert_eq!(nsamples, tree.num_tracked_samples(n).unwrap()); } + for n in tree.traverse_nodes(NodeTraversalOrder::Postorder) { + let mut nsamples = 0; + postoder_nodes.push(n); + for _ in tree.samples(n).unwrap() { + nsamples += 1; + } + assert!(nsamples > 0); + assert_eq!(nsamples, tree.num_tracked_samples(n).unwrap()); + } + assert_eq!(preoder_nodes.len(), postoder_nodes.len()); + + // Test our preorder against the tskit functions in 0.99.15 + { + let mut nodes: Vec = vec![ + NodeId::NULL; + unsafe { ll_bindings::tsk_tree_get_size_bound(tree.as_ptr()) } + as usize + ]; + let mut num_nodes: tsk_size_t = 0; + let ptr = std::ptr::addr_of_mut!(num_nodes); + unsafe { + ll_bindings::tsk_tree_preorder( + tree.as_ptr(), + -1, + nodes.as_mut_ptr() as *mut tsk_id_t, + ptr, + ); + } + assert_eq!(num_nodes as usize, preoder_nodes.len()); + for i in 0..num_nodes as usize { + assert_eq!(preoder_nodes[i], nodes[i]); + } + } } }