Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 32 additions & 36 deletions src/trees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub struct Tree {
current_tree: i32,
advanced: bool,
num_nodes: tsk_size_t,
array_len: tsk_size_t,
flags: TreeFlags,
}

Expand Down Expand Up @@ -58,6 +59,7 @@ impl Tree {
current_tree: 0,
advanced: false,
num_nodes,
array_len: num_nodes + 1,
flags,
}
}
Expand Down Expand Up @@ -100,7 +102,7 @@ impl Tree {
/// }
/// ```
pub fn parent_array(&self) -> &[NodeId] {
tree_array_slice!(self, parent, (*self.inner).num_nodes)
tree_array_slice!(self, parent, self.array_len)
}

/// # Failing examples
Expand Down Expand Up @@ -267,7 +269,7 @@ impl Tree {
/// }
/// ```
pub fn left_sib_array(&self) -> &[NodeId] {
tree_array_slice!(self, left_sib, (*self.inner).num_nodes)
tree_array_slice!(self, left_sib, self.array_len)
}

/// # Failing examples
Expand All @@ -287,7 +289,7 @@ impl Tree {
/// }
/// ```
pub fn right_sib_array(&self) -> &[NodeId] {
tree_array_slice!(self, right_sib, (*self.inner).num_nodes)
tree_array_slice!(self, right_sib, self.array_len)
}

/// # Failing examples
Expand All @@ -307,7 +309,7 @@ impl Tree {
/// }
/// ```
pub fn left_child_array(&self) -> &[NodeId] {
tree_array_slice!(self, left_child, (*self.inner).num_nodes)
tree_array_slice!(self, left_child, self.array_len)
}

/// # Failing examples
Expand All @@ -327,7 +329,7 @@ impl Tree {
/// }
/// ```
pub fn right_child_array(&self) -> &[NodeId] {
tree_array_slice!(self, right_child, (*self.inner).num_nodes)
tree_array_slice!(self, right_child, self.array_len)
}

fn left_sample(&self, u: NodeId) -> Result<NodeId, TskitError> {
Expand Down Expand Up @@ -367,7 +369,7 @@ impl Tree {
///
/// [`TskitError`] if `u` is out of range.
pub fn parent(&self, u: NodeId) -> Result<NodeId, TskitError> {
unsafe_tsk_column_access!(u.0, 0, self.num_nodes, (*self.inner).parent, NodeId)
unsafe_tsk_column_access!(u.0, 0, self.array_len, (*self.inner).parent, NodeId)
}

/// Get the left child of node `u`.
Expand All @@ -376,7 +378,7 @@ impl Tree {
///
/// [`TskitError`] if `u` is out of range.
pub fn left_child(&self, u: NodeId) -> Result<NodeId, TskitError> {
unsafe_tsk_column_access!(u.0, 0, self.num_nodes, (*self.inner).left_child, NodeId)
unsafe_tsk_column_access!(u.0, 0, self.array_len, (*self.inner).left_child, NodeId)
}

/// Get the right child of node `u`.
Expand All @@ -385,7 +387,7 @@ impl Tree {
///
/// [`TskitError`] if `u` is out of range.
pub fn right_child(&self, u: NodeId) -> Result<NodeId, TskitError> {
unsafe_tsk_column_access!(u.0, 0, self.num_nodes, (*self.inner).right_child, NodeId)
unsafe_tsk_column_access!(u.0, 0, self.array_len, (*self.inner).right_child, NodeId)
}

/// Get the left sib of node `u`.
Expand All @@ -394,7 +396,7 @@ impl Tree {
///
/// [`TskitError`] if `u` is out of range.
pub fn left_sib(&self, u: NodeId) -> Result<NodeId, TskitError> {
unsafe_tsk_column_access!(u.0, 0, self.num_nodes, (*self.inner).left_sib, NodeId)
unsafe_tsk_column_access!(u.0, 0, self.array_len, (*self.inner).left_sib, NodeId)
}

/// Get the right sib of node `u`.
Expand All @@ -403,7 +405,7 @@ impl Tree {
///
/// [`TskitError::IndexError`] if `u` is out of range.
pub fn right_sib(&self, u: NodeId) -> Result<NodeId, TskitError> {
unsafe_tsk_column_access!(u.0, 0, self.num_nodes, (*self.inner).right_sib, NodeId)
unsafe_tsk_column_access!(u.0, 0, self.array_len, (*self.inner).right_sib, NodeId)
}

/// Obtain the list of samples for the current tree/tree sequence
Expand Down Expand Up @@ -577,6 +579,11 @@ impl Tree {
};
handle_tsk_return_value!(code, kc)
}

/// Return the virtual root of the tree.
pub fn virtual_root(&self) -> NodeId {
(*self.inner).virtual_root.into()
}
}

impl streaming_iterator::StreamingIterator for Tree {
Expand Down Expand Up @@ -639,7 +646,7 @@ pub enum NodeTraversalOrder {
}

struct PreorderNodeIterator<'a> {
root_stack: Vec<NodeId>,
current_root: NodeId,
node_stack: Vec<NodeId>,
tree: &'a Tree,
current_node_: Option<NodeId>,
Expand All @@ -648,14 +655,15 @@ struct PreorderNodeIterator<'a> {
impl<'a> PreorderNodeIterator<'a> {
fn new(tree: &'a Tree) -> Self {
let mut rv = PreorderNodeIterator {
root_stack: tree.roots_to_vec(),
current_root: tree.right_child(tree.virtual_root()).unwrap(),
node_stack: vec![],
tree,
current_node_: None,
};
rv.root_stack.reverse();
if let Some(root) = rv.root_stack.pop() {
rv.node_stack.push(root);
let mut c = rv.current_root;
while !c.is_null() {
rv.node_stack.push(c);
c = rv.tree.left_sib(c).unwrap();
}
rv
}
Expand All @@ -664,26 +672,14 @@ impl<'a> PreorderNodeIterator<'a> {
impl NodeIterator for PreorderNodeIterator<'_> {
fn next_node(&mut self) {
self.current_node_ = self.node_stack.pop();
match self.current_node_ {
Some(u) => {
// NOTE: process children right-to-left
// because we later pop them from a steck
// to generate the expected left-to-right ordering.
let mut c = self.tree.right_child(u).unwrap();
while c != NodeId::NULL {
self.node_stack.push(c);
c = self.tree.left_sib(c).unwrap();
}
}
None => {
if let Some(r) = self.root_stack.pop() {
self.current_node_ = Some(r);
let mut c = self.tree.right_child(r).unwrap();
while c != NodeId::NULL {
self.node_stack.push(c);
c = self.tree.left_sib(c).unwrap();
}
}
if let Some(u) = self.current_node_ {
// NOTE: process children right-to-left
// because we later pop them from a steck
// to generate the expected left-to-right ordering.
let mut c = self.tree.right_child(u).unwrap();
while c != NodeId::NULL {
self.node_stack.push(c);
c = self.tree.left_sib(c).unwrap();
}
};
}
Expand Down Expand Up @@ -763,7 +759,7 @@ impl<'a> RootIterator<'a> {
fn new(tree: &'a Tree) -> Self {
RootIterator {
current_root: None,
next_root: unsafe { ll_bindings::tsk_tree_get_left_root(tree.as_ptr()).into() },
next_root: tree.left_child(tree.virtual_root()).unwrap(),
tree,
}
}
Expand Down