Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 146 additions & 1 deletion src/trees.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,22 @@ impl Tree {

fn new(ts: &TreeSequence, flags: TreeFlags) -> Result<Self, TskitError> {
let mut tree = Self::wrap(ts.consumed.nodes().num_rows(), flags);
let rv = unsafe { ll_bindings::tsk_tree_init(tree.as_mut_ptr(), ts.as_ptr(), flags.bits) };
let mut rv =
unsafe { ll_bindings::tsk_tree_init(tree.as_mut_ptr(), ts.as_ptr(), flags.bits) };
if rv < 0 {
return Err(TskitError::ErrorCode { code: rv });
}
// Gotta ask Jerome about this one--why isn't this handled in tsk_tree_init??
if !flags.contains(TreeFlags::NO_SAMPLE_COUNTS) {
rv = unsafe {
ll_bindings::tsk_tree_set_tracked_samples(
tree.as_mut_ptr(),
ts.num_samples() as u64,
tree.inner.samples,
)
};
}

handle_tsk_return_value!(rv, tree)
}

Expand Down Expand Up @@ -355,6 +370,39 @@ impl Tree {
false => Ok(b),
}
}

/// Get the number of samples below node `u`.
///
/// # Errors
///
/// * [`TskitError`] if [`TreeFlags::NO_SAMPLE_COUNTS`].
pub fn num_tracked_samples(&self, u: tsk_id_t) -> Result<u64, TskitError> {
let mut n = u64::MAX;
let np: *mut u64 = &mut n;
let code = unsafe { ll_bindings::tsk_tree_get_num_tracked_samples(self.as_ptr(), u, np) };
handle_tsk_return_value!(code, n)
}

/// Calculate the average Kendall-Colijn (`K-C`) distance between
/// pairs of trees whose intervals overlap.
///
/// # Note
///
/// * [Citation](https://doi.org/10.1093/molbev/msw124)
///
/// # Parameters
///
/// * `lambda` specifies the relative weight of topology and branch length.
/// If `lambda` is 0, we only consider topology.
/// If `lambda` is 1, we only consider branch lengths.
pub fn kc_distance(&self, other: &Tree, lambda: f64) -> Result<f64, TskitError> {
let mut kc = f64::NAN;
let kcp: *mut f64 = &mut kc;
let code = unsafe {
ll_bindings::tsk_tree_kc_distance(self.as_ptr(), other.as_ptr(), lambda, kcp)
};
handle_tsk_return_value!(code, kc)
}
}

impl streaming_iterator::StreamingIterator for Tree {
Expand Down Expand Up @@ -763,6 +811,7 @@ impl TreeSequence {
/// # Parameters
///
/// * `lambda` specifies the relative weight of topology and branch length.
/// See [`Tree::kc_distance`] for more details.
pub fn kc_distance(&self, other: &TreeSequence, lambda: f64) -> Result<f64, TskitError> {
let mut kc: f64 = f64::NAN;
let kcp: *mut f64 = &mut kc;
Expand All @@ -771,6 +820,11 @@ impl TreeSequence {
};
handle_tsk_return_value!(code, kc)
}

// FIXME: document
pub fn num_samples(&self) -> tsk_size_t {
unsafe { ll_bindings::tsk_treeseq_get_num_samples(self.as_ptr()) }
}
}

#[cfg(test)]
Expand Down Expand Up @@ -799,6 +853,51 @@ mod test_trees {
tables.tree_sequence().unwrap()
}

fn make_small_table_collection_two_trees() -> TableCollection {
// The two trees are:
// 0
// +++
// | | 1
// | | +++
// 2 3 4 5

// 0
// +-+-+
// 1 |
// +-+-+ |
// 2 4 5 3

let mut tables = TableCollection::new(1000.).unwrap();
tables.add_node(0, 2.0, TSK_NULL, TSK_NULL).unwrap();
tables.add_node(0, 1.0, TSK_NULL, TSK_NULL).unwrap();
tables
.add_node(TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL)
.unwrap();
tables
.add_node(TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL)
.unwrap();
tables
.add_node(TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL)
.unwrap();
tables
.add_node(TSK_NODE_IS_SAMPLE, 0.0, TSK_NULL, TSK_NULL)
.unwrap();
tables.add_edge(500., 1000., 0, 1).unwrap();
tables.add_edge(0., 500., 0, 2).unwrap();
tables.add_edge(0., 1000., 0, 3).unwrap();
tables.add_edge(500., 1000., 1, 2).unwrap();
tables.add_edge(0., 1000., 1, 4).unwrap();
tables.add_edge(0., 1000., 1, 5).unwrap();
tables.full_sort().unwrap();
tables.build_index(0).unwrap();
tables
}

fn treeseq_from_small_table_collection_two_trees() -> TreeSequence {
let tables = make_small_table_collection_two_trees();
tables.tree_sequence().unwrap()
}

#[test]
fn test_create_treeseq_new_from_tables() {
let tables = make_small_table_collection();
Expand Down Expand Up @@ -877,18 +976,46 @@ mod test_trees {
}
}

#[test]
fn test_num_tracked_samples() {
let treeseq = treeseq_from_small_table_collection();
assert_eq!(treeseq.inner.num_samples, 2);
let mut tree_iter = treeseq.tree_iterator(TreeFlags::default()).unwrap();
if let Some(tree) = tree_iter.next() {
assert_eq!(tree.num_tracked_samples(2).unwrap(), 1);
assert_eq!(tree.num_tracked_samples(1).unwrap(), 1);
assert_eq!(tree.num_tracked_samples(0).unwrap(), 2);
}
}

#[should_panic]
#[test]
fn test_num_tracked_samples_not_tracking_samples() {
let treeseq = treeseq_from_small_table_collection();
assert_eq!(treeseq.inner.num_samples, 2);
let mut tree_iter = treeseq.tree_iterator(TreeFlags::NO_SAMPLE_COUNTS).unwrap();
if let Some(tree) = tree_iter.next() {
assert_eq!(tree.num_tracked_samples(2).unwrap(), 0);
assert_eq!(tree.num_tracked_samples(1).unwrap(), 0);
assert_eq!(tree.num_tracked_samples(0).unwrap(), 0);
}
}

#[test]
fn test_iterate_samples() {
let tables = make_small_table_collection();
let treeseq = tables.tree_sequence().unwrap();

let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();
if let Some(tree) = tree_iter.next() {
assert!(!tree.flags.contains(TreeFlags::NO_SAMPLE_COUNTS));
assert!(tree.flags.contains(TreeFlags::SAMPLE_LISTS));
let mut s = vec![];
for i in tree.samples(0).unwrap() {
s.push(i);
}
assert_eq!(s.len(), 2);
assert_eq!(s.len(), tree.num_tracked_samples(0).unwrap() as usize);
assert_eq!(s[0], 1);
assert_eq!(s[1], 2);

Expand All @@ -899,12 +1026,30 @@ mod test_trees {
}
assert_eq!(s.len(), 1);
assert_eq!(s[0], u);
assert_eq!(s.len(), tree.num_tracked_samples(u).unwrap() as usize);
}
} else {
panic!("Expected a tree");
}
}

#[test]
fn test_iterate_samples_two_trees() {
let treeseq = treeseq_from_small_table_collection_two_trees();
assert_eq!(treeseq.inner.num_trees, 2);
let mut tree_iter = treeseq.tree_iterator(TreeFlags::SAMPLE_LISTS).unwrap();
while let Some(tree) = tree_iter.next() {
for n in tree.nodes(NodeTraversalOrder::Preorder) {
let mut nsamples = 0;
for _ in tree.samples(n).unwrap() {
nsamples += 1;
}
assert!(nsamples > 0);
assert_eq!(nsamples, tree.num_tracked_samples(n).unwrap());
}
}
}

#[test]
fn test_kc_distance_naive_test() {
let ts1 = treeseq_from_small_table_collection();
Expand Down