From 5a734e7bacab85cc407d90f5d1a388bb1d0c97fe Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 27 Oct 2021 15:38:44 +0100 Subject: [PATCH] Fixup num_tracked_samples to work with virtual_root Closes #1724 --- c/tests/test_trees.c | 84 ++++++++++++++++++++++++++++++++++ c/tskit/trees.c | 7 ++- python/tests/test_highlevel.py | 33 ++++++------- python/tskit/trees.py | 10 +--- 4 files changed, 104 insertions(+), 30 deletions(-) diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 59e45e7ddb..cdaa2449f0 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -4639,6 +4639,89 @@ test_single_tree_map_mutations_internal_samples(void) tsk_tree_free(&t); } +static void +test_single_tree_tracked_samples(void) +{ + tsk_treeseq_t ts; + tsk_tree_t tree; + tsk_id_t samples[] = { 0, 1 }; + tsk_size_t n; + int ret; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, + single_tree_ex_sites, single_tree_ex_mutations, NULL, NULL, 0); + + ret = tsk_tree_init(&tree, &ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_tree_set_tracked_samples(&tree, 2, samples); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_get_num_tracked_samples(&tree, 0, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 1); + ret = tsk_tree_get_num_tracked_samples(&tree, 4, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 0); + ret = tsk_tree_get_num_tracked_samples(&tree, tree.virtual_root, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 2); + + ret = tsk_tree_first(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 1); + + ret = tsk_tree_get_num_tracked_samples(&tree, 0, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 1); + ret = tsk_tree_get_num_tracked_samples(&tree, 4, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 2); + ret = tsk_tree_get_num_tracked_samples(&tree, 5, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 0); + ret = tsk_tree_get_num_tracked_samples(&tree, 6, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 2); + ret = tsk_tree_get_num_tracked_samples(&tree, tree.virtual_root, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 2); + + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_get_num_tracked_samples(&tree, 0, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 1); + ret = tsk_tree_get_num_tracked_samples(&tree, 4, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 0); + ret = tsk_tree_get_num_tracked_samples(&tree, tree.virtual_root, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 2); + + ret = tsk_tree_next(&tree); + CU_ASSERT_EQUAL_FATAL(ret, 1); + ret = tsk_tree_get_num_tracked_samples(&tree, 0, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 1); + ret = tsk_tree_get_num_tracked_samples(&tree, 4, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 2); + ret = tsk_tree_get_num_tracked_samples(&tree, tree.virtual_root, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 2); + + ret = tsk_tree_set_tracked_samples(&tree, 0, NULL); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_get_num_tracked_samples(&tree, 0, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 0); + ret = tsk_tree_get_num_tracked_samples(&tree, tree.virtual_root, &n); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(n, 0); + + tsk_treeseq_free(&ts); + tsk_tree_free(&tree); +} + /*======================================================= * Multi tree tests. *======================================================*/ @@ -6746,6 +6829,7 @@ main(int argc, char **argv) { "test_single_tree_map_mutations", test_single_tree_map_mutations }, { "test_single_tree_map_mutations_internal_samples", test_single_tree_map_mutations_internal_samples }, + { "test_single_tree_tracked_samples", test_single_tree_tracked_samples }, /* Multi tree tests */ { "test_simple_multi_tree", test_simple_multi_tree }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 82546dbfe0..300d36e652 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3314,7 +3314,7 @@ tsk_tree_reset_tracked_samples(tsk_tree_t *self) goto out; } tsk_memset(self->num_tracked_samples, 0, - self->num_nodes * sizeof(*self->num_tracked_samples)); + (self->num_nodes + 1) * sizeof(*self->num_tracked_samples)); out: return ret; } @@ -3367,6 +3367,7 @@ tsk_tree_set_tracked_samples_from_sample_list( tsk_id_t u, stop, index; const tsk_id_t *next = other->next_sample; const tsk_id_t *samples = other->tree_sequence->samples; + tsk_size_t num_tracked_samples = 0; if (!tsk_tree_has_sample_lists(other)) { ret = TSK_ERR_UNSUPPORTED_OPERATION; @@ -3385,6 +3386,7 @@ tsk_tree_set_tracked_samples_from_sample_list( stop = other->right_sample[node]; while (true) { u = samples[index]; + num_tracked_samples++; tsk_bug_assert(self->num_tracked_samples[u] == 0); /* Propagate this upwards */ while (u != TSK_NULL) { @@ -3397,6 +3399,7 @@ tsk_tree_set_tracked_samples_from_sample_list( index = next[index]; } } + self->num_tracked_samples[self->virtual_root] = num_tracked_samples; out: return ret; } @@ -4293,7 +4296,7 @@ tsk_tree_clear(tsk_tree_t *self) self->num_tracked_samples[j] = 0; } } - self->num_tracked_samples[self->virtual_root] = 0; + /* The total tracked_samples gets set in set_tracked_samples */ self->num_samples[self->virtual_root] = num_samples; } if (sample_lists) { diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index e903fe8261..54f35fbacc 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -1464,39 +1464,32 @@ def test_compute_mutation_time(self): # Check we have valid times tables.tree_sequence() - def verify_tracked_samples(self, ts): + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_tracked_samples(self, ts): # Should be empty list by default. for tree in ts.trees(): - assert tree.get_num_tracked_samples() == 0 + assert tree.num_tracked_samples() == 0 for u in tree.nodes(): - assert tree.get_num_tracked_samples(u) == 0 + assert tree.num_tracked_samples(u) == 0 samples = list(ts.samples()) tracked_samples = samples[:2] for tree in ts.trees(tracked_samples=tracked_samples): - if len(tree.parent_dict) == 0: - # This is a crude way of checking if we have multiple roots. - # We'll need to fix this code up properly when we support multiple - # roots and remove this check - break - nu = [0 for j in range(ts.get_num_nodes())] - assert tree.get_num_tracked_samples() == len(tracked_samples) + nu = [0 for j in range(ts.num_nodes)] + assert tree.num_tracked_samples() == len(tracked_samples) for j in tracked_samples: u = j while u != tskit.NULL: nu[u] += 1 - u = tree.get_parent(u) + u = tree.parent(u) for u, count in enumerate(nu): - assert tree.get_num_tracked_samples(u) == count - - def test_tracked_samples(self): - for ts in get_example_tree_sequences(): - self.verify_tracked_samples(ts) + assert tree.num_tracked_samples(u) == count + assert tree.num_tracked_samples(tree.virtual_root) == len(tracked_samples) def test_tracked_samples_is_first_arg(self): - for ts in get_example_tree_sequences(): - samples = list(ts.samples())[:2] - for a, b in zip(ts.trees(samples), ts.trees(tracked_samples=samples)): - assert a.get_num_tracked_samples() == b.get_num_tracked_samples() + ts = tskit.Tree.generate_balanced(6).tree_sequence + samples = [0, 1, 2] + tree = next(ts.trees(samples)) + assert tree.num_tracked_samples() == 3 def test_deprecated_sample_aliases(self): for ts in get_example_tree_sequences(): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index fbfb4a104e..81af2a8c0e 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -2106,14 +2106,8 @@ def num_tracked_samples(self, u=None): the subtree rooted at u. :rtype: int """ - # This should work, there's a but somethings wrong somewhere - # https://github.com/tskit-dev/tskit/issues/1724 - # u = self.virtual_root if u is None else u - # return self._ll_tree.get_num_tracked_samples(u) - roots = [u] - if u is None: - roots = self.roots - return sum(self._ll_tree.get_num_tracked_samples(root) for root in roots) + u = self.virtual_root if u is None else u + return self._ll_tree.get_num_tracked_samples(u) # TODO document these traversal arrays # https://github.com/tskit-dev/tskit/issues/1788