Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: improve Tree iterator ergonomics #384

Merged
merged 1 commit into from Nov 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/tree_traversals.rs
Expand Up @@ -16,7 +16,7 @@ fn traverse_upwards_with_iterator(tree: &tskit::Tree) {
for &s in tree.sample_nodes() {
// _steps_to_root counts the number of steps,
// including the starting node s.
for (_steps_to_root, _) in tree.parents(s).unwrap().enumerate() {}
for (_steps_to_root, _) in tree.parents(s).enumerate() {}
}
}

Expand Down
69 changes: 31 additions & 38 deletions src/tree_interface.rs
Expand Up @@ -399,25 +399,14 @@ impl TreeInterface {
tree_array_slice!(self, samples, num_samples)
}

/// Return an [`Iterator`] from the node `u` to the root of the tree.
///
/// # Returns
///
/// * `Some(iterator)` if `u` us valid
/// * `None` otherwise
#[deprecated(since = "0.2.3", note = "Please use Tree::parents instead")]
pub fn path_to_root(&self, u: NodeId) -> Option<impl Iterator<Item = NodeId> + '_> {
self.parents(u)
}

/// Return an [`Iterator`] from the node `u` to the root of the tree,
/// travering all parent nodes.
///
/// # Returns
///
/// * `Some(iterator)` if `u` is valid
/// * `None` otherwise
pub fn parents(&self, u: NodeId) -> Option<impl Iterator<Item = NodeId> + '_> {
pub fn parents(&self, u: NodeId) -> impl Iterator<Item = NodeId> + '_ {
ParentsIterator::new(self, u)
}

Expand All @@ -426,7 +415,7 @@ impl TreeInterface {
///
/// * `Some(iterator)` if `u` is valid
/// * `None` otherwise
pub fn children(&self, u: NodeId) -> Option<impl Iterator<Item = NodeId> + '_> {
pub fn children(&self, u: NodeId) -> impl Iterator<Item = NodeId> + '_ {
ChildIterator::new(self, u)
}

Expand All @@ -441,10 +430,7 @@ impl TreeInterface {
/// * Some(Ok(iterator)) if [`TreeFlags::SAMPLE_LISTS`] is in [`TreeInterface::flags`]
/// * Some(Err(_)) if [`TreeFlags::SAMPLE_LISTS`] is not in [`TreeInterface::flags`]
/// * None if `u` is not valid.
pub fn samples(
&self,
u: NodeId,
) -> Option<Result<impl Iterator<Item = NodeId> + '_, TskitError>> {
pub fn samples(&self, u: NodeId) -> Result<impl Iterator<Item = NodeId> + '_, TskitError> {
SamplesIterator::new(self, u)
}

Expand Down Expand Up @@ -737,14 +723,17 @@ struct ChildIterator<'a> {
}

impl<'a> ChildIterator<'a> {
fn new(tree: &'a TreeInterface, u: NodeId) -> Option<Self> {
let c = tree.left_child(u)?;
fn new(tree: &'a TreeInterface, u: NodeId) -> Self {
let c = match tree.left_child(u) {
Some(x) => x,
None => NodeId::NULL,
};

Some(ChildIterator {
ChildIterator {
current_child: None,
next_child: c,
tree,
})
}
}
}

Expand Down Expand Up @@ -776,16 +765,15 @@ struct ParentsIterator<'a> {
}

impl<'a> ParentsIterator<'a> {
fn new(tree: &'a TreeInterface, u: NodeId) -> Option<Self> {
let num_nodes = tsk_id_t::try_from(tree.num_nodes).ok()?;

match u {
x if x < num_nodes => Some(ParentsIterator {
current_node: None,
next_node: u,
tree,
}),
_ => None,
fn new(tree: &'a TreeInterface, u: NodeId) -> Self {
let u = match tsk_id_t::try_from(tree.num_nodes) {
Ok(num_nodes) if u < num_nodes => u,
_ => NodeId::NULL,
};
ParentsIterator {
current_node: None,
next_node: u,
tree,
}
}
}
Expand All @@ -797,7 +785,6 @@ impl NodeIterator for ParentsIterator<'_> {
r => {
assert!(r >= 0);
let cr = Some(r);
debug_assert!(self.tree.parent(r).is_some());
self.next_node = self.tree.parent(r).unwrap_or(NodeId::NULL);
cr
}
Expand All @@ -821,18 +808,24 @@ struct SamplesIterator<'a> {
}

impl<'a> SamplesIterator<'a> {
fn new(tree: &'a TreeInterface, u: NodeId) -> Option<Result<Self, TskitError>> {
fn new(tree: &'a TreeInterface, u: NodeId) -> Result<Self, TskitError> {
match tree.flags.contains(TreeFlags::SAMPLE_LISTS) {
false => Some(Err(TskitError::NotTrackingSamples {})),
false => Err(TskitError::NotTrackingSamples {}),
true => {
let next_sample_index = tree.left_sample(u)?;
let last_sample_index = tree.right_sample(u)?;
Some(Ok(SamplesIterator {
let next_sample_index = match tree.left_sample(u) {
Some(x) => x,
None => NodeId::NULL,
};
let last_sample_index = match tree.right_sample(u) {
Some(x) => x,
None => NodeId::NULL,
};
Ok(SamplesIterator {
current_node: None,
next_sample_index,
last_sample_index,
tree,
}))
})
}
}
}
Expand Down
27 changes: 20 additions & 7 deletions src/trees.rs
Expand Up @@ -598,15 +598,28 @@ pub(crate) mod test_trees {
assert_eq!(samples[i - 1], NodeId::from(i as tsk_id_t));

let mut nsteps = 0;
for _ in tree.parents(samples[i - 1]).unwrap() {
for _ in tree.parents(samples[i - 1]) {
nsteps += 1;
}
assert_eq!(nsteps, 2);
}

// These nodes are all out of range
for i in 100..110 {
let mut nsteps = 0;
for _ in tree.parents(i.into()) {
nsteps += 1;
}
assert_eq!(nsteps, 0);
}

assert_eq!(tree.parents((-1_i32).into()).count(), 0);
assert_eq!(tree.children((-1_i32).into()).count(), 0);

let roots = tree.roots_to_vec();
for r in roots.iter() {
let mut num_children = 0;
for _ in tree.children(*r).unwrap() {
for _ in tree.children(*r) {
num_children += 1;
}
assert_eq!(num_children, 2);
Expand Down Expand Up @@ -639,7 +652,7 @@ pub(crate) mod test_trees {
if let Some(tree) = tree_iter.next() {
for n in tree.traverse_nodes(NodeTraversalOrder::Preorder) {
match tree.samples(n) {
Some(Err(_)) => (),
Err(_) => (),
_ => panic!("should not be Ok(_) or None"),
}
}
Expand Down Expand Up @@ -682,7 +695,7 @@ pub(crate) mod test_trees {
assert!(tree.flags().contains(TreeFlags::SAMPLE_LISTS));
let mut s = vec![];

if let Some(Ok(iter)) = tree.samples(0.into()) {
if let Ok(iter) = tree.samples(0.into()) {
for i in iter {
s.push(i);
}
Expand All @@ -697,7 +710,7 @@ pub(crate) mod test_trees {

for u in 1..3 {
let mut s = vec![];
if let Some(Ok(iter)) = tree.samples(u.into()) {
if let Ok(iter) = tree.samples(u.into()) {
for i in iter {
s.push(i);
}
Expand Down Expand Up @@ -741,7 +754,7 @@ pub(crate) mod test_trees {
for n in tree.traverse_nodes(NodeTraversalOrder::Preorder) {
let mut nsamples = 0;
preoder_nodes.push(n);
if let Some(Ok(iter)) = tree.samples(n) {
if let Ok(iter) = tree.samples(n) {
for _ in iter {
nsamples += 1;
}
Expand All @@ -752,7 +765,7 @@ pub(crate) mod test_trees {
for n in tree.traverse_nodes(NodeTraversalOrder::Postorder) {
let mut nsamples = 0;
postoder_nodes.push(n);
if let Some(Ok(iter)) = tree.samples(n) {
if let Ok(iter) = tree.samples(n) {
for _ in iter {
nsamples += 1;
}
Expand Down
12 changes: 4 additions & 8 deletions tests/book_trees.rs
Expand Up @@ -84,14 +84,10 @@ fn initialize_from_table_collection() {
if let Some(parent) = tree.parent(node) {
// Collect the siblings of node into a Vec
// The children function returns another iterator
let _siblings = if let Some(child_iterator) = tree.children(parent) {
child_iterator
.filter(|child| child != &node)
.collect::<Vec<_>>()
} else {
// assign empty vector
vec![]
};
let _siblings = tree
.children(parent)
.filter(|child| child != &node)
.collect::<Vec<_>>();
}
}
}
Expand Down