diff --git a/c/CHANGELOG.rst b/c/CHANGELOG.rst index 9e5f54ca25..5192029a85 100644 --- a/c/CHANGELOG.rst +++ b/c/CHANGELOG.rst @@ -19,11 +19,26 @@ In development. - Change the ``tsk_vargen_init`` method to take an extra parameter ``alleles``. To keep the current behaviour, set this parameter to NULL. + **New features** - Add the ``TSK_KEEP_UNARY`` option to simplify (:user:`gtsambos`). See :issue:`1` and :pr:`143`. +- Add a ``set_root_threshold`` option to tsk_tree_t which allows us to set the + number of samples a node must be an ancestor of to be considered a root + (:pr:`462`). + +- Change the semantics of tsk_tree_t so that sample counts are always + computed, and add a new ``TSK_NO_SAMPLE_COUNTS`` option to turn this + off (:pr:`462`). + + +**Deprecated** + +- The ``TSK_SAMPLE_COUNTS`` options is now ignored and will print out a warning + if used (:pr:`462`). + --------------------- [0.99.2] - 2019-03-27 --------------------- diff --git a/c/dev-tools/dev-cli.c b/c/dev-tools/dev-cli.c index 968b91fc37..92656433ae 100644 --- a/c/dev-tools/dev-cli.c +++ b/c/dev-tools/dev-cli.c @@ -178,7 +178,7 @@ print_tree_sequence(tsk_treeseq_t *ts, int verbose) printf("========================\n"); printf("trees\n"); printf("========================\n"); - ret = tsk_tree_init(&tree, ts, TSK_SAMPLE_COUNTS|TSK_SAMPLE_LISTS); + ret = tsk_tree_init(&tree, ts, TSK_SAMPLE_LISTS); if (ret != 0) { fatal_error("ERROR: %d: %s\n", ret, tsk_strerror(ret)); } diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index 287adfcb6c..90fbb026a4 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -683,7 +683,7 @@ verify_branch_general_stat_identity(tsk_treeseq_t *ts) sigma, TSK_STAT_BRANCH|TSK_STAT_POLARISED|TSK_STAT_SPAN_NORMALISE); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_tree_init(&tree, ts, TSK_SAMPLE_COUNTS); + ret = tsk_tree_init(&tree, ts, 0); CU_ASSERT_EQUAL(ret, 0); for (ret = tsk_tree_first(&tree); ret == 1; ret = tsk_tree_next(&tree)) { diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 11eae1fbe6..43bc5ae5ba 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -67,8 +67,6 @@ check_trees_identical(tsk_tree_t *self, tsk_tree_t *other) CU_ASSERT_FATAL(memcmp(self->right_child, other->right_child, N * sizeof(tsk_id_t)) == 0); CU_ASSERT_FATAL(memcmp(self->left_sib, other->left_sib, N * sizeof(tsk_id_t)) == 0); CU_ASSERT_FATAL(memcmp(self->right_sib, other->right_sib, N * sizeof(tsk_id_t)) == 0); - CU_ASSERT_FATAL(memcmp(self->above_sample, other->above_sample, - N * sizeof(*self->above_sample)) == 0); CU_ASSERT_EQUAL_FATAL(self->num_samples == NULL, other->num_samples == NULL) CU_ASSERT_EQUAL_FATAL(self->num_tracked_samples == NULL, @@ -645,8 +643,8 @@ verify_sample_counts(tsk_treeseq_t *ts, size_t num_tests, sample_count_test_t *t n = tsk_treeseq_get_num_samples(ts); samples = tsk_treeseq_get_samples(ts); - /* First run without the TSK_SAMPLE_COUNTS feature */ - ret = tsk_tree_init(&tree, ts, 0); + /* First run with the TSK_NO_SAMPLE_COUNTS feature */ + ret = tsk_tree_init(&tree, ts, TSK_NO_SAMPLE_COUNTS); CU_ASSERT_EQUAL(ret, 0); ret = tsk_tree_first(&tree); CU_ASSERT_EQUAL_FATAL(ret, 1); @@ -661,11 +659,13 @@ verify_sample_counts(tsk_treeseq_t *ts, size_t num_tests, sample_count_test_t *t /* all operations depending on tracked samples should fail. */ ret = tsk_tree_get_num_tracked_samples(&tree, 0, &num_samples); CU_ASSERT_EQUAL(ret, TSK_ERR_UNSUPPORTED_OPERATION); + /* The root should be NULL */ + CU_ASSERT_EQUAL(tree.left_root, TSK_NULL); } tsk_tree_free(&tree); /* Now run with TSK_SAMPLE_COUNTS but with no samples tracked. */ - ret = tsk_tree_init(&tree, ts, TSK_SAMPLE_COUNTS); + ret = tsk_tree_init(&tree, ts, 0); CU_ASSERT_EQUAL(ret, 0); ret = tsk_tree_first(&tree); CU_ASSERT_EQUAL_FATAL(ret, 1); @@ -681,11 +681,13 @@ verify_sample_counts(tsk_treeseq_t *ts, size_t num_tests, sample_count_test_t *t ret = tsk_tree_get_num_tracked_samples(&tree, 0, &num_samples); CU_ASSERT_EQUAL(ret, 0); CU_ASSERT_EQUAL(num_samples, 0); + /* The root should not be NULL */ + CU_ASSERT_NOT_EQUAL(tree.left_root, TSK_NULL); } tsk_tree_free(&tree); - /* Run with TSK_SAMPLE_LISTS, but without TSK_SAMPLE_COUNTS */ - ret = tsk_tree_init(&tree, ts, TSK_SAMPLE_LISTS); + /* Run with TSK_SAMPLE_LISTS and TSK_NO_SAMPLE_COUNTS */ + ret = tsk_tree_init(&tree, ts, TSK_SAMPLE_LISTS|TSK_NO_SAMPLE_COUNTS); CU_ASSERT_EQUAL(ret, 0); ret = tsk_tree_first(&tree); CU_ASSERT_EQUAL_FATAL(ret, 1); @@ -718,8 +720,8 @@ verify_sample_counts(tsk_treeseq_t *ts, size_t num_tests, sample_count_test_t *t } tsk_tree_free(&tree); - /* Now use TSK_SAMPLE_COUNTS|TSK_SAMPLE_LISTS */ - ret = tsk_tree_init(&tree, ts, TSK_SAMPLE_COUNTS|TSK_SAMPLE_LISTS); + /* Now use TSK_SAMPLE_LISTS */ + ret = tsk_tree_init(&tree, ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL(ret, 0); ret = tsk_tree_set_tracked_samples(&tree, n, samples); CU_ASSERT_EQUAL(ret, 0); @@ -827,7 +829,7 @@ verify_sample_sets(tsk_treeseq_t *ts) int ret; tsk_tree_t t; - ret = tsk_tree_init(&t, ts, TSK_SAMPLE_COUNTS|TSK_SAMPLE_LISTS); + ret = tsk_tree_init(&t, ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL(ret, 0); for (ret = tsk_tree_first(&t); ret == 1; ret = tsk_tree_next(&t)) { @@ -1181,6 +1183,64 @@ test_simplest_zero_root_tree(void) tsk_treeseq_free(&ts); } +static void +test_simplest_multi_root_tree(void) +{ + int ret; + const char *nodes = + "1 0 0\n" + "1 0 0\n" + "1 0 0\n" + "0 1 0\n"; + const char *edges = + "0 1 3 1,2\n"; + tsk_treeseq_t ts; + tsk_tree_t t; + + tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_sequence_length(&ts), 1.0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 4); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 1); + + ret = tsk_tree_init(&t, &ts, 0); + tsk_tree_print_state(&t, _devnull); + /* Make sure the initial roots are set correctly */ + CU_ASSERT_EQUAL(t.left_root, 0); + CU_ASSERT_EQUAL(t.left_sib[0], TSK_NULL); + CU_ASSERT_EQUAL(t.right_sib[0], 1); + CU_ASSERT_EQUAL(t.left_sib[1], 0); + CU_ASSERT_EQUAL(t.right_sib[1], 2); + CU_ASSERT_EQUAL(t.left_sib[2], 1); + CU_ASSERT_EQUAL(t.right_sib[2], TSK_NULL); + + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(ret, 0); + ret = tsk_tree_first(&t); + CU_ASSERT_EQUAL(ret, 1); + CU_ASSERT_EQUAL(tsk_tree_get_num_roots(&t), 2); + CU_ASSERT_EQUAL(t.left_root, 0); + CU_ASSERT_EQUAL(t.right_sib[0], 3); + + tsk_tree_print_state(&t, _devnull); + + CU_ASSERT_EQUAL(tsk_tree_set_root_threshold(&t, 1), TSK_ERR_UNSUPPORTED_OPERATION); + ret = tsk_tree_next(&t); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(tsk_tree_set_root_threshold(&t, 0), TSK_ERR_BAD_PARAM_VALUE); + ret = tsk_tree_set_root_threshold(&t, 2); + CU_ASSERT_EQUAL(ret, 0); + CU_ASSERT_EQUAL(tsk_tree_get_root_threshold(&t), 2); + + ret = tsk_tree_next(&t); + CU_ASSERT_EQUAL(ret, 1); + CU_ASSERT_EQUAL(tsk_tree_get_num_roots(&t), 1); + CU_ASSERT_EQUAL(t.left_root, 3); + + tsk_tree_free(&t); + tsk_treeseq_free(&ts); +} + static void test_simplest_root_mutations(void) { @@ -4714,7 +4774,7 @@ test_tree_errors(void) ret = tsk_tree_init(&t, NULL, 0); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); - ret = tsk_tree_init(&t, &ts, TSK_SAMPLE_COUNTS); + ret = tsk_tree_init(&t, &ts, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_first(&t); CU_ASSERT_EQUAL_FATAL(ret, 1); @@ -4762,9 +4822,9 @@ test_tree_errors(void) tsk_tree_free(&t); tsk_tree_free(&other_t); - ret = tsk_tree_init(&t, &other_ts, 0); + ret = tsk_tree_init(&t, &other_ts, TSK_NO_SAMPLE_COUNTS); CU_ASSERT_EQUAL(ret, 0); - ret = tsk_tree_copy(&t, &other_t, TSK_SAMPLE_COUNTS); + ret = tsk_tree_copy(&t, &other_t, 0); CU_ASSERT_EQUAL(ret, TSK_ERR_UNSUPPORTED_OPERATION); tsk_tree_free(&other_t); ret = tsk_tree_copy(&t, &other_t, TSK_SAMPLE_LISTS); @@ -4784,7 +4844,7 @@ test_tree_copy_flags(void) tsk_treeseq_t ts; tsk_tree_t t, other_t; tsk_flags_t options[] = { - 0, TSK_SAMPLE_COUNTS, TSK_SAMPLE_LISTS, TSK_SAMPLE_COUNTS|TSK_SAMPLE_LISTS}; + 0, TSK_NO_SAMPLE_COUNTS, TSK_SAMPLE_LISTS, TSK_NO_SAMPLE_COUNTS|TSK_SAMPLE_LISTS}; tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, NULL, NULL, paper_ex_individuals, NULL); @@ -5177,6 +5237,33 @@ test_zero_edges(void) tsk_treeseq_free(&tss); } +static void +test_sample_counts_deprecated(void) +{ + tsk_treeseq_t ts; + tsk_tree_t tree; + int ret; + FILE *f = fopen(_tmp_file_name, "w"); + FILE *tmp = stderr; + + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, + NULL, NULL, NULL, NULL, NULL); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_samples(&ts), 4); + + stderr = f; + ret = tsk_tree_init(&tree, &ts, TSK_SAMPLE_COUNTS); + stderr = tmp; + CU_ASSERT_EQUAL_FATAL(ret, 0); + + CU_ASSERT_FATAL(ftell(f) > 0); + + fclose(f); + tsk_tree_free(&tree); + tsk_treeseq_free(&ts); +} + + int main(int argc, char **argv) { @@ -5190,6 +5277,7 @@ main(int argc, char **argv) test_simplest_degenerate_multiple_root_records}, {"test_simplest_multiple_root_records", test_simplest_multiple_root_records}, {"test_simplest_zero_root_tree", test_simplest_zero_root_tree}, + {"test_simplest_multi_root_tree", test_simplest_multi_root_tree}, {"test_simplest_root_mutations", test_simplest_root_mutations}, {"test_simplest_back_mutations", test_simplest_back_mutations}, {"test_simplest_general_samples", test_simplest_general_samples}, @@ -5275,7 +5363,7 @@ main(int argc, char **argv) {"test_nonbinary_sample_sets", test_nonbinary_sample_sets}, {"test_internal_sample_sample_sets", test_internal_sample_sample_sets}, - /*KC_Distance tests */ + /*KC distance tests */ {"test_single_tree_kc", test_single_tree_kc}, {"test_two_trees_kc", test_two_trees_kc}, {"test_empty_tree_kc", test_empty_tree_kc}, @@ -5296,6 +5384,7 @@ main(int argc, char **argv) {"test_deduplicate_sites_multichar", test_deduplicate_sites_multichar}, {"test_empty_tree_sequence", test_empty_tree_sequence}, {"test_zero_edges", test_zero_edges}, + {"test_sample_counts_deprecated", test_sample_counts_deprecated}, {NULL, NULL}, }; diff --git a/c/tskit/core.h b/c/tskit/core.h index ec42f523c9..2e7c6df807 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -80,7 +80,7 @@ to the API or ABI are introduced, i.e., the addition of a new function. The library patch version. Incremented when any changes not relevant to the to the API or ABI are introduced, i.e., internal refactors of bugfixes. */ -#define TSK_VERSION_PATCH 2 +#define TSK_VERSION_PATCH 3 /** @} */ /* Node flags */ diff --git a/c/tskit/haplotype_matching.c b/c/tskit/haplotype_matching.c index db4b094d6b..d75ecb9392 100644 --- a/c/tskit/haplotype_matching.c +++ b/c/tskit/haplotype_matching.c @@ -176,7 +176,7 @@ tsk_ls_hmm_init(tsk_ls_hmm_t *self, tsk_treeseq_t *tree_sequence, self->alleles[l] = _zero_one_alleles; } } - ret = tsk_tree_init(&self->tree, self->tree_sequence, TSK_SAMPLE_COUNTS); + ret = tsk_tree_init(&self->tree, self->tree_sequence, 0); if (ret != 0) { goto out; } diff --git a/c/tskit/stats.c b/c/tskit/stats.c index 784f5279f1..483137a794 100644 --- a/c/tskit/stats.c +++ b/c/tskit/stats.c @@ -78,13 +78,11 @@ tsk_ld_calc_init(tsk_ld_calc_t *self, tsk_treeseq_t *tree_sequence) ret = TSK_ERR_NO_MEMORY; goto out; } - ret = tsk_tree_init(self->outer_tree, self->tree_sequence, - TSK_SAMPLE_COUNTS|TSK_SAMPLE_LISTS); + ret = tsk_tree_init(self->outer_tree, self->tree_sequence, TSK_SAMPLE_LISTS); if (ret != 0) { goto out; } - ret = tsk_tree_init(self->inner_tree, self->tree_sequence, - TSK_SAMPLE_COUNTS); + ret = tsk_tree_init(self->inner_tree, self->tree_sequence, 0); if (ret != 0) { goto out; } diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 3729bde615..69f2c58320 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3053,7 +3053,7 @@ tsk_tree_clear(tsk_tree_t *self) tsk_id_t u; const tsk_size_t N = self->num_nodes; const tsk_size_t num_samples = self->tree_sequence->num_samples; - const bool sample_counts = !!(self->options & TSK_SAMPLE_COUNTS); + const bool sample_counts = !(self->options & TSK_NO_SAMPLE_COUNTS); const bool sample_lists = !!(self->options & TSK_SAMPLE_LISTS); self->left = 0; @@ -3067,7 +3067,7 @@ tsk_tree_clear(tsk_tree_t *self) memset(self->right_child, 0xff, N * sizeof(tsk_id_t)); memset(self->left_sib, 0xff, N * sizeof(tsk_id_t)); memset(self->right_sib, 0xff, N * sizeof(tsk_id_t)); - memset(self->above_sample, 0, N * sizeof(bool)); + if (sample_counts) { memset(self->num_samples, 0, N * sizeof(tsk_id_t)); memset(self->marked, 0, N * sizeof(uint8_t)); @@ -3086,13 +3086,8 @@ tsk_tree_clear(tsk_tree_t *self) memset(self->next_sample, 0xff, num_samples * sizeof(tsk_id_t)); } /* Set the sample attributes */ - self->left_root = TSK_NULL; - if (num_samples > 0) { - self->left_root = self->samples[0]; - } for (j = 0; j < num_samples; j++) { u = self->samples[j]; - self->above_sample[u] = true; if (sample_counts) { self->num_samples[u] = 1; } @@ -3101,12 +3096,19 @@ tsk_tree_clear(tsk_tree_t *self) self->left_sample[u] = (tsk_id_t) j; self->right_sample[u] = (tsk_id_t) j; } - /* Set initial roots */ - if (j < num_samples - 1) { - self->right_sib[u] = self->samples[j + 1]; - } - if (j > 0) { - self->left_sib[u] = self->samples[j - 1]; + } + self->left_root = TSK_NULL; + if (sample_counts && self->root_threshold == 1 && num_samples > 0) { + self->left_root = self->samples[0]; + for (j = 0; j < num_samples; j++) { + /* Set initial roots */ + u = self->samples[j]; + if (j < num_samples - 1) { + self->right_sib[u] = self->samples[j + 1]; + } + if (j > 0) { + self->left_sib[u] = self->samples[j - 1]; + } } } return ret; @@ -3124,21 +3126,26 @@ tsk_tree_init(tsk_tree_t *self, tsk_treeseq_t *tree_sequence, tsk_flags_t option ret = TSK_ERR_BAD_PARAM_VALUE; goto out; } + if (options & TSK_SAMPLE_COUNTS) { + fprintf(stderr, + "TSK_SAMPLE_COUNTS is no longer supported. " + "Sample counts are tracked by default since 0.99.3, " + "Please use TSK_NO_SAMPLE_COUNTS to turn off sample counts."); + } num_nodes = tree_sequence->tables->nodes.num_rows; num_samples = tree_sequence->num_samples; self->num_nodes = num_nodes; self->tree_sequence = tree_sequence; self->samples = tree_sequence->samples; self->options = options; + self->root_threshold = 1; self->parent = malloc(num_nodes * sizeof(tsk_id_t)); self->left_child = malloc(num_nodes * sizeof(tsk_id_t)); self->right_child = malloc(num_nodes * sizeof(tsk_id_t)); self->left_sib = malloc(num_nodes * sizeof(tsk_id_t)); self->right_sib = malloc(num_nodes * sizeof(tsk_id_t)); - self->above_sample = malloc(num_nodes * sizeof(bool)); if (self->parent == NULL || self->left_child == NULL || self->right_child == NULL - || self->left_sib == NULL || self->right_sib == NULL - || self->above_sample == NULL) { + || self->left_sib == NULL || self->right_sib == NULL) { goto out; } /* the maximum possible height of the tree is num_nodes + 1, including @@ -3148,7 +3155,7 @@ tsk_tree_init(tsk_tree_t *self, tsk_treeseq_t *tree_sequence, tsk_flags_t option if (self->stack1 == NULL || self->stack2 == NULL) { goto out; } - if (self->options & TSK_SAMPLE_COUNTS) { + if (! (self->options & TSK_NO_SAMPLE_COUNTS)) { self->num_samples = calloc(num_nodes, sizeof(tsk_id_t)); self->num_tracked_samples = calloc(num_nodes, sizeof(tsk_id_t)); self->marked = calloc(num_nodes, sizeof(uint8_t)); @@ -3171,7 +3178,33 @@ tsk_tree_init(tsk_tree_t *self, tsk_treeseq_t *tree_sequence, tsk_flags_t option return ret; } -int TSK_WARN_UNUSED +int +tsk_tree_set_root_threshold(tsk_tree_t *self, tsk_size_t root_threshold) +{ + int ret = 0; + + if (root_threshold == 0) { + ret = TSK_ERR_BAD_PARAM_VALUE; + goto out; + } + /* Don't allow the value to be set when the tree is out of the null + * state */ + if (self->index != -1) { + ret = TSK_ERR_UNSUPPORTED_OPERATION; + goto out; + } + self->root_threshold = root_threshold; +out: + return ret; +} + +tsk_size_t +tsk_tree_get_root_threshold(tsk_tree_t *self) +{ + return self->root_threshold; +} + +int tsk_tree_free(tsk_tree_t *self) { tsk_safe_free(self->parent); @@ -3179,7 +3212,6 @@ tsk_tree_free(tsk_tree_t *self) tsk_safe_free(self->right_child); tsk_safe_free(self->left_sib); tsk_safe_free(self->right_sib); - tsk_safe_free(self->above_sample); tsk_safe_free(self->stack1); tsk_safe_free(self->stack2); tsk_safe_free(self->num_samples); @@ -3200,7 +3232,7 @@ tsk_tree_has_sample_lists(tsk_tree_t *self) bool tsk_tree_has_sample_counts(tsk_tree_t *self) { - return !!(self->options & TSK_SAMPLE_COUNTS); + return !(self->options & TSK_NO_SAMPLE_COUNTS); } static int TSK_WARN_UNUSED @@ -3323,15 +3355,15 @@ tsk_tree_copy(tsk_tree_t *self, tsk_tree_t *dest, tsk_flags_t options) dest->index = self->index; dest->sites = self->sites; dest->sites_length = self->sites_length; + dest->root_threshold = self->root_threshold; memcpy(dest->parent, self->parent, N * sizeof(tsk_id_t)); memcpy(dest->left_child, self->left_child, N * sizeof(tsk_id_t)); memcpy(dest->right_child, self->right_child, N * sizeof(tsk_id_t)); memcpy(dest->left_sib, self->left_sib, N * sizeof(tsk_id_t)); memcpy(dest->right_sib, self->right_sib, N * sizeof(tsk_id_t)); - memcpy(dest->above_sample, self->above_sample, N * sizeof(*self->above_sample)); - if (dest->options & TSK_SAMPLE_COUNTS) { - if (!(self->options & TSK_SAMPLE_COUNTS)) { + if (!(dest->options & TSK_NO_SAMPLE_COUNTS)) { + if (self->options & TSK_NO_SAMPLE_COUNTS) { ret = TSK_ERR_UNSUPPORTED_OPERATION; goto out; } @@ -3481,7 +3513,7 @@ tsk_tree_get_num_samples(tsk_tree_t *self, tsk_id_t u, size_t *num_samples) goto out; } - if (self->options & TSK_SAMPLE_COUNTS) { + if (! (self->options & TSK_NO_SAMPLE_COUNTS)) { *num_samples = (size_t) self->num_samples[u]; } else { ret = tsk_tree_get_num_samples_by_traversal(self, u, num_samples); @@ -3500,7 +3532,7 @@ tsk_tree_get_num_tracked_samples(tsk_tree_t *self, tsk_id_t u, if (ret != 0) { goto out; } - if (! (self->options & TSK_SAMPLE_COUNTS)) { + if (self->options & TSK_NO_SAMPLE_COUNTS) { ret = TSK_ERR_UNSUPPORTED_OPERATION; goto out; } @@ -3621,7 +3653,7 @@ tsk_tree_check_state(tsk_tree_t *self) assert(site.position < self->right); } - if (self->options & TSK_SAMPLE_COUNTS) { + if (! (self->options & TSK_NO_SAMPLE_COUNTS)) { assert(self->num_samples != NULL); assert(self->num_tracked_samples != NULL); for (u = 0; u < (tsk_id_t) self->num_nodes; u++) { @@ -3655,6 +3687,7 @@ tsk_tree_print_state(tsk_tree_t *self, FILE *out) fprintf(out, "Tree state:\n"); fprintf(out, "options = %d\n", self->options); + fprintf(out, "root_threshold = %d\n", self->root_threshold); fprintf(out, "left = %f\n", self->left); fprintf(out, "right = %f\n", self->right); fprintf(out, "left_root = %d\n", (int) self->left_root); @@ -3666,13 +3699,14 @@ tsk_tree_print_state(tsk_tree_t *self, FILE *out) fprintf(out, "\n"); for (j = 0; j < self->num_nodes; j++) { - fprintf(out, "%d\t%d\t%d\t%d\t%d\t%d", (int) j, self->parent[j], self->left_child[j], - self->right_child[j], self->left_sib[j], self->right_sib[j]); + fprintf(out, "%d\t%d\t%d\t%d\t%d\t%d", (int) j, self->parent[j], + self->left_child[j], self->right_child[j], + self->left_sib[j], self->right_sib[j]); if (self->options & TSK_SAMPLE_LISTS) { fprintf(out, "\t%d\t%d\t", self->left_sample[j], self->right_sample[j]); } - if (self->options & TSK_SAMPLE_COUNTS) { + if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { fprintf(out, "\t%d\t%d\t%d", (int) self->num_samples[j], (int) self->num_tracked_samples[j], self->marked[j]); } @@ -3688,56 +3722,6 @@ tsk_tree_print_state(tsk_tree_t *self, FILE *out) /* Methods for positioning the tree along the sequence */ -/* Implementation note: we're passing the parent array as a restrict pointer - * argument here for performance reasons. The num_samples and num_tracked_samples - * arrays can be accessed through local restrict pointers here because we're not - * accessing them from the calling function. - */ - -static inline void -tsk_tree_propagate_sample_count_loss(tsk_tree_t *self, - const tsk_id_t * restrict tree_parent, tsk_id_t parent, tsk_id_t child) -{ - tsk_id_t v; - const tsk_id_t all_samples_diff = self->num_samples[child]; - const tsk_id_t tracked_samples_diff = self->num_tracked_samples[child]; - const uint8_t mark = self->mark; - tsk_id_t * restrict num_samples = self->num_samples; - tsk_id_t * restrict num_tracked_samples = self->num_tracked_samples; - uint8_t * restrict marked = self->marked; - - /* propagate this loss up as far as we can */ - v = parent; - while (v != TSK_NULL) { - num_samples[v] -= all_samples_diff; - num_tracked_samples[v] -= tracked_samples_diff; - marked[v] = mark; - v = tree_parent[v]; - } -} - -static inline void -tsk_tree_propagate_sample_count_gain(tsk_tree_t *self, - const tsk_id_t * restrict tree_parent, tsk_id_t parent, tsk_id_t child) -{ - tsk_id_t v; - const tsk_id_t all_samples_diff = self->num_samples[child]; - const tsk_id_t tracked_samples_diff = self->num_tracked_samples[child]; - const uint8_t mark = self->mark; - tsk_id_t * restrict num_samples = self->num_samples; - tsk_id_t * restrict num_tracked_samples = self->num_tracked_samples; - uint8_t * restrict marked = self->marked; - - /* propogate this gain up as far as we can */ - v = parent; - while (v != TSK_NULL) { - num_samples[v] += all_samples_diff; - num_tracked_samples[v] += tracked_samples_diff; - marked[v] = mark; - v = tree_parent[v]; - } -} - /* parent, left_child and right_sib are restrict pointers in the calling function, * so we pass these as parameters to ensure the relationships are clear to the * compiler. */ @@ -3778,102 +3762,70 @@ tsk_tree_update_sample_lists(tsk_tree_t *self, } static int -tsk_tree_advance(tsk_tree_t *self, int direction, - const double * restrict out_breakpoints, - const tsk_id_t * restrict out_order, - tsk_id_t *out_index, - const double * restrict in_breakpoints, - const tsk_id_t * restrict in_order, - tsk_id_t *in_index) +tsk_tree_remove_edge(tsk_tree_t *self, tsk_id_t p, tsk_id_t c) { - int ret = 0; - const int direction_change = direction * (direction != self->direction); - tsk_id_t in = *in_index + direction_change; - tsk_id_t out = *out_index + direction_change; - tsk_id_t k, p, c, u, v, root, lsib, rsib, lroot, rroot; - const tsk_table_collection_t *tables = self->tree_sequence->tables; - const double sequence_length = tables->sequence_length; - const tsk_id_t num_edges = (tsk_id_t) tables->edges.num_rows; - const tsk_id_t * restrict edge_parent = tables->edges.parent; - const tsk_id_t * restrict edge_child = tables->edges.child; - const tsk_flags_t * restrict node_flags = tables->nodes.flags; tsk_id_t * restrict parent = self->parent; tsk_id_t * restrict left_child = self->left_child; tsk_id_t * restrict right_child = self->right_child; tsk_id_t * restrict left_sib = self->left_sib; tsk_id_t * restrict right_sib = self->right_sib; - bool * restrict above_sample = self->above_sample; - bool currently_above_sample; - double x; + tsk_id_t * restrict num_samples = self->num_samples; + tsk_id_t * restrict num_tracked_samples = self->num_tracked_samples; + uint8_t * restrict marked = self->marked; + const uint8_t mark = self->mark; + const tsk_id_t root_threshold = (tsk_id_t) self->root_threshold; + tsk_id_t lsib, rsib, u, path_end, lroot, rroot; + bool path_end_was_root; - if (direction == TSK_DIR_FORWARD) { - x = self->right; +#define IS_ROOT(U) (num_samples[U] >= root_threshold) + + lsib = left_sib[c]; + rsib = right_sib[c]; + if (lsib == TSK_NULL) { + left_child[p] = rsib; } else { - x = self->left; + right_sib[lsib] = rsib; } - while (out >= 0 && out < num_edges && out_breakpoints[out_order[out]] == x) { - assert(out < num_edges); - k = out_order[out]; - out += direction; - p = edge_parent[k]; - c = edge_child[k]; - lsib = left_sib[c]; - rsib = right_sib[c]; - if (lsib == TSK_NULL) { - left_child[p] = rsib; - } else { - right_sib[lsib] = rsib; - } - if (rsib == TSK_NULL) { - right_child[p] = lsib; - } else { - left_sib[rsib] = lsib; - } - parent[c] = TSK_NULL; - left_sib[c] = TSK_NULL; - right_sib[c] = TSK_NULL; - if (self->options & TSK_SAMPLE_COUNTS) { - tsk_tree_propagate_sample_count_loss(self, parent, p, c); - } - if (self->options & TSK_SAMPLE_LISTS) { - tsk_tree_update_sample_lists(self, parent, left_child, right_sib, p); - } - - /* Update the roots. If c is not above a sample then we have nothing to do - * as we cannot affect the status of any roots. */ - if (above_sample[c]) { - /* Compute the new above sample status for the nodes from p up to root. */ - v = p; - root = v; - currently_above_sample = false; - while (v != TSK_NULL && !currently_above_sample) { - currently_above_sample = !!(node_flags[v] & TSK_NODE_IS_SAMPLE); - u = left_child[v]; - while (u != TSK_NULL && !currently_above_sample) { - currently_above_sample = above_sample[u]; - u = right_sib[u]; - } - above_sample[v] = currently_above_sample; - root = v; - v = parent[v]; + if (rsib == TSK_NULL) { + right_child[p] = lsib; + } else { + left_sib[rsib] = lsib; + } + parent[c] = TSK_NULL; + left_sib[c] = TSK_NULL; + right_sib[c] = TSK_NULL; + + if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { + /* keep the compiler happy */ + path_end_was_root = false; + path_end = TSK_NULL; + + u = p; + while (u != TSK_NULL) { + path_end = u; + path_end_was_root = IS_ROOT(u); + num_samples[u] -= num_samples[c]; + num_tracked_samples[u] -= num_tracked_samples[c]; + marked[u] = mark; + u = parent[u]; + } + if (path_end_was_root && !IS_ROOT(path_end)) { + /* remove path_end from the list of roots */ + lroot = left_sib[path_end]; + rroot = right_sib[path_end]; + self->left_root = TSK_NULL; + if (lroot != TSK_NULL) { + right_sib[lroot] = rroot; + self->left_root = lroot; } - if (!currently_above_sample) { - /* root is no longer above samples. Remove it from the root list */ - lroot = left_sib[root]; - rroot = right_sib[root]; - self->left_root = TSK_NULL; - if (lroot != TSK_NULL) { - right_sib[lroot] = rroot; - self->left_root = lroot; - } - if (rroot != TSK_NULL) { - left_sib[rroot] = lroot; - self->left_root = rroot; - } - left_sib[root] = TSK_NULL; - right_sib[root] = TSK_NULL; + if (rroot != TSK_NULL) { + left_sib[rroot] = lroot; + self->left_root = rroot; } - /* Add c to the root list */ + left_sib[path_end] = TSK_NULL; + right_sib[path_end] = TSK_NULL; + } + if (IS_ROOT(c)) { if (self->left_root != TSK_NULL) { lroot = left_sib[self->left_root]; if (lroot != TSK_NULL) { @@ -3887,60 +3839,70 @@ tsk_tree_advance(tsk_tree_t *self, int direction, } } - while (in >= 0 && in < num_edges && in_breakpoints[in_order[in]] == x) { - k = in_order[in]; - in += direction; - p = edge_parent[k]; - c = edge_child[k]; - if (parent[c] != TSK_NULL) { - ret = TSK_ERR_BAD_EDGES_CONTRADICTORY_CHILDREN; - goto out; - } - parent[c] = p; - u = right_child[p]; - lsib = left_sib[c]; - rsib = right_sib[c]; - if (u == TSK_NULL) { - left_child[p] = c; - left_sib[c] = TSK_NULL; - right_sib[c] = TSK_NULL; - } else { - right_sib[u] = c; - left_sib[c] = u; - right_sib[c] = TSK_NULL; - } - right_child[p] = c; - if (self->options & TSK_SAMPLE_COUNTS) { - tsk_tree_propagate_sample_count_gain(self, parent, p, c); + if (self->options & TSK_SAMPLE_LISTS) { + tsk_tree_update_sample_lists(self, parent, left_child, right_sib, p); + } + + return 0; +} + +static int +tsk_tree_insert_edge(tsk_tree_t *self, tsk_id_t p, tsk_id_t c) +{ + int ret = 0; + tsk_id_t * restrict parent = self->parent; + tsk_id_t * restrict left_child = self->left_child; + tsk_id_t * restrict right_child = self->right_child; + tsk_id_t * restrict left_sib = self->left_sib; + tsk_id_t * restrict right_sib = self->right_sib; + tsk_id_t * restrict num_samples = self->num_samples; + tsk_id_t * restrict num_tracked_samples = self->num_tracked_samples; + uint8_t * restrict marked = self->marked; + const uint8_t mark = self->mark; + const tsk_id_t root_threshold = (tsk_id_t) self->root_threshold; + tsk_id_t lsib, rsib, u, path_end, lroot; + bool path_end_was_root; + +#define IS_ROOT(U) (num_samples[U] >= root_threshold) + + if (parent[c] != TSK_NULL) { + ret = TSK_ERR_BAD_EDGES_CONTRADICTORY_CHILDREN; + goto out; + } + parent[c] = p; + u = right_child[p]; + lsib = left_sib[c]; + rsib = right_sib[c]; + if (u == TSK_NULL) { + left_child[p] = c; + left_sib[c] = TSK_NULL; + right_sib[c] = TSK_NULL; + } else { + right_sib[u] = c; + left_sib[c] = u; + right_sib[c] = TSK_NULL; + } + right_child[p] = c; + + if (!(self->options & TSK_NO_SAMPLE_COUNTS)) { + + /* keep compiler happy */ + path_end = TSK_NULL; + path_end_was_root = false; + + u = p; + while (u != TSK_NULL) { + path_end = u; + path_end_was_root = IS_ROOT(u); + num_samples[u] += num_samples[c]; + num_tracked_samples[u] += num_tracked_samples[c]; + marked[u] = mark; + u = parent[u]; } - if (self->options & TSK_SAMPLE_LISTS) { - tsk_tree_update_sample_lists(self, parent, left_child, right_sib, p); - } - - /* Update the roots. */ - if (above_sample[c]) { - v = p; - root = v; - currently_above_sample = false; - while (v != TSK_NULL && !currently_above_sample) { - currently_above_sample = above_sample[v]; - above_sample[v] = true; - root = v; - v = parent[v]; - } - if (!currently_above_sample) { - /* Replace c with root in root list */ - if (lsib != TSK_NULL) { - right_sib[lsib] = root; - } - if (rsib != TSK_NULL) { - left_sib[rsib] = root; - } - left_sib[root] = lsib; - right_sib[root] = rsib; - self->left_root = root; - } else { - /* Remove c from root list */ + + if (IS_ROOT(c)) { + if (path_end_was_root) { + /* Remove c from the root list */ self->left_root = TSK_NULL; if (lsib != TSK_NULL) { right_sib[lsib] = rsib; @@ -3950,14 +3912,90 @@ tsk_tree_advance(tsk_tree_t *self, int direction, left_sib[rsib] = lsib; self->left_root = rsib; } + } else { + /* Replace c with path_end in root list */ + if (lsib != TSK_NULL) { + right_sib[lsib] = path_end; + } + if (rsib != TSK_NULL) { + left_sib[rsib] = path_end; + } + left_sib[path_end] = lsib; + right_sib[path_end] = rsib; + self->left_root = path_end; + } + } else { + if (IS_ROOT(path_end) && ! path_end_was_root) { + /* Add a path_end as new root */ + if (self->left_root != TSK_NULL) { + lroot = left_sib[self->left_root]; + if (lroot != TSK_NULL) { + right_sib[lroot] = path_end; + } + left_sib[path_end] = lroot; + left_sib[self->left_root] = path_end; + } + right_sib[path_end] = self->left_root; + self->left_root = path_end; } } } + if (self->options & TSK_SAMPLE_LISTS) { + tsk_tree_update_sample_lists(self, parent, left_child, right_sib, p); + } +out: + return ret; +} + +static int +tsk_tree_advance(tsk_tree_t *self, int direction, + const double * restrict out_breakpoints, + const tsk_id_t * restrict out_order, + tsk_id_t *out_index, + const double * restrict in_breakpoints, + const tsk_id_t * restrict in_order, + tsk_id_t *in_index) +{ + int ret = 0; + const int direction_change = direction * (direction != self->direction); + tsk_id_t in = *in_index + direction_change; + tsk_id_t out = *out_index + direction_change; + tsk_id_t k; + const tsk_table_collection_t *tables = self->tree_sequence->tables; + const double sequence_length = tables->sequence_length; + const tsk_id_t num_edges = (tsk_id_t) tables->edges.num_rows; + const tsk_id_t * restrict edge_parent = tables->edges.parent; + const tsk_id_t * restrict edge_child = tables->edges.child; + double x; + + if (direction == TSK_DIR_FORWARD) { + x = self->right; + } else { + x = self->left; + } + while (out >= 0 && out < num_edges && out_breakpoints[out_order[out]] == x) { + assert(out < num_edges); + k = out_order[out]; + out += direction; + ret = tsk_tree_remove_edge(self, edge_parent[k], edge_child[k]); + if (ret != 0) { + goto out; + } + } + + while (in >= 0 && in < num_edges && in_breakpoints[in_order[in]] == x) { + k = in_order[in]; + in += direction; + ret = tsk_tree_insert_edge(self, edge_parent[k], edge_child[k]); + if (ret != 0) { + goto out; + } + } if (self->left_root != TSK_NULL) { /* Ensure that left_root is the left-most root */ - while (left_sib[self->left_root] != TSK_NULL) { - self->left_root = left_sib[self->left_root]; + while (self->left_sib[self->left_root] != TSK_NULL) { + self->left_root = self->left_sib[self->left_root]; } } @@ -4293,141 +4331,6 @@ tsk_tree_map_mutations(tsk_tree_t *self, int8_t *genotypes, return ret; } -/* ======================================================== * - * Tree diff iterator. - * ======================================================== */ - -int TSK_WARN_UNUSED -tsk_diff_iter_init(tsk_diff_iter_t *self, tsk_treeseq_t *tree_sequence) -{ - int ret = 0; - - assert(tree_sequence != NULL); - memset(self, 0, sizeof(tsk_diff_iter_t)); - self->num_nodes = tsk_treeseq_get_num_nodes(tree_sequence); - self->num_edges = tsk_treeseq_get_num_edges(tree_sequence); - self->tree_sequence = tree_sequence; - self->insertion_index = 0; - self->removal_index = 0; - self->tree_left = 0; - self->tree_index = (size_t) -1; - self->edge_list_nodes = malloc(self->num_edges * sizeof(tsk_edge_list_t)); - if (self->edge_list_nodes == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } -out: - return ret; -} - -int TSK_WARN_UNUSED -tsk_diff_iter_free(tsk_diff_iter_t *self) -{ - int ret = 0; - tsk_safe_free(self->edge_list_nodes); - return ret; -} - -void -tsk_diff_iter_print_state(tsk_diff_iter_t *self, FILE *out) -{ - fprintf(out, "tree_diff_iterator state\n"); - fprintf(out, "num_edges = %d\n", (int) self->num_edges); - fprintf(out, "insertion_index = %d\n", (int) self->insertion_index); - fprintf(out, "removal_index = %d\n", (int) self->removal_index); - fprintf(out, "tree_left = %f\n", self->tree_left); - fprintf(out, "tree_index = %d\n", (int) self->tree_index); -} - -int TSK_WARN_UNUSED -tsk_diff_iter_next(tsk_diff_iter_t *self, double *ret_left, double *ret_right, - tsk_edge_list_t **edges_out, tsk_edge_list_t **edges_in) -{ - int ret = 0; - tsk_id_t k; - const double sequence_length = self->tree_sequence->tables->sequence_length; - double left = self->tree_left; - double right = sequence_length; - size_t next_edge_list_node = 0; - tsk_treeseq_t *s = self->tree_sequence; - tsk_edge_list_t *out_head = NULL; - tsk_edge_list_t *out_tail = NULL; - tsk_edge_list_t *in_head = NULL; - tsk_edge_list_t *in_tail = NULL; - tsk_edge_list_t *w = NULL; - size_t num_trees = tsk_treeseq_get_num_trees(s); - const tsk_edge_table_t *edges = &s->tables->edges; - const tsk_id_t *insertion_order = s->tables->indexes.edge_insertion_order; - const tsk_id_t *removal_order = s->tables->indexes.edge_removal_order; - - if (self->tree_index + 1 < num_trees) { - /* First we remove the stale records */ - while (self->removal_index < self->num_edges && - left == edges->right[removal_order[self->removal_index]]) { - k = removal_order[self->removal_index]; - assert(next_edge_list_node < self->num_edges); - w = &self->edge_list_nodes[next_edge_list_node]; - next_edge_list_node++; - w->edge.id = k; - w->edge.left = edges->left[k]; - w->edge.right = edges->right[k]; - w->edge.parent = edges->parent[k]; - w->edge.child = edges->child[k]; - w->next = NULL; - if (out_head == NULL) { - out_head = w; - out_tail = w; - } else { - out_tail->next = w; - out_tail = w; - } - self->removal_index++; - } - - /* Now insert the new records */ - while (self->insertion_index < self->num_edges && - left == edges->left[insertion_order[self->insertion_index]]) { - k = insertion_order[self->insertion_index]; - assert(next_edge_list_node < self->num_edges); - w = &self->edge_list_nodes[next_edge_list_node]; - next_edge_list_node++; - w->edge.id = k; - w->edge.left = edges->left[k]; - w->edge.right = edges->right[k]; - w->edge.parent = edges->parent[k]; - w->edge.child = edges->child[k]; - w->next = NULL; - if (in_head == NULL) { - in_head = w; - in_tail = w; - } else { - in_tail->next = w; - in_tail = w; - } - self->insertion_index++; - } - right = sequence_length; - if (self->insertion_index < self->num_edges) { - right = TSK_MIN(right, edges->left[ - insertion_order[self->insertion_index]]); - } - if (self->removal_index < self->num_edges) { - right = TSK_MIN(right, edges->right[ - removal_order[self->removal_index]]); - } - self->tree_index++; - ret = 1; - } - *edges_out = out_head; - *edges_in = in_head; - *ret_left = left; - *ret_right = right; - /* Set the left coordinate for the next tree */ - self->tree_left = right; - return ret; -} - - int tsk_tree_kc_distance(tsk_tree_t *self, tsk_tree_t *other, double lambda, double *result) { @@ -4572,3 +4475,137 @@ tsk_tree_kc_distance(tsk_tree_t *self, tsk_tree_t *other, double lambda, double } return ret; } + +/* ======================================================== * + * Tree diff iterator. + * ======================================================== */ + +int TSK_WARN_UNUSED +tsk_diff_iter_init(tsk_diff_iter_t *self, tsk_treeseq_t *tree_sequence) +{ + int ret = 0; + + assert(tree_sequence != NULL); + memset(self, 0, sizeof(tsk_diff_iter_t)); + self->num_nodes = tsk_treeseq_get_num_nodes(tree_sequence); + self->num_edges = tsk_treeseq_get_num_edges(tree_sequence); + self->tree_sequence = tree_sequence; + self->insertion_index = 0; + self->removal_index = 0; + self->tree_left = 0; + self->tree_index = (size_t) -1; + self->edge_list_nodes = malloc(self->num_edges * sizeof(tsk_edge_list_t)); + if (self->edge_list_nodes == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_diff_iter_free(tsk_diff_iter_t *self) +{ + int ret = 0; + tsk_safe_free(self->edge_list_nodes); + return ret; +} + +void +tsk_diff_iter_print_state(tsk_diff_iter_t *self, FILE *out) +{ + fprintf(out, "tree_diff_iterator state\n"); + fprintf(out, "num_edges = %d\n", (int) self->num_edges); + fprintf(out, "insertion_index = %d\n", (int) self->insertion_index); + fprintf(out, "removal_index = %d\n", (int) self->removal_index); + fprintf(out, "tree_left = %f\n", self->tree_left); + fprintf(out, "tree_index = %d\n", (int) self->tree_index); +} + +int TSK_WARN_UNUSED +tsk_diff_iter_next(tsk_diff_iter_t *self, double *ret_left, double *ret_right, + tsk_edge_list_t **edges_out, tsk_edge_list_t **edges_in) +{ + int ret = 0; + tsk_id_t k; + const double sequence_length = self->tree_sequence->tables->sequence_length; + double left = self->tree_left; + double right = sequence_length; + size_t next_edge_list_node = 0; + tsk_treeseq_t *s = self->tree_sequence; + tsk_edge_list_t *out_head = NULL; + tsk_edge_list_t *out_tail = NULL; + tsk_edge_list_t *in_head = NULL; + tsk_edge_list_t *in_tail = NULL; + tsk_edge_list_t *w = NULL; + size_t num_trees = tsk_treeseq_get_num_trees(s); + const tsk_edge_table_t *edges = &s->tables->edges; + const tsk_id_t *insertion_order = s->tables->indexes.edge_insertion_order; + const tsk_id_t *removal_order = s->tables->indexes.edge_removal_order; + + if (self->tree_index + 1 < num_trees) { + /* First we remove the stale records */ + while (self->removal_index < self->num_edges && + left == edges->right[removal_order[self->removal_index]]) { + k = removal_order[self->removal_index]; + assert(next_edge_list_node < self->num_edges); + w = &self->edge_list_nodes[next_edge_list_node]; + next_edge_list_node++; + w->edge.id = k; + w->edge.left = edges->left[k]; + w->edge.right = edges->right[k]; + w->edge.parent = edges->parent[k]; + w->edge.child = edges->child[k]; + w->next = NULL; + if (out_head == NULL) { + out_head = w; + out_tail = w; + } else { + out_tail->next = w; + out_tail = w; + } + self->removal_index++; + } + + /* Now insert the new records */ + while (self->insertion_index < self->num_edges && + left == edges->left[insertion_order[self->insertion_index]]) { + k = insertion_order[self->insertion_index]; + assert(next_edge_list_node < self->num_edges); + w = &self->edge_list_nodes[next_edge_list_node]; + next_edge_list_node++; + w->edge.id = k; + w->edge.left = edges->left[k]; + w->edge.right = edges->right[k]; + w->edge.parent = edges->parent[k]; + w->edge.child = edges->child[k]; + w->next = NULL; + if (in_head == NULL) { + in_head = w; + in_tail = w; + } else { + in_tail->next = w; + in_tail = w; + } + self->insertion_index++; + } + right = sequence_length; + if (self->insertion_index < self->num_edges) { + right = TSK_MIN(right, edges->left[ + insertion_order[self->insertion_index]]); + } + if (self->removal_index < self->num_edges) { + right = TSK_MIN(right, edges->right[ + removal_order[self->removal_index]]); + } + self->tree_index++; + ret = 1; + } + *edges_out = out_head; + *edges_in = in_head; + *ret_left = left; + *ret_right = right; + /* Set the left coordinate for the next tree */ + self->tree_left = right; + return ret; +} diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 1443d31f6d..f85c2eae34 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -36,8 +36,14 @@ extern "C" { #include -#define TSK_SAMPLE_COUNTS (1 << 0) -#define TSK_SAMPLE_LISTS (1 << 1) +/* The TSK_SAMPLE_COUNTS was removed in version 0.99.3, where + * the default is now to always count samples except when + * TSK_NO_SAMPLE_COUNTS is specified. This macro can be undefined + * at some point in the future and the option reused for something + * else. */ +#define TSK_SAMPLE_COUNTS (1 << 0) +#define TSK_SAMPLE_LISTS (1 << 1) +#define TSK_NO_SAMPLE_COUNTS (1 << 2) #define TSK_STAT_SITE (1 << 0) #define TSK_STAT_BRANCH (1 << 1) @@ -143,18 +149,23 @@ typedef struct { tsk_size_t num_nodes; tsk_flags_t options; + tsk_size_t root_threshold; tsk_id_t *samples; /* TODO before documenting this should be change to interval. */ /* Left and right physical coordinates of the tree */ double left; double right; - bool *above_sample; tsk_id_t index; /* These are involved in the optional sample tracking; num_samples counts * all samples below a give node, and num_tracked_samples counts those - * from a specific subset. */ + * from a specific subset. By default sample counts are tracked and roots + * maintained. If TSK_NO_SAMPLE_COUNTS is specified, then neither sample + * counts or roots are available. */ tsk_id_t *num_samples; tsk_id_t *num_tracked_samples; + /* TODO the only place this feature seems to be used is in the ld_calculator. + * when this is being replaced we should come up with a better way of doing + * whatever this is being used for. */ /* All nodes that are marked during a particular transition are marked * with a given value. */ uint8_t *marked; @@ -352,6 +363,9 @@ int tsk_tree_clear(tsk_tree_t *self); void tsk_tree_print_state(tsk_tree_t *self, FILE *out); /** @} */ +int tsk_tree_set_root_threshold(tsk_tree_t *self, tsk_size_t root_threshold); +tsk_size_t tsk_tree_get_root_threshold(tsk_tree_t *self); + bool tsk_tree_has_sample_lists(tsk_tree_t *self); bool tsk_tree_has_sample_counts(tsk_tree_t *self); bool tsk_tree_equals(tsk_tree_t *self, tsk_tree_t *other); diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 157fc78c89..701bcfb2e4 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -19,8 +19,18 @@ In development - User specified allele mapping for genotypes in ``variants`` and ``genotype_matrix`` (:user:`jeromekelleher`, :pr:`430`). +- New ``root_threshold`` option for the Tree class, which allows + us to efficiently iterate over 'real' roots when we have + missing data (:user:`jeromekelleher`, :pr:`462`). + **Bugfixes** +**Deprecated** + +- The ``sample_counts`` feature has been deprecated and is now + ignored. Sample counts are now always computed. + + -------------------- [0.2.3] - 2019-11-22 -------------------- diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index b7a46529dd..80ba4f1ebb 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7715,7 +7715,7 @@ Tree_init(Tree *self, PyObject *args, PyObject *kwds) num_nodes = tsk_treeseq_get_num_nodes(tree_sequence->tree_sequence); num_tracked_samples = 0; if (py_tracked_samples != NULL) { - if (!(options & TSK_SAMPLE_COUNTS)) { + if ((options & TSK_NO_SAMPLE_COUNTS)) { PyErr_SetString(PyExc_ValueError, "Cannot specified tracked_samples without count_samples flag"); goto out; @@ -7749,7 +7749,7 @@ Tree_init(Tree *self, PyObject *args, PyObject *kwds) handle_library_error(err); goto out; } - if (!!(options & TSK_SAMPLE_COUNTS)) { + if (!(options & TSK_NO_SAMPLE_COUNTS)) { err = tsk_tree_set_tracked_samples(self->tree, num_tracked_samples, tracked_samples); if (err != 0) { @@ -8534,6 +8534,37 @@ Tree_get_kc_distance(Tree *self, PyObject *args, PyObject *kwds) return ret; } +static PyObject * +Tree_get_root_threshold(Tree *self) +{ + PyObject *ret = NULL; + + ret = Py_BuildValue("I", + (unsigned int) tsk_tree_get_root_threshold(self->tree)); + return ret; +} + +static PyObject * +Tree_set_root_threshold(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + int err; + unsigned int threshold = 0; + + if (!PyArg_ParseTuple(args, "I", &threshold)) { + goto out; + } + + err = tsk_tree_set_root_threshold(self->tree, threshold); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + static PyMemberDef Tree_members[] = { {NULL} /* Sentinel */ }; @@ -8617,6 +8648,12 @@ static PyMethodDef Tree_methods[] = { {"get_kc_distance", (PyCFunction) Tree_get_kc_distance, METH_VARARGS|METH_KEYWORDS, "Returns the KC distance between this tree and another." }, + {"set_root_threshold", (PyCFunction) Tree_set_root_threshold, + METH_VARARGS, + "Sets the root threshold to the specified value." }, + {"get_root_threshold", (PyCFunction) Tree_get_root_threshold, + METH_NOARGS, + "Returns the root threshold for this tree." }, {NULL} /* Sentinel */ }; @@ -10065,7 +10102,7 @@ PyInit__tskit(void) /* Node flags */ PyModule_AddIntConstant(module, "NODE_IS_SAMPLE", TSK_NODE_IS_SAMPLE); /* Tree flags */ - PyModule_AddIntConstant(module, "SAMPLE_COUNTS", TSK_SAMPLE_COUNTS); + PyModule_AddIntConstant(module, "NO_SAMPLE_COUNTS", TSK_NO_SAMPLE_COUNTS); PyModule_AddIntConstant(module, "SAMPLE_LISTS", TSK_SAMPLE_LISTS); /* Directions */ PyModule_AddIntConstant(module, "FORWARD", TSK_DIR_FORWARD); diff --git a/python/tests/__init__.py b/python/tests/__init__.py index 97d21a22b9..171d1e834c 100644 --- a/python/tests/__init__.py +++ b/python/tests/__init__.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2019 Tskit Developers +# Copyright (c) 2018-2020 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,7 @@ # TODO remove this code and refactor elsewhere. from .simplify import * # NOQA +from . import tsutil import tskit @@ -42,11 +43,8 @@ def __init__(self, num_nodes): self.right_child = [tskit.NULL for _ in range(num_nodes)] self.left_sib = [tskit.NULL for _ in range(num_nodes)] self.right_sib = [tskit.NULL for _ in range(num_nodes)] - self.above_sample = [False for _ in range(num_nodes)] - self.is_sample = [False for _ in range(num_nodes)] self.left = 0 self.right = 0 - self.root = 0 self.index = -1 self.left_root = -1 # We need a sites function, so this name is taken. @@ -191,30 +189,37 @@ def _build_newick(self, node, precision, node_labels): class PythonTreeSequence(object): """ A python implementation of the TreeSequence object. + + TODO this class is of limited use now and should be factored out as + part of a drive towards more modular versions of the tests currently + in tests_highlevel.py. """ def __init__(self, tree_sequence, breakpoints=None): self._tree_sequence = tree_sequence - self._num_samples = tree_sequence.get_num_samples() - self._breakpoints = breakpoints self._sites = [] + # TODO this code here is expressed in terms of the low-level + # tree sequence for legacy reasons. It probably makes more sense + # to describe it in terms of the tables now if we want to have an + # independent implementation. + ll_ts = self._tree_sequence._ll_tree_sequence def make_mutation(id_): - site, node, derived_state, parent, metadata = tree_sequence.get_mutation(id_) + site, node, derived_state, parent, metadata = ll_ts.get_mutation(id_) return tskit.Mutation( id_=id_, site=site, node=node, derived_state=derived_state, parent=parent, metadata=metadata) - for j in range(tree_sequence.get_num_sites()): - pos, ancestral_state, ll_mutations, id_, metadata = tree_sequence.get_site(j) + for j in range(tree_sequence.num_sites): + pos, ancestral_state, ll_mutations, id_, metadata = ll_ts.get_site(j) self._sites.append(tskit.Site( id_=id_, position=pos, ancestral_state=ancestral_state, mutations=[make_mutation(ll_mut) for ll_mut in ll_mutations], metadata=metadata)) def edge_diffs(self): - M = self._tree_sequence.get_num_edges() - sequence_length = self._tree_sequence.get_sequence_length() - edges = [tskit.Edge(*self._tree_sequence.get_edge(j), j) for j in range(M)] - time = [self._tree_sequence.get_node(edge.parent)[1] for edge in edges] + M = self._tree_sequence.num_edges + sequence_length = self._tree_sequence.sequence_length + edges = list(self._tree_sequence.edges()) + time = [self._tree_sequence.node(edge.parent).time for edge in edges] in_order = sorted(range(M), key=lambda j: ( edges[j].left, time[j], edges[j].parent, edges[j].child)) out_order = sorted(range(M), key=lambda j: ( @@ -242,174 +247,24 @@ def edge_diffs(self): left = right def trees(self): - M = self._tree_sequence.get_num_edges() - sequence_length = self._tree_sequence.get_sequence_length() - edges = [ - tskit.Edge(*self._tree_sequence.get_edge(j), j) for j in range(M)] - t = [ - self._tree_sequence.get_node(j)[1] - for j in range(self._tree_sequence.get_num_nodes())] - in_order = sorted( - range(M), key=lambda j: ( - edges[j].left, t[edges[j].parent], edges[j].parent, edges[j].child)) - out_order = sorted( - range(M), key=lambda j: ( - edges[j].right, -t[edges[j].parent], -edges[j].parent, -edges[j].child)) - j = 0 - k = 0 - N = self._tree_sequence.get_num_nodes() - st = PythonTree(N) - - samples = list(self._tree_sequence.get_samples()) - for l in range(len(samples)): - if l < len(samples) - 1: - st.right_sib[samples[l]] = samples[l + 1] - if l > 0: - st.left_sib[samples[l]] = samples[l - 1] - st.above_sample[samples[l]] = True - st.is_sample[samples[l]] = True - - st.left_root = tskit.NULL - if len(samples) > 0: - st.left_root = samples[0] - - u = st.left_root - roots = [] - while u != -1: - roots.append(u) - v = st.right_sib[u] - if v != -1: - assert st.left_sib[v] == u - u = v - - st.left = 0 - while j < M or st.left < sequence_length: - while k < M and edges[out_order[k]].right == st.left: - p = edges[out_order[k]].parent - c = edges[out_order[k]].child - k += 1 - - lsib = st.left_sib[c] - rsib = st.right_sib[c] - if lsib == tskit.NULL: - st.left_child[p] = rsib - else: - st.right_sib[lsib] = rsib - if rsib == tskit.NULL: - st.right_child[p] = lsib - else: - st.left_sib[rsib] = lsib - st.parent[c] = tskit.NULL - st.left_sib[c] = tskit.NULL - st.right_sib[c] = tskit.NULL - - # If c is not above a sample then we have nothing to do as we - # cannot affect the status of any roots. - if st.above_sample[c]: - # Compute the new above sample status for the nodes from - # p up to root. - v = p - above_sample = False - while v != tskit.NULL and not above_sample: - above_sample = st.is_sample[v] - u = st.left_child[v] - while u != tskit.NULL: - above_sample = above_sample or st.above_sample[u] - u = st.right_sib[u] - st.above_sample[v] = above_sample - root = v - v = st.parent[v] - - if not above_sample: - # root is no longer above samples. Remove it from the root list. - lroot = st.left_sib[root] - rroot = st.right_sib[root] - st.left_root = tskit.NULL - if lroot != tskit.NULL: - st.right_sib[lroot] = rroot - st.left_root = lroot - if rroot != tskit.NULL: - st.left_sib[rroot] = lroot - st.left_root = rroot - st.left_sib[root] = tskit.NULL - st.right_sib[root] = tskit.NULL - - # Add c to the root list. - # print("Insert ", c, "into root list") - if st.left_root != tskit.NULL: - lroot = st.left_sib[st.left_root] - if lroot != tskit.NULL: - st.right_sib[lroot] = c - st.left_sib[c] = lroot - st.left_sib[st.left_root] = c - st.right_sib[c] = st.left_root - st.left_root = c - - while j < M and edges[in_order[j]].left == st.left: - p = edges[in_order[j]].parent - c = edges[in_order[j]].child - j += 1 - - # print("insert ", c, "->", p) - st.parent[c] = p - u = st.right_child[p] - lsib = st.left_sib[c] - rsib = st.right_sib[c] - if u == tskit.NULL: - st.left_child[p] = c - st.left_sib[c] = tskit.NULL - st.right_sib[c] = tskit.NULL - else: - st.right_sib[u] = c - st.left_sib[c] = u - st.right_sib[c] = tskit.NULL - st.right_child[p] = c - - if st.above_sample[c]: - v = p - above_sample = False - while v != tskit.NULL and not above_sample: - above_sample = st.above_sample[v] - st.above_sample[v] = st.above_sample[v] or st.above_sample[c] - root = v - v = st.parent[v] - # print("root = ", root, st.above_sample[root]) - - if not above_sample: - # Replace c with root in root list. - # print("replacing", root, "with ", c ," in root list") - if lsib != tskit.NULL: - st.right_sib[lsib] = root - if rsib != tskit.NULL: - st.left_sib[rsib] = root - st.left_sib[root] = lsib - st.right_sib[root] = rsib - st.left_root = root - else: - # Remove c from root list. - # print("remove ", c ," from root list") - st.left_root = tskit.NULL - if lsib != tskit.NULL: - st.right_sib[lsib] = rsib - st.left_root = lsib - if rsib != tskit.NULL: - st.left_sib[rsib] = lsib - st.left_root = rsib - - st.right = sequence_length - if j < M: - st.right = min(st.right, edges[in_order[j]].left) - if k < M: - st.right = min(st.right, edges[out_order[k]].right) - assert st.left_root != tskit.NULL - while st.left_sib[st.left_root] != tskit.NULL: - st.left_root = st.left_sib[st.left_root] - st.index += 1 + rtt = tsutil.RootThresholdTree(self._tree_sequence) + pt = PythonTree(self._tree_sequence.get_num_nodes()) + pt.index = 0 + for left, right in rtt.iterate(): + pt.parent[:] = rtt.parent + pt.left_child[:] = rtt.left_child + pt.right_child[:] = rtt.right_child + pt.left_sib[:] = rtt.left_sib + pt.right_sib[:] = rtt.right_sib + pt.left_root = rtt.left_root + pt.left = left + pt.right = right # Add in all the sites - st.site_list = [ - site for site in self._sites if st.left <= site.position < st.right] - yield st - st.left = st.right + pt.site_list = [ + site for site in self._sites if left <= site.position < right] + yield pt + pt.index += 1 + pt.index = -1 class MRCACalculator(object): diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index 9b48dc49d2..e7d3d7e092 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -918,6 +918,17 @@ def test_jukes_cantor_n_15(self): ts = tsutil.jukes_cantor(ts, num_sites=10, mu=0.1, seed=10) self.verify(ts, tskit.ALLELES_ACGT) + @unittest.skip("Not supporting internal samples yet") + def test_ancestors_n_3(self): + ts = msprime.simulate(3, recombination_rate=2, mutation_rate=7, random_seed=2) + self.assertGreater(ts.num_sites, 5) + tables = ts.dump_tables() + print(tables.nodes) + tables.nodes.flags = np.ones_like(tables.nodes.flags) + print(tables.nodes) + ts = tables.tree_sequence() + self.verify(ts) + class ForwardAlgorithmBase(LiStephensBase): """ @@ -949,6 +960,7 @@ class TestExactMatchViterbi(ViterbiAlgorithmBase, unittest.TestCase): def verify(self, ts, alleles=tskit.ALLELES_01): G = ts.genotype_matrix(alleles=alleles) H = G.T + # print(H) rho = np.zeros(ts.num_sites) + 0.1 mu = np.zeros(ts.num_sites) rho[0] = 0 diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 3602fc636c..210b83fb74 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2019 Tskit Developers +# Copyright (c) 2018-2020 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -39,13 +39,13 @@ import numpy as np import msprime +import networkx as nx import tskit import _tskit import tests as tests import tests.tsutil as tsutil import tests.simplify as simplify -import networkx as nx def insert_uniform_mutations(tables, num_mutations, nodes): @@ -426,7 +426,7 @@ def verify_tree(self, st): self.verify_tree_structure(st) def verify_trees(self, ts): - pts = tests.PythonTreeSequence(ts.get_ll_tree_sequence()) + pts = tests.PythonTreeSequence(ts) iter1 = ts.trees() iter2 = pts.trees() length = 0 @@ -440,7 +440,9 @@ def verify_trees(self, ts): while st1.get_parent(root) != tskit.NULL: root = st1.get_parent(root) roots.add(root) + self.assertEqual(st1.left_root, st2.left_root) self.assertEqual(sorted(list(roots)), sorted(st1.roots)) + self.assertEqual(st1.roots, st2.roots) if len(roots) > 1: with self.assertRaises(ValueError): st1.root @@ -580,7 +582,7 @@ def test_pairwise_diversity(self): self.verify_pairwise_diversity(ts) def verify_edge_diffs(self, ts): - pts = tests.PythonTreeSequence(ts.get_ll_tree_sequence()) + pts = tests.PythonTreeSequence(ts) d1 = list(ts.edge_diffs()) d2 = list(pts.edge_diffs()) self.assertEqual(d1, d2) @@ -744,15 +746,14 @@ def test_deprecated_sample_aliases(self): for u in t_new.nodes(): self.assertEqual( t_new.num_tracked_samples(u), t_old.get_num_tracked_leaves(u)) + trees_new = ts.trees() + trees_old = ts.trees() + for t_new, t_old in zip(trees_new, trees_old): + for u in t_new.nodes(): + self.assertEqual(t_new.num_samples(u), t_old.get_num_leaves(u)) + self.assertEqual( + list(t_new.samples(u)), list(t_old.get_leaves(u))) for on in [True, False]: - # sample/leaf counts - trees_new = ts.trees(sample_counts=on) - trees_old = ts.trees(leaf_counts=on) - for t_new, t_old in zip(trees_new, trees_old): - for u in t_new.nodes(): - self.assertEqual(t_new.num_samples(u), t_old.get_num_leaves(u)) - self.assertEqual( - list(t_new.samples(u)), list(t_old.get_leaves(u))) trees_new = ts.trees(sample_lists=on) trees_old = ts.trees(leaf_lists=on) for t_new, t_old in zip(trees_new, trees_old): @@ -802,42 +803,22 @@ def test_first_last(self): def test_trees_interface(self): ts = list(get_example_tree_sequences())[0] - # The defaults should make sense and count samples. - # get_num_tracked_samples for t in ts.trees(): self.assertEqual(t.get_num_samples(0), 1) self.assertEqual(t.get_num_tracked_samples(0), 0) self.assertEqual(list(t.samples(0)), [0]) self.assertIs(t.tree_sequence, ts) - for t in ts.trees(sample_counts=False): - self.assertEqual(t.get_num_samples(0), 1) - self.assertRaises(RuntimeError, t.get_num_tracked_samples, 0) - self.assertEqual(list(t.samples(0)), [0]) - - for t in ts.trees(sample_counts=True): - self.assertEqual(t.get_num_samples(0), 1) - self.assertEqual(t.get_num_tracked_samples(0), 0) - self.assertEqual(list(t.samples(0)), [0]) - - for t in ts.trees(sample_counts=True, tracked_samples=[0]): + for t in ts.trees(tracked_samples=[0]): self.assertEqual(t.get_num_samples(0), 1) self.assertEqual(t.get_num_tracked_samples(0), 1) self.assertEqual(list(t.samples(0)), [0]) - for t in ts.trees(sample_lists=True, sample_counts=True): + for t in ts.trees(sample_lists=True): self.assertEqual(t.get_num_samples(0), 1) self.assertEqual(t.get_num_tracked_samples(0), 0) self.assertEqual(list(t.samples(0)), [0]) - for t in ts.trees(sample_lists=True, sample_counts=False): - self.assertEqual(t.get_num_samples(0), 1) - self.assertRaises(RuntimeError, t.get_num_tracked_samples, 0) - self.assertEqual(list(t.samples(0)), [0]) - - self.assertRaises( - ValueError, ts.trees, sample_counts=False, tracked_samples=[0]) - def test_get_pairwise_diversity(self): for ts in get_example_tree_sequences(): self.assertRaises(ValueError, ts.get_pairwise_diversity, []) @@ -2140,6 +2121,22 @@ def test_copy_tracked_samples(self): self.assertEqual( tree.num_tracked_samples(j), copy.num_tracked_samples(j)) + def test_copy_multiple_roots(self): + ts = msprime.simulate(20, recombination_rate=2, length=3, random_seed=42) + ts = tsutil.decapitate(ts, ts.num_edges // 2) + for root_threshold in [1, 2, 100]: + tree = tskit.Tree(ts, root_threshold=root_threshold) + copy = tree.copy() + self.assertEqual(copy.roots, tree.roots) + self.assertEqual(copy.root_threshold, root_threshold) + while tree.next(): + copy = tree.copy() + self.assertEqual(copy.roots, tree.roots) + self.assertEqual(copy.root_threshold, root_threshold) + copy = tree.copy() + self.assertEqual(copy.roots, tree.roots) + self.assertEqual(copy.root_threshold, root_threshold) + def test_map_mutations(self): ts = msprime.simulate(5, random_seed=42) tree = ts.first() @@ -2160,6 +2157,18 @@ def test_map_mutations(self): tree.map_mutations([0] * 5, alleles) tree.map_mutations(np.zeros(5, dtype=int), alleles) + def test_sample_count_deprecated(self): + ts = msprime.simulate(5, random_seed=42) + with warnings.catch_warnings(record=True) as w: + ts.trees(sample_counts=True) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, RuntimeWarning)) + + with warnings.catch_warnings(record=True) as w: + tskit.Tree(ts, sample_counts=False) + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, RuntimeWarning)) + class TestNodeOrdering(HighLevelTestCase): """ diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 01b0de6433..1a0a936756 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2019 Tskit Developers +# Copyright (c) 2018-2020 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -1329,22 +1329,21 @@ def test_options(self): ts = self.get_example_tree_sequence() st = _tskit.Tree(ts) self.assertEqual(st.get_options(), 0) - # We should still be able to count the samples, just inefficiently. - self.assertEqual(st.get_num_samples(0), 1) - self.assertRaises(_tskit.LibraryError, st.get_num_tracked_samples, 0) all_options = [ - 0, _tskit.SAMPLE_COUNTS, _tskit.SAMPLE_LISTS, - _tskit.SAMPLE_COUNTS | _tskit.SAMPLE_LISTS] + 0, _tskit.NO_SAMPLE_COUNTS, _tskit.SAMPLE_LISTS, + _tskit.NO_SAMPLE_COUNTS | _tskit.SAMPLE_LISTS] for options in all_options: tree = _tskit.Tree(ts, options=options) copy = tree.copy() for st in [tree, copy]: self.assertEqual(st.get_options(), options) self.assertEqual(st.get_num_samples(0), 1) - if options & _tskit.SAMPLE_COUNTS: - self.assertEqual(st.get_num_tracked_samples(0), 0) - else: + if options & _tskit.NO_SAMPLE_COUNTS: + # We should still be able to count the samples, just inefficiently. + self.assertEqual(st.get_num_samples(0), 1) self.assertRaises(_tskit.LibraryError, st.get_num_tracked_samples, 0) + else: + self.assertEqual(st.get_num_tracked_samples(0), 0) if options & _tskit.SAMPLE_LISTS: self.assertEqual(0, st.get_left_sample(0)) self.assertEqual(0, st.get_right_sample(0)) @@ -1404,6 +1403,33 @@ def test_sites(self): j += 1 self.assertEqual(all_tree_sites, all_sites) + def test_root_threshold_errors(self): + ts = self.get_example_tree_sequence() + tree = _tskit.Tree(ts) + for bad_type in ["", "x", {}]: + with self.assertRaises(TypeError): + tree.set_root_threshold(bad_type) + + with self.assertRaises(_tskit.LibraryError): + tree.set_root_threshold(0) + tree.set_root_threshold(2) + # Setting when not in the null state raises an error + tree.next() + with self.assertRaises(_tskit.LibraryError): + tree.set_root_threshold(2) + + def test_root_threshold(self): + for ts in self.get_example_tree_sequences(): + tree = _tskit.Tree(ts) + for root_threshold in [1, 2, ts.get_num_samples() * 2]: + tree.set_root_threshold(root_threshold) + self.assertEqual(tree.get_root_threshold(), root_threshold) + while tree.next(): + self.assertEqual(tree.get_root_threshold(), root_threshold) + with self.assertRaises(_tskit.LibraryError): + tree.set_root_threshold(2) + self.assertEqual(tree.get_root_threshold(), root_threshold) + def test_constructor(self): self.assertRaises(TypeError, _tskit.Tree) for bad_type in ["", {}, [], None, 0]: @@ -1430,7 +1456,7 @@ def test_constructor(self): def test_bad_tracked_samples(self): ts = self.get_example_tree_sequence() - options = _tskit.SAMPLE_COUNTS + options = 0 for bad_type in ["", {}, [], None]: self.assertRaises( TypeError, _tskit.Tree, ts, options=options, @@ -1474,7 +1500,7 @@ def test_while_loop_semantics(self): def test_count_all_samples(self): for ts in self.get_example_tree_sequences(): self.verify_iterator(_tskit.TreeDiffIterator(ts)) - st = _tskit.Tree(ts, options=_tskit.SAMPLE_COUNTS) + st = _tskit.Tree(ts) # Without initialisation we should be 0 samples for every node # that is not a sample. for j in range(st.get_num_nodes()): @@ -1505,8 +1531,7 @@ def test_count_tracked_samples(self): for _, subset in zip(range(max_sets), map(list, powerset)): # Ordering shouldn't make any difference. random.shuffle(subset) - st = _tskit.Tree( - ts, options=_tskit.SAMPLE_COUNTS, tracked_samples=subset) + st = _tskit.Tree(ts, tracked_samples=subset) while st.next(): nu = get_tracked_sample_counts(st, subset) nu_prime = [ @@ -1518,15 +1543,14 @@ def test_count_tracked_samples(self): for j in range(2, 20): tracked_samples = [sample for _ in range(j)] self.assertRaises( - _tskit.LibraryError, _tskit.Tree, - ts, options=_tskit.SAMPLE_COUNTS, + _tskit.LibraryError, _tskit.Tree, ts, tracked_samples=tracked_samples) self.assertTrue(non_binary) def test_bounds_checking(self): for ts in self.get_example_tree_sequences(): n = ts.get_num_nodes() - st = _tskit.Tree(ts, options=_tskit.SAMPLE_COUNTS | _tskit.SAMPLE_LISTS) + st = _tskit.Tree(ts, options=_tskit.SAMPLE_LISTS) for v in [-100, -1, n + 1, n + 100, n * 100]: self.assertRaises(ValueError, st.get_parent, v) self.assertRaises(ValueError, st.get_children, v) @@ -1669,7 +1693,7 @@ def f(mutations): self.assertRaises(_tskit.LibraryError, f, [(length, 0)]) def test_sample_list(self): - options = _tskit.SAMPLE_COUNTS | _tskit.SAMPLE_LISTS + options = _tskit.SAMPLE_LISTS # Note: we're assuming that samples are 0-n here. for ts in self.get_example_tree_sequences(): t = _tskit.Tree(ts, options=options) @@ -1839,4 +1863,4 @@ def test_kastore_version(self): def test_tskit_version(self): version = _tskit.get_tskit_version() - self.assertEqual(version, (0, 99, 2)) + self.assertEqual(version, (0, 99, 3)) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index ca51b37a1d..a13dd8fdd8 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -931,7 +931,7 @@ def check_num_samples(self, ts, x): `(tree number, parent, number of samples)`. """ k = 0 - tss = ts.trees(sample_counts=True) + tss = ts.trees() t = next(tss) for j, node, nl in x: while k < j: @@ -941,7 +941,7 @@ def check_num_samples(self, ts, x): def check_num_tracked_samples(self, ts, tracked_samples, x): k = 0 - tss = ts.trees(sample_counts=True, tracked_samples=tracked_samples) + tss = ts.trees(tracked_samples=tracked_samples) t = next(tss) for j, node, nl in x: while k < j: @@ -1371,7 +1371,7 @@ def test_no_last_tree(self): 28 0.00000000 200000.00000000 0 4 """) ts = tskit.load_text(nodes, edges, sequence_length=200000, strict=False) - pts = tests.PythonTreeSequence(ts.get_ll_tree_sequence()) + pts = tests.PythonTreeSequence(ts) num_trees = 0 for t in pts.trees(): num_trees += 1 @@ -4618,48 +4618,10 @@ def test_coalescent_trees(self): self.assertRaises(StopIteration, next, new_trees) -class TestSampleLists(unittest.TestCase): +class ExampleTopologyMixin(object): """ - Tests for the sample lists algorithm. + Some example topologies for tests cases. """ - def verify(self, ts): - tree1 = tsutil.LinkedTree(ts) - s = str(tree1) - self.assertIsNotNone(s) - trees = ts.trees(sample_lists=True) - for left, right in tree1.sample_lists(): - tree2 = next(trees) - assert (left, right) == tree2.interval - for u in tree2.nodes(): - self.assertEqual(tree1.left_sample[u], tree2.left_sample(u)) - self.assertEqual(tree1.right_sample[u], tree2.right_sample(u)) - for j in range(ts.num_samples): - self.assertEqual(tree1.next_sample[j], tree2.next_sample(j)) - assert right == ts.sequence_length - - tree1 = tsutil.LinkedTree(ts) - trees = ts.trees(sample_lists=False) - sample_index_map = ts.samples() - for left, right in tree1.sample_lists(): - tree2 = next(trees) - for u in range(ts.num_nodes): - samples2 = list(tree2.samples(u)) - samples1 = [] - index = tree1.left_sample[u] - if index != tskit.NULL: - self.assertEqual( - sample_index_map[tree1.left_sample[u]], samples2[0]) - self.assertEqual( - sample_index_map[tree1.right_sample[u]], samples2[-1]) - stop = tree1.right_sample[u] - while True: - assert index != -1 - samples1.append(sample_index_map[index]) - if index == stop: - break - index = tree1.next_sample[index] - self.assertEqual(samples1, samples2) - assert right == ts.sequence_length def test_single_coalescent_tree(self): ts = msprime.simulate(10, random_seed=1, length=10) @@ -4722,6 +4684,105 @@ def test_many_multiroot_trees(self): ts = tsutil.decapitate(ts, ts.num_edges // 2) self.verify(ts) + def test_multiroot_tree(self): + ts = msprime.simulate(15, random_seed=10) + ts = tsutil.decapitate(ts, ts.num_edges // 2) + self.verify(ts) + + +class TestSampleLists(unittest.TestCase, ExampleTopologyMixin): + """ + Tests for the sample lists algorithm. + """ + def verify(self, ts): + tree1 = tsutil.SampleListTree(ts) + s = str(tree1) + self.assertIsNotNone(s) + trees = ts.trees(sample_lists=True) + for left, right in tree1.sample_lists(): + tree2 = next(trees) + assert (left, right) == tree2.interval + for u in tree2.nodes(): + self.assertEqual(tree1.left_sample[u], tree2.left_sample(u)) + self.assertEqual(tree1.right_sample[u], tree2.right_sample(u)) + for j in range(ts.num_samples): + self.assertEqual(tree1.next_sample[j], tree2.next_sample(j)) + assert right == ts.sequence_length + + tree1 = tsutil.SampleListTree(ts) + trees = ts.trees(sample_lists=False) + sample_index_map = ts.samples() + for left, right in tree1.sample_lists(): + tree2 = next(trees) + for u in range(ts.num_nodes): + samples2 = list(tree2.samples(u)) + samples1 = [] + index = tree1.left_sample[u] + if index != tskit.NULL: + self.assertEqual( + sample_index_map[tree1.left_sample[u]], samples2[0]) + self.assertEqual( + sample_index_map[tree1.right_sample[u]], samples2[-1]) + stop = tree1.right_sample[u] + while True: + assert index != -1 + samples1.append(sample_index_map[index]) + if index == stop: + break + index = tree1.next_sample[index] + self.assertEqual(samples1, samples2) + assert right == ts.sequence_length + + +class TestOneSampleRoot(unittest.TestCase, ExampleTopologyMixin): + """ + Tests for the standard root threshold of subtending at least + one sample. + """ + def verify(self, ts): + tree1 = tsutil.RootThresholdTree(ts, root_threshold=1) + tree2 = tskit.Tree(ts) + tree2.first() + for interval in tree1.iterate(): + self.assertEqual(interval, tree2.interval) + self.assertEqual(tree1.roots(), tree2.roots) + # Definition here is the set unique path ends from samples + roots = set() + for u in ts.samples(): + while u != tskit.NULL: + path_end = u + u = tree2.parent(u) + roots.add(path_end) + self.assertEqual(set(tree1.roots()), roots) + tree2.next() + self.assertEqual(tree2.index, -1) + + +class TestKSamplesRoot(unittest.TestCase, ExampleTopologyMixin): + """ + Tests for the root criteria of subtending at least k samples. + """ + def verify(self, ts): + for k in range(1, 5): + tree1 = tsutil.RootThresholdTree(ts, root_threshold=k) + tree2 = tskit.Tree(ts, root_threshold=k) + tree2.first() + for interval in tree1.iterate(): + self.assertEqual(interval, tree2.interval) + # Definition here is the set unique path ends from samples + # that subtend at least k samples + roots = set() + for u in ts.samples(): + while u != tskit.NULL: + path_end = u + u = tree2.parent(u) + if tree2.num_samples(path_end) >= k: + roots.add(path_end) + self.assertEqual(set(tree1.roots()), roots) + self.assertEqual(tree1.roots(), tree2.roots) + tree2.next() + self.assertEqual(tree2.index, -1) + class TestSquashEdges(unittest.TestCase): """ diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index 4d852d091a..d12ad4d4d8 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -2752,7 +2752,7 @@ def naive_branch_allele_frequency_spectrum( for set_index, sample_set in enumerate(sample_sets): S = np.zeros(out_dim) trees = [ - next(ts.trees(tracked_samples=sample_set, sample_counts=True)) + next(ts.trees(tracked_samples=sample_set)) for sample_set in sample_sets] t = trees[0] while True: diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 65de764244..7c7fbc0ea4 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -535,7 +535,7 @@ def algorithm_T(ts): left = right -class LinkedTree(object): +class SampleListTree(object): """ Straightforward implementation of the quintuply linked tree for developing and testing the sample lists feature. @@ -689,6 +689,195 @@ def sample_lists(self): left = right +class RootThresholdTree(object): + """ + Straightforward implementation of the quintuply linked tree for developing + and testing the root_threshold feature. + + NOTE: The interface is pretty awkward; it's not intended for anything other + than testing. + """ + def __init__(self, tree_sequence, root_threshold=1): + self.tree_sequence = tree_sequence + self.root_threshold = root_threshold + num_nodes = tree_sequence.num_nodes + # Quintuply linked tree. + self.parent = [-1 for _ in range(num_nodes)] + self.left_sib = [-1 for _ in range(num_nodes)] + self.right_sib = [-1 for _ in range(num_nodes)] + self.left_child = [-1 for _ in range(num_nodes)] + self.right_child = [-1 for _ in range(num_nodes)] + self.num_samples = [0 for _ in range(num_nodes)] + self.left_root = -1 + for u in tree_sequence.samples()[::-1]: + self.num_samples[u] = 1 + if self.root_threshold == 1: + self.add_root(u) + + def __str__(self): + fmt = "{:<5}{:>8}{:>8}{:>8}{:>8}{:>8}{:>8}\n" + s = f"roots = {self.roots()}\n" + s += fmt.format( + "node", "parent", "lsib", "rsib", "lchild", "rchild", "nsamp") + for u in range(self.tree_sequence.num_nodes): + s += fmt.format( + u, self.parent[u], + self.left_sib[u], self.right_sib[u], + self.left_child[u], self.right_child[u], + self.num_samples[u]) + # Strip off trailing newline + return s[:-1] + + def is_root(self, u): + return self.num_samples[u] >= self.root_threshold + + def roots(self): + roots = [] + u = self.left_root + while u != -1: + roots.append(u) + u = self.right_sib[u] + return roots + + def add_root(self, root): + if self.left_root != tskit.NULL: + lroot = self.left_sib[self.left_root] + if lroot != tskit.NULL: + self.right_sib[lroot] = root + self.left_sib[root] = lroot + self.left_sib[self.left_root] = root + self.right_sib[root] = self.left_root + self.left_root = root + + def remove_root(self, root): + lroot = self.left_sib[root] + rroot = self.right_sib[root] + self.left_root = tskit.NULL + if lroot != tskit.NULL: + self.right_sib[lroot] = rroot + self.left_root = lroot + if rroot != tskit.NULL: + self.left_sib[rroot] = lroot + self.left_root = rroot + self.left_sib[root] = tskit.NULL + self.right_sib[root] = tskit.NULL + + def remove_edge(self, edge): + p = edge.parent + c = edge.child + lsib = self.left_sib[c] + rsib = self.right_sib[c] + if lsib == -1: + self.left_child[p] = rsib + else: + self.right_sib[lsib] = rsib + if rsib == -1: + self.right_child[p] = lsib + else: + self.left_sib[rsib] = lsib + self.parent[c] = -1 + self.left_sib[c] = -1 + self.right_sib[c] = -1 + + u = edge.parent + while u != -1: + path_end = u + path_end_was_root = self.is_root(u) + self.num_samples[u] -= self.num_samples[c] + u = self.parent[u] + if path_end_was_root and not self.is_root(path_end): + self.remove_root(path_end) + if self.is_root(c): + self.add_root(c) + + def insert_edge(self, edge): + p = edge.parent + c = edge.child + assert self.parent[c] == -1, "contradictory edges" + self.parent[c] = p + u = self.right_child[p] + lsib = self.left_sib[c] + rsib = self.right_sib[c] + if u == -1: + self.left_child[p] = c + self.left_sib[c] = -1 + self.right_sib[c] = -1 + else: + self.right_sib[u] = c + self.left_sib[c] = u + self.right_sib[c] = -1 + self.right_child[p] = c + + u = edge.parent + while u != -1: + path_end = u + path_end_was_root = self.is_root(u) + self.num_samples[u] += self.num_samples[c] + u = self.parent[u] + + if self.is_root(c): + if path_end_was_root: + # Remove c from root list. + # Note: we don't use the remove_root function here because + # it assumes that the node is at the end of a path + self.left_root = tskit.NULL + if lsib != tskit.NULL: + self.right_sib[lsib] = rsib + self.left_root = lsib + if rsib != tskit.NULL: + self.left_sib[rsib] = lsib + self.left_root = rsib + else: + # Replace c with path_end in the root list + if lsib != tskit.NULL: + self.right_sib[lsib] = path_end + if rsib != tskit.NULL: + self.left_sib[rsib] = path_end + self.left_sib[path_end] = lsib + self.right_sib[path_end] = rsib + self.left_root = path_end + else: + if self.is_root(path_end) and not path_end_was_root: + self.add_root(path_end) + + def iterate(self): + """ + Iterate over the the trees in this tree sequence, yielding the (left, right) + interval tuples. The tree state is maintained internally. + """ + ts = self.tree_sequence + sequence_length = ts.sequence_length + edges = list(ts.edges()) + M = len(edges) + time = [ts.node(edge.parent).time for edge in edges] + in_order = sorted(range(M), key=lambda j: ( + edges[j].left, time[j], edges[j].parent, edges[j].child)) + out_order = sorted(range(M), key=lambda j: ( + edges[j].right, -time[j], -edges[j].parent, -edges[j].child)) + j = 0 + k = 0 + left = 0 + + while j < M or left < sequence_length: + while k < M and edges[out_order[k]].right == left: + edge = edges[out_order[k]] + self.remove_edge(edge) + k += 1 + while j < M and edges[in_order[j]].left == left: + edge = edges[in_order[j]] + self.insert_edge(edge) + j += 1 + while self.left_sib[self.left_root] != tskit.NULL: + self.left_root = self.left_sib[self.left_root] + right = sequence_length + if j < M: + right = min(right, edges[in_order[j]].left) + if k < M: + right = min(right, edges[out_order[k]].right) + yield left, right + left = right + + def mean_descendants(ts, reference_sets): """ Returns the mean number of nodes from the specified reference sets diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 3f45317690..b9e0bb03b1 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -445,18 +445,13 @@ class Tree(object): on how efficiently access trees sequentially or obtain a list of individual trees in a tree sequence. - The ``sample_counts`` and ``sample_lists`` parameters control the - features that are enabled for this tree. If ``sample_counts`` - is True, then it is possible to count the number of samples underneath - a particular node in constant time using the :meth:`num_samples` - method. If ``sample_lists`` is True a more efficient algorithm is + The ``sample_lists`` parameter controls the features that are enabled + for this tree. If ``sample_lists`` is True a more efficient algorithm is used in the :meth:`Tree.samples` method. The ``tracked_samples`` parameter can be used to efficiently count the number of samples in a given set that exist in a particular subtree - using the :meth:`Tree.num_tracked_samples` method. It is an - error to use the ``tracked_samples`` parameter when the ``sample_counts`` - flag is False. + using the :meth:`Tree.num_tracked_samples` method. The :class:`Tree` class is a state-machine which has a state corresponding to each of the trees in the parent tree sequence. We @@ -477,21 +472,26 @@ class Tree(object): :param TreeSequence tree_sequence: The parent tree sequence. :param list tracked_samples: The list of samples to be tracked and counted using the :meth:`Tree.num_tracked_samples` method. - :param bool sample_counts: If True, support constant time sample counts - via the :meth:`Tree.num_samples` and - :meth:`Tree.num_tracked_samples` methods. + :param bool sample_counts: Deprecated since 0.2.4. :param bool sample_lists: If True, provide more efficient access to the samples beneath a give node using the :meth:`Tree.samples` method. + :param int root_threshold: The minimum number of samples that a node + must be ancestral to for it to be in the list of roots. By default + this is 1, so that isolated samples (representing missing data) + are roots. To efficiently restrict the roots of the tree to + those subtending meaningful topology, set this to 2. This value + is only relevant when trees have multiple roots. """ def __init__( self, tree_sequence, - tracked_samples=None, sample_counts=True, sample_lists=False): + tracked_samples=None, sample_counts=None, sample_lists=False, + root_threshold=1): options = 0 - if sample_counts: - options |= _tskit.SAMPLE_COUNTS - elif tracked_samples is not None: - raise ValueError("Cannot set tracked_samples without sample_counts") + if sample_counts is not None: + warnings.warn( + "The sample_counts option is not supported since 0.2.4 " + "and is ignored", RuntimeWarning) if sample_lists: options |= _tskit.SAMPLE_LISTS kwargs = {"options": options} @@ -501,6 +501,7 @@ def __init__( self._tree_sequence = tree_sequence self._ll_tree = _tskit.Tree(tree_sequence.ll_tree_sequence, **kwargs) + self._ll_tree.set_root_threshold(root_threshold) def copy(self): """ @@ -525,6 +526,17 @@ def tree_sequence(self): """ return self._tree_sequence + @property + def root_threshold(self): + """ + Returns the minimum number of samples that a node must be an ancestor + of to be considered a potential root. + + :return: The root threshold. + :rtype: :class:`TreeSequence` + """ + return self._ll_tree.get_root_threshold() + def __eq__(self, other): ret = False if type(other) is type(self): @@ -1394,9 +1406,7 @@ def num_samples(self, u=None): node (including the node itself). If u is not specified return the total number of samples in the tree. - If the :meth:`TreeSequence.trees` method is called with - ``sample_counts=True`` this method is a constant time operation. If not, - a slower traversal based algorithm is used to count the samples. + This is a constant time operation. :param int u: The node of interest. :return: The number of samples in the subtree rooted at u. @@ -1430,16 +1440,10 @@ def num_tracked_samples(self, u=None): :return: The number of samples within the set of tracked samples in the subtree rooted at u. :rtype: int - :raises RuntimeError: if the :meth:`TreeSequence.trees` - method is not called with ``sample_counts=True``. """ roots = [u] if u is None: roots = self.roots - if not (self._ll_tree.get_options() & _tskit.SAMPLE_COUNTS): - raise RuntimeError( - "The get_num_tracked_samples method is only supported " - "when sample_counts=True.") return sum(self._ll_tree.get_num_tracked_samples(root) for root in roots) def _preorder_traversal(self, u): @@ -2875,7 +2879,7 @@ def last(self): return tree def trees( - self, tracked_samples=None, sample_counts=True, sample_lists=False, + self, tracked_samples=None, sample_counts=None, sample_lists=False, tracked_leaves=None, leaf_counts=None, leaf_lists=None): """ Returns an iterator over the trees in this tree sequence. Each value @@ -2883,8 +2887,8 @@ def trees( successful termination of the iterator, the tree will be in the "cleared" null state. - The ``sample_counts``, ``sample_lists`` and ``tracked_samples`` - parameters are passed to the :class:`Tree` constructor, and control + The ``sample_lists`` and ``tracked_samples`` parameters are passed + to the :class:`Tree` constructor, and control the options that are set in the returned tree instance. :warning: Do not store the results of this iterator in a list! @@ -2895,9 +2899,7 @@ def trees( :param list tracked_samples: The list of samples to be tracked and counted using the :meth:`Tree.num_tracked_samples` method. - :param bool sample_counts: If True, support constant time sample counts - via the :meth:`Tree.num_samples` and - :meth:`Tree.num_tracked_samples` methods. + :param bool sample_counts: Deprecated since 0.2.4. :param bool sample_lists: If True, provide more efficient access to the samples beneath a give node using the :meth:`Tree.samples` method.