From 53a4756b4892d031e2bf95f1f36fe41f9ced26d7 Mon Sep 17 00:00:00 2001 From: Daniel Goldstein Date: Mon, 9 Mar 2020 17:09:07 +0000 Subject: [PATCH] make kc linear in the number of nodes by not doing pairwise mrca queries --- c/tests/test_trees.c | 65 ++++-- c/tskit/trees.c | 371 +++++++++++++++++++++------------- python/CHANGELOG.rst | 4 + python/tests/test_lowlevel.py | 24 ++- python/tests/test_topology.py | 327 ++++++++++++++++++------------ 5 files changed, 500 insertions(+), 291 deletions(-) diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index cd786b846f..f14137600a 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -4296,6 +4296,40 @@ test_internal_sample_sample_sets(void) * KC Distance tests. *=======================================================*/ +static void +test_isolated_node_kc(void) +{ + const char *single_leaf = "1 0 0"; + const char *single_internal = "0 0 0"; + const char *edges = ""; + tsk_treeseq_t ts; + tsk_tree_t t; + int ret; + double result = 0; + + tsk_treeseq_from_text(&ts, 1, single_leaf, edges, NULL, NULL, NULL, NULL, NULL); + ret = tsk_tree_init(&t, &ts, TSK_SAMPLE_LISTS); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_first(&t); + CU_ASSERT_EQUAL_FATAL(ret, 1); + ret = tsk_tree_kc_distance(&t, &t, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(result, 0); + tsk_treeseq_free(&ts); + tsk_tree_free(&t); + + tsk_treeseq_from_text(&ts, 1, single_internal, edges, NULL, NULL, NULL, NULL, NULL); + ret = tsk_tree_init(&t, &ts, TSK_SAMPLE_LISTS); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_first(&t); + CU_ASSERT_EQUAL_FATAL(ret, 1); + CU_ASSERT_EQUAL_FATAL(t.left_root, TSK_NULL); + ret = tsk_tree_kc_distance(&t, &t, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_ROOTS); + tsk_treeseq_free(&ts); + tsk_tree_free(&t); +} + static void test_single_tree_kc(void) { @@ -4306,11 +4340,11 @@ test_single_tree_kc(void) tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, NULL, NULL, NULL); - ret = tsk_tree_init(&t, &ts, 0); + ret = tsk_tree_init(&t, &ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_first(&t); CU_ASSERT_EQUAL_FATAL(ret, 1); - ret = tsk_tree_init(&other_t, &ts, 0); + ret = tsk_tree_init(&other_t, &ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_first(&other_t); CU_ASSERT_EQUAL_FATAL(ret, 1); @@ -4349,13 +4383,13 @@ test_two_trees_kc(void) double result = 0; tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); - ret = tsk_tree_init(&t, &ts, 0); + ret = tsk_tree_init(&t, &ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_first(&t); CU_ASSERT_EQUAL_FATAL(ret, 1); tsk_treeseq_from_text( &other_ts, 1, nodes_other, edges, NULL, NULL, NULL, NULL, NULL); - ret = tsk_tree_init(&other_t, &other_ts, 0); + ret = tsk_tree_init(&other_t, &other_ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_first(&other_t); CU_ASSERT_EQUAL_FATAL(ret, 1); @@ -4380,18 +4414,18 @@ test_empty_tree_kc(void) int ret; double result = 0; - ret = tsk_table_collection_init(&tables, 0); + ret = tsk_table_collection_init(&tables, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = tsk_treeseq_init(&ts, &tables, TSK_BUILD_INDEXES); + ret = tsk_treeseq_init(&ts, &tables, TSK_BUILD_INDEXES | TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SEQUENCE_LENGTH); tsk_treeseq_free(&ts); tables.sequence_length = 1.0; - ret = tsk_treeseq_init(&ts, &tables, TSK_BUILD_INDEXES); + ret = tsk_treeseq_init(&ts, &tables, TSK_BUILD_INDEXES | TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); verify_empty_tree_sequence(&ts, 1.0); - ret = tsk_tree_init(&t, &ts, 0); + ret = tsk_tree_init(&t, &ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_first(&t); CU_ASSERT_EQUAL_FATAL(ret, 1); @@ -4423,7 +4457,7 @@ test_nonbinary_tree_kc(void) double result = 0; tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); - ret = tsk_tree_init(&t, &ts, 0); + ret = tsk_tree_init(&t, &ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_first(&t); CU_ASSERT_EQUAL_FATAL(ret, 1); @@ -4447,7 +4481,7 @@ test_nonzero_samples_kc(void) double result = 0; tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); - ret = tsk_tree_init(&t, &ts, 0); + ret = tsk_tree_init(&t, &ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_first(&t); CU_ASSERT_EQUAL_FATAL(ret, 1); @@ -4471,7 +4505,7 @@ test_internal_samples_kc(void) double result = 0; tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); - ret = tsk_tree_init(&t, &ts, 0); + ret = tsk_tree_init(&t, &ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_first(&t); CU_ASSERT_EQUAL_FATAL(ret, 1); @@ -4501,13 +4535,13 @@ test_unequal_sample_size_kc(void) double result = 0; tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); - ret = tsk_tree_init(&t, &ts, 0); + ret = tsk_tree_init(&t, &ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_first(&t); CU_ASSERT_EQUAL_FATAL(ret, 1); tsk_treeseq_from_text( &other_ts, 1, nodes_other, edges_other, NULL, NULL, NULL, NULL, NULL); - ret = tsk_tree_init(&other_t, &other_ts, 0); + ret = tsk_tree_init(&other_t, &other_ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_first(&other_t); CU_ASSERT_EQUAL_FATAL(ret, 1); @@ -4543,13 +4577,13 @@ test_unequal_samples_kc(void) double result = 0; tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL); - ret = tsk_tree_init(&t, &ts, 0); + ret = tsk_tree_init(&t, &ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_first(&t); CU_ASSERT_EQUAL_FATAL(ret, 1); tsk_treeseq_from_text( &other_ts, 1, nodes_other, edges_other, NULL, NULL, NULL, NULL, NULL); - ret = tsk_tree_init(&other_t, &other_ts, 0); + ret = tsk_tree_init(&other_t, &other_ts, TSK_SAMPLE_LISTS); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = tsk_tree_first(&other_t); CU_ASSERT_EQUAL_FATAL(ret, 1); @@ -5264,6 +5298,7 @@ main(int argc, char **argv) /*KC distance tests */ { "test_single_tree_kc", test_single_tree_kc }, + { "test_isolated_node_kc", test_isolated_node_kc }, { "test_two_trees_kc", test_two_trees_kc }, { "test_empty_tree_kc", test_empty_tree_kc }, { "test_nonbinary_tree_kc", test_nonbinary_tree_kc }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 435ae4937a..a4933082bb 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -4282,146 +4282,6 @@ tsk_tree_map_mutations(tsk_tree_t *self, int8_t *genotypes, return ret; } -int -tsk_tree_kc_distance(tsk_tree_t *self, tsk_tree_t *other, double lambda, double *result) -{ - struct stack_elmt { - tsk_id_t node; - int path_depth; - double time_depth; - }; - - tsk_size_t num_nodes_self, num_nodes_other; - int ret = 0; - int stack_top = 0; - int path_depth, tree_index, pair_index, i; - double vT1, vT2, distance_sum, time_depth, root_time; - int *m[2], *path_distance[2]; - double *M[2], *time_distance[2]; - tsk_id_t N, u, v, mrca, n1, n2, num_samples, u_index; - tsk_tree_t *trees[2] = { self, other }; - tsk_tree_t *tree; - const tsk_id_t *samples = self->tree_sequence->samples; - const tsk_id_t *other_samples = other->tree_sequence->samples; - const double *times; - const tsk_id_t *sample_index_map; - struct stack_elmt *stack = NULL; - - memset(path_distance, 0, sizeof(path_distance)); - memset(time_distance, 0, sizeof(time_distance)); - memset(m, 0, sizeof(m)); - memset(M, 0, sizeof(M)); - - if (tsk_tree_get_num_roots(self) != 1 || tsk_tree_get_num_roots(other) != 1) { - ret = TSK_ERR_MULTIPLE_ROOTS; - goto out; - } - if (self->tree_sequence->num_samples != other->tree_sequence->num_samples) { - ret = TSK_ERR_SAMPLE_SIZE_MISMATCH; - goto out; - } - - num_samples = (tsk_id_t) self->tree_sequence->num_samples; - N = (num_samples * (num_samples - 1)) / 2; - num_nodes_self = self->num_nodes; - num_nodes_other = other->num_nodes; - stack = malloc(TSK_MAX(num_nodes_self, num_nodes_other) * sizeof(*stack)); - m[0] = calloc((size_t)(N + num_samples), sizeof(m[0])); - m[1] = calloc((size_t)(N + num_samples), sizeof(m[1])); - M[0] = malloc((size_t)(N + num_samples) * sizeof(M[0])); - M[1] = malloc((size_t)(N + num_samples) * sizeof(M[1])); - path_distance[0] = malloc(num_nodes_self * sizeof(path_distance[0])); - path_distance[1] = malloc(num_nodes_other * sizeof(path_distance[1])); - time_distance[0] = malloc(num_nodes_self * sizeof(time_distance[0])); - time_distance[1] = malloc(num_nodes_other * sizeof(time_distance[1])); - if (stack == NULL || m[0] == NULL || m[1] == NULL || M[0] == NULL || M[1] == NULL - || path_distance[0] == NULL || time_distance[1] == NULL) { - ret = TSK_ERR_NO_MEMORY; - goto out; - } - - for (i = 0; i < num_samples; i++) { - if (samples[i] != other_samples[i]) { - ret = TSK_ERR_SAMPLES_NOT_EQUAL; - goto out; - } - u = samples[i]; - if (self->left_child[u] != TSK_NULL || other->left_child[u] != TSK_NULL) { - /* It's probably possible to support this, but it's too awkward - * to deal with and seems like a fairly niche requirement. */ - ret = TSK_ERR_INTERNAL_SAMPLES; - goto out; - } - } - - for (i = 0; i <= N + num_samples; i++) { - m[0][i] = 1; - m[1][i] = 1; - } - - for (tree_index = 0; tree_index < 2; tree_index++) { - tree = trees[tree_index]; - times = tree->tree_sequence->tables->nodes.time; - sample_index_map = tree->tree_sequence->sample_index_map; - stack_top = 0; - u = tree->left_root; - root_time = times[u]; - stack[stack_top].node = u; - stack[stack_top].path_depth = 0; - stack[stack_top].time_depth = root_time; - while (stack_top >= 0) { - u = stack[stack_top].node; - path_depth = stack[stack_top].path_depth; - time_depth = stack[stack_top].time_depth; - stack_top--; - for (v = tree->left_child[u]; v != TSK_NULL; v = tree->right_sib[v]) { - stack_top++; - stack[stack_top].node = v; - stack[stack_top].path_depth = path_depth + 1; - stack[stack_top].time_depth = times[v]; - } - path_distance[tree_index][u] = path_depth; - time_distance[tree_index][u] = root_time - time_depth; - u_index = sample_index_map[u]; - if (u_index != TSK_NULL) { - M[tree_index][u_index + N] = times[tree->parent[u]] - times[u]; - } - } - for (n1 = 0; n1 < num_samples; n1++) { - for (n2 = n1 + 1; n2 < num_samples; n2++) { - ret = tsk_tree_get_mrca(tree, samples[n1], samples[n2], &mrca); - if (ret != 0) { - goto out; - } - pair_index = n2 - n1 - 1 + (-1 * n1 * (n1 - 2 * num_samples + 1)) / 2; - assert(m[tree_index][pair_index] == 1); - m[tree_index][pair_index] = path_distance[tree_index][mrca]; - M[tree_index][pair_index] = time_distance[tree_index][mrca]; - } - } - } - - vT1 = 0; - vT2 = 0; - distance_sum = 0; - for (i = 0; i < N + num_samples; i++) { - vT1 = (m[0][i] * (1 - lambda)) + (lambda * M[0][i]); - vT2 = (m[1][i] * (1 - lambda)) + (lambda * M[1][i]); - distance_sum += (vT1 - vT2) * (vT1 - vT2); - } - - *result = sqrt(distance_sum); -out: - tsk_safe_free(stack); - for (i = 0; i < 2; i++) { - tsk_safe_free(m[i]); - tsk_safe_free(M[i]); - tsk_safe_free(path_distance[i]); - tsk_safe_free(time_distance[i]); - } - return ret; -} - /* ======================================================== * * Tree diff iterator. * ======================================================== */ @@ -4559,3 +4419,234 @@ tsk_diff_iter_next(tsk_diff_iter_t *self, double *ret_left, double *ret_right, self->tree_left = right; return ret; } + +/* ======================================================== * + * KC Distance + * ======================================================== */ + +typedef struct { + int *m; + double *M; + tsk_id_t n; + tsk_id_t N; +} kc_vectors; + +static int +kc_vectors_alloc(kc_vectors *self, tsk_id_t n) +{ + int ret = 0; + + self->n = n; + self->N = (n * (n - 1)) / 2; + self->m = calloc((size_t)(self->N + self->n), sizeof(int)); + self->M = calloc((size_t)(self->N + self->n), sizeof(double)); + if (self->m == NULL || self->M == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + +out: + return ret; +} + +static void +kc_vectors_free(kc_vectors *self) +{ + tsk_safe_free(self->m); + tsk_safe_free(self->M); +} + +static inline void +update_kc_vectors_single_leaf( + tsk_treeseq_t *ts, kc_vectors *kc_vecs, tsk_id_t u, double time) +{ + const tsk_id_t *sample_index_map = ts->sample_index_map; + tsk_id_t u_index = sample_index_map[u]; + + kc_vecs->m[kc_vecs->N + u_index] = 1; + kc_vecs->M[kc_vecs->N + u_index] = time; +} + +static inline void +update_kc_vectors_all_pairs(tsk_tree_t *tree, kc_vectors *kc_vecs, tsk_id_t u, + tsk_id_t v, int depth, double time) +{ + tsk_id_t leaf1_index, leaf2_index, leaf1, leaf2, n1, n2, tmp, pair_index; + tsk_treeseq_t *ts = tree->tree_sequence; + const tsk_id_t *restrict samples = ts->samples; + const tsk_id_t *restrict left_sample = tree->left_sample; + const tsk_id_t *restrict right_sample = tree->right_sample; + const tsk_id_t *restrict next_sample = tree->next_sample; + const tsk_id_t *restrict sample_index_map = ts->sample_index_map; + int *restrict kc_m = kc_vecs->m; + double *restrict kc_M = kc_vecs->M; + + leaf1_index = left_sample[u]; + while (true) { + leaf1 = samples[leaf1_index]; + leaf2_index = left_sample[v]; + while (true) { + leaf2 = samples[leaf2_index]; + + n1 = sample_index_map[leaf1]; + n2 = sample_index_map[leaf2]; + if (n1 > n2) { + tmp = n1; + n1 = n2; + n2 = tmp; + } + + /* We spend ~40% of our time here because these accesses + * are not in order and gets very poor cache behavior */ + pair_index = n2 - n1 - 1 + (-1 * n1 * (n1 - 2 * kc_vecs->n + 1)) / 2; + kc_m[pair_index] = depth; + kc_M[pair_index] = time; + + if (leaf2_index == right_sample[v]) { + break; + } + leaf2_index = next_sample[leaf2_index]; + } + if (leaf1_index == right_sample[u]) { + break; + } + leaf1_index = next_sample[leaf1_index]; + } +} + +static int +fill_kc_vectors(tsk_tree_t *t, kc_vectors *kc_vecs) +{ + struct stack_elmt { + tsk_id_t node; + int depth; + }; + + int stack_top, depth; + double time; + const double *times; + struct stack_elmt *stack; + tsk_id_t root, u, c1, c2; + int ret = 0; + tsk_treeseq_t *ts = t->tree_sequence; + + stack = malloc(t->num_nodes * sizeof(*stack)); + if (stack == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + times = t->tree_sequence->tables->nodes.time; + + for (root = t->left_root; root != TSK_NULL; root = t->right_sib[root]) { + stack_top = 0; + stack[stack_top].node = root; + stack[stack_top].depth = 0; + while (stack_top >= 0) { + u = stack[stack_top].node; + depth = stack[stack_top].depth; + stack_top--; + + if (t->left_child[u] == TSK_NULL) { + if (u == root) { + time = 0; + } else { + time = times[t->parent[u]] - times[u]; + } + update_kc_vectors_single_leaf(ts, kc_vecs, u, time); + } else { + for (c1 = t->left_child[u]; c1 != TSK_NULL; c1 = t->right_sib[c1]) { + stack_top++; + stack[stack_top].node = c1; + stack[stack_top].depth = depth + 1; + + for (c2 = t->right_sib[c1]; c2 != TSK_NULL; c2 = t->right_sib[c2]) { + update_kc_vectors_all_pairs( + t, kc_vecs, c1, c2, depth, times[root] - times[u]); + } + } + } + } + } + +out: + tsk_safe_free(stack); + return ret; +} + +static double +norm_kc_vectors(kc_vectors *self, kc_vectors *other, double lambda) +{ + double vT1, vT2, distance_sum; + tsk_id_t i; + + distance_sum = 0; + for (i = 0; i < self->n + self->N; i++) { + vT1 = (self->m[i] * (1 - lambda)) + (lambda * self->M[i]); + vT2 = (other->m[i] * (1 - lambda)) + (lambda * other->M[i]); + distance_sum += (vT1 - vT2) * (vT1 - vT2); + } + + return sqrt(distance_sum); +} + +int +tsk_tree_kc_distance(tsk_tree_t *self, tsk_tree_t *other, double lambda, double *result) +{ + tsk_id_t u, n, i; + kc_vectors vecs[2]; + tsk_tree_t *trees[2] = { self, other }; + const tsk_id_t *samples = self->tree_sequence->samples; + const tsk_id_t *other_samples = other->tree_sequence->samples; + int ret = 0; + + for (i = 0; i < 2; i++) { + memset(&vecs[i], 0, sizeof(kc_vectors)); + } + + if (tsk_tree_get_num_roots(self) != 1 || tsk_tree_get_num_roots(other) != 1) { + ret = TSK_ERR_MULTIPLE_ROOTS; + goto out; + } + if (self->tree_sequence->num_samples != other->tree_sequence->num_samples) { + ret = TSK_ERR_SAMPLE_SIZE_MISMATCH; + goto out; + } + if (!tsk_tree_has_sample_lists(self) || !tsk_tree_has_sample_lists(other)) { + ret = TSK_ERR_UNSUPPORTED_OPERATION; + goto out; + } + + n = (tsk_id_t) self->tree_sequence->num_samples; + for (i = 0; i < n; i++) { + if (samples[i] != other_samples[i]) { + ret = TSK_ERR_SAMPLES_NOT_EQUAL; + goto out; + } + u = samples[i]; + if (self->left_child[u] != TSK_NULL || other->left_child[u] != TSK_NULL) { + /* It's probably possible to support this, but it's too awkward + * to deal with and seems like a fairly niche requirement. */ + ret = TSK_ERR_INTERNAL_SAMPLES; + goto out; + } + } + + for (i = 0; i < 2; i++) { + ret = kc_vectors_alloc(&vecs[i], n); + if (ret != 0) { + goto out; + } + ret = fill_kc_vectors(trees[i], &vecs[i]); + if (ret != 0) { + goto out; + } + } + + *result = norm_kc_vectors(&vecs[0], &vecs[1], lambda); +out: + for (i = 0; i < 2; i++) { + kc_vectors_free(&vecs[i]); + } + return ret; +} diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 5f319d06fa..d28f663999 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -6,6 +6,10 @@ In development **New features** +- Improve Kendall-Colijn tree distance algorithm to operate in O(n^2) time + instead of O(n^2 * log(n)) where n is the number of samples + (:user:`daniel-goldstein`, :pr:`490`) + - Add a metadata column to the migrations table. Works similarly to existing metadata columns on other tables(:user:`benjeffery`, :pr:`505`). diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 4833d7772c..aca43b579d 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1899,7 +1899,7 @@ def test_equality(self): def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) - t1 = _tskit.Tree(ts1) + t1 = _tskit.Tree(ts1, options=_tskit.SAMPLE_LISTS) t1.first() self.assertRaises(TypeError, t1.get_kc_distance) self.assertRaises(TypeError, t1.get_kc_distance, t1) @@ -1908,7 +1908,7 @@ def test_kc_distance_errors(self): for bad_value in ["tree", [], None]: self.assertRaises(TypeError, t1.get_kc_distance, t1, lambda_=bad_value) - t2 = _tskit.Tree(ts1) + t2 = _tskit.Tree(ts1, options=_tskit.SAMPLE_LISTS) # If we don't seek to a specific tree, it has multiple roots (i.e., it's # in the null state). This fails because we don't accept multiple roots. with self.assertRaises(_tskit.LibraryError): @@ -1916,10 +1916,15 @@ def test_kc_distance_errors(self): # Different numbers of samples fail. ts2 = self.get_example_tree_sequence(11) + t2 = _tskit.Tree(ts2, options=_tskit.SAMPLE_LISTS) + t2.first() + self.verify_kc_library_error(t1, t2) + + # Error when tree not initialized with sample lists + ts2 = self.get_example_tree_sequence(10) t2 = _tskit.Tree(ts2) t2.first() - with self.assertRaises(_tskit.LibraryError): - t1.get_kc_distance(t2, 0) + self.verify_kc_library_error(t1, t2) # Internal samples cause errors. tables = _tskit.TableCollection(1.0) @@ -1928,17 +1933,20 @@ def test_kc_distance_errors(self): tables.edges.add_row(0, 1, 1, 0) ts = _tskit.TreeSequence() ts.load_tables(tables) - t1 = _tskit.Tree(ts) + t1 = _tskit.Tree(ts, options=_tskit.SAMPLE_LISTS) t1.first() + self.verify_kc_library_error(t1, t1) + + def verify_kc_library_error(self, t1, t2): with self.assertRaises(_tskit.LibraryError): - t1.get_kc_distance(t1, 0) + t1.get_kc_distance(t2, 0) def test_kc_distance(self): ts1 = self.get_example_tree_sequence(10, random_seed=123456) - t1 = _tskit.Tree(ts1) + t1 = _tskit.Tree(ts1, options=_tskit.SAMPLE_LISTS) t1.first() ts2 = self.get_example_tree_sequence(10, random_seed=1234) - t2 = _tskit.Tree(ts2) + t2 = _tskit.Tree(ts2, options=_tskit.SAMPLE_LISTS) t2.first() for lambda_ in [-1, 0, 1, 1000, -1e300]: x1 = t1.get_kc_distance(t2, lambda_) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 314b7a0d97..116ab1f4a3 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -121,7 +121,7 @@ def generate_segments(n, sequence_length=100, seed=None): return segs -def kc_distance(tree1, tree2, lambda_=0): +def naive_kc_distance(tree1, tree2, lambda_=0): """ Returns the Kendall-Colijn distance between the specified pair of trees. lambda_ determines weight of topology vs branch lengths in calculating @@ -134,43 +134,133 @@ def kc_distance(tree1, tree2, lambda_=0): raise ValueError("Trees must have the same samples") if not len(tree1.roots) == len(tree2.roots) == 1: raise ValueError("Trees must have one root") - sample_index_map = np.zeros(tree1.tree_sequence.num_nodes, dtype=int) - 1 - for j, u in enumerate(samples): - sample_index_map[u] = j + for u in samples: if not tree1.is_leaf(u) or not tree2.is_leaf(u): raise ValueError("Internal samples not supported") - k = samples.shape[0] - n = (k * (k - 1)) // 2 - m = [np.ones(n + k), np.ones(n + k)] - M = [np.zeros(n + k), np.zeros(n + k)] + n = samples.shape[0] + N = (n * (n - 1)) // 2 + m = [np.zeros(N + n), np.zeros(N + n)] + M = [np.zeros(N + n), np.zeros(N + n)] for tree_index, tree in enumerate([tree1, tree2]): - stack = [(tree.root, 0, tree.time(tree.root))] - while len(stack) > 0: - node, depth, time = stack.pop() - children = tree.children(node) - for child in children: - stack.append((child, depth + 1, tree.time(child))) - for c1, c2 in itertools.combinations(children, 2): - for v1 in tree.samples(c1): - index1 = sample_index_map[v1] - for v2 in tree.samples(c2): - index2 = sample_index_map[v2] - a = min(index1, index2) - b = max(index1, index2) - pair_index = a * (a - 2 * k + 1) // -2 + b - a - 1 - assert m[tree_index][pair_index] == 1 - m[tree_index][pair_index] = depth - M[tree_index][pair_index] = tree.time(tree.root) - time - if len(tree.children(node)) == 0: - index = sample_index_map[node] - M[tree_index][index + n] = tree.branch_length(node) + for sample in range(n): + m[tree_index][N + sample] = 1 + M[tree_index][N + sample] = tree.branch_length(sample) + + for n1, n2 in itertools.combinations(range(n), 2): + mrca = tree.mrca(samples[n1], samples[n2]) + depth = 0 + u = tree.parent(mrca) + while u != tskit.NULL: + depth += 1 + u = tree.parent(u) + pair_index = n1 * (n1 - 2 * n + 1) // -2 + n2 - n1 - 1 + m[tree_index][pair_index] = depth + M[tree_index][pair_index] = tree.time(tree.root) - tree.time(mrca) + return np.linalg.norm((1 - lambda_) * (m[0] - m[1]) + lambda_ * (M[0] - M[1])) -def kc_distance_simple(tree1, tree2, lambda_=0): +class KCVectors: """ - Simplified version of the kc_distance() function above. + Manages the two vectors (m and M) of a tree used to compute the + KC distance between trees. For any two samples, u and v, + m and M capture the distance of mrca(u, v) to the root in + number of edges and time, respectively. + + See Kendall & Colijn (2016): + https://academic.oup.com/mbe/article/33/10/2735/2925548 + """ + + def __init__(self, n): + self.n = n + self.N = (self.n * (self.n - 1)) // 2 + self.m = np.zeros(self.N + self.n) + self.M = np.zeros(self.N + self.n) + + +def fill_kc_vectors(tree, kc_vecs): + ts = tree.tree_sequence + sample_index_map = np.zeros(tree.tree_sequence.num_nodes) + for j, u in enumerate(tree.tree_sequence.samples()): + sample_index_map[u] = j + for root in tree.roots: + stack = [(tree.root, 0)] + while len(stack) > 0: + u, depth = stack.pop() + if tree.is_leaf(u): + time = 0 if u is root else tree.branch_length(u) + update_kc_vectors_single_leaf(ts, kc_vecs, u, time, sample_index_map) + else: + c1 = tree.left_child(u) + while c1 != tskit.NULL: + stack.append((c1, depth + 1)) + c2 = tree.right_sib(c1) + while c2 != tskit.NULL: + update_kc_vectors_all_pairs( + tree, + kc_vecs, + c1, + c2, + depth, + tree.time(root) - tree.time(u), + sample_index_map, + ) + c2 = tree.right_sib(c2) + c1 = tree.right_sib(c1) + + +def update_kc_vectors_single_leaf(ts, kc_vecs, u, time, sample_index_map): + u_index = int(sample_index_map[u]) + kc_vecs.m[kc_vecs.N + u_index] = 1 + kc_vecs.M[kc_vecs.N + u_index] = time + + +def update_kc_vectors_all_pairs(tree, kc_vecs, c1, c2, depth, time, sample_index_map): + leaves = tree.tree_sequence.samples() + c1_left = tree.left_sample(c1) + c1_right = tree.right_sample(c1) + leaf1_index = c1_left + while True: + c2_left = tree.left_sample(c2) + c2_right = tree.right_sample(c2) + leaf2_index = c2_left + while True: + leaf1 = leaves[leaf1_index] + leaf2 = leaves[leaf2_index] + update_kc_vectors_pair(kc_vecs, leaf1, leaf2, depth, time, sample_index_map) + if leaf2_index == c2_right: + break + leaf2_index = tree.next_sample(leaf2_index) + if leaf1_index == c1_right: + break + leaf1_index = tree.next_sample(leaf1_index) + + +def update_kc_vectors_pair(kc_vecs, u, v, depth, time, sample_index_map): + n1 = int(min(sample_index_map[u], sample_index_map[v])) + n2 = int(max(sample_index_map[u], sample_index_map[v])) + pair_index = n2 - n1 - 1 + (-1 * n1 * (n1 - 2 * kc_vecs.n + 1)) // 2 + + kc_vecs.m[pair_index] = depth + kc_vecs.M[pair_index] = time + + +def norm_kc_vectors(kc_vecs1, kc_vecs2, lambda_): + vT1 = 0 + vT2 = 0 + distance_sum = 0 + for i in range(kc_vecs1.n + kc_vecs1.N): + vT1 = (kc_vecs1.m[i] * (1 - lambda_)) + (lambda_ * kc_vecs1.M[i]) + vT2 = (kc_vecs2.m[i] * (1 - lambda_)) + (lambda_ * kc_vecs2.M[i]) + distance_sum += (vT1 - vT2) ** 2 + + return math.sqrt(distance_sum) + + +def c_kc_distance(tree1, tree2, lambda_=0): + """ + Simplified version of the naive_kc_distance() function above. Written without Python features to aid writing C implementation. """ samples = tree1.tree_sequence.samples() @@ -179,48 +269,16 @@ def kc_distance_simple(tree1, tree2, lambda_=0): raise ValueError("Trees must have the same samples") if not len(tree1.roots) == len(tree2.roots) == 1: raise ValueError("Trees must have one root") - sample_index_map = np.zeros(tree1.tree_sequence.num_nodes, dtype=int) - 1 - for j, u in enumerate(samples): - sample_index_map[u] = j + for u in samples: if not tree1.is_leaf(u) or not tree2.is_leaf(u): raise ValueError("Internal samples not supported") - n = samples.shape[0] - N = (n * (n - 1)) // 2 - m = [np.ones(N + n), np.ones(N + n)] - M = [np.zeros(N + n), np.zeros(N + n)] - path_distance = [np.zeros(tree1.num_nodes), np.zeros(tree2.num_nodes)] - time_distance = [np.zeros(tree1.num_nodes), np.zeros(tree2.num_nodes)] - for tree_index, tree in enumerate([tree1, tree2]): - stack = [(tree.root, 0, tree.time(tree.root))] - while len(stack) > 0: - u, depth, time = stack.pop() - children = tree.children(u) - for v in children: - stack.append((v, depth + 1, tree.time(v))) - path_distance[tree_index][u] = depth - time_distance[tree_index][u] = tree.time(tree.root) - time - if len(tree.children(u)) == 0: - u_index = sample_index_map[u] - M[tree_index][u_index + N] = tree.branch_length(u) - - for n1 in range(n): - for n2 in range(n1 + 1, n): - mrca = tree.mrca(samples[n1], samples[n2]) - pair_index = n1 * (n1 - 2 * n + 1) // -2 + n2 - n1 - 1 - assert m[tree_index][pair_index] == 1 - m[tree_index][pair_index] = path_distance[tree_index][mrca] - M[tree_index][pair_index] = time_distance[tree_index][mrca] - - vT1 = 0 - vT2 = 0 - distance_sum = 0 - for i in range(N + n): - vT1 = (m[0][i] * (1 - lambda_)) + (lambda_ * M[0][i]) - vT2 = (m[1][i] * (1 - lambda_)) + (lambda_ * M[1][i]) - distance_sum += (vT1 - vT2) ** 2 - - return math.sqrt(distance_sum) + n = tree1.tree_sequence.num_samples + vecs1 = KCVectors(n) + fill_kc_vectors(tree1, vecs1) + vecs2 = KCVectors(n) + fill_kc_vectors(tree2, vecs2) + return norm_kc_vectors(vecs1, vecs2, lambda_) class TestKCMetric(unittest.TestCase): @@ -232,52 +290,56 @@ def test_same_tree_zero_distance(self): for n in range(2, 10): for seed in range(1, 10): ts = msprime.simulate(n, random_seed=seed) - tree = ts.first() - self.assertEqual(kc_distance(tree, tree), 0) - self.assertEqual(kc_distance_simple(tree, tree), 0) + tree = next(ts.trees(sample_lists=True)) + self.assertEqual(naive_kc_distance(tree, tree), 0) + self.assertEqual(c_kc_distance(tree, tree), 0) self.assertEqual(tree.kc_distance(tree), 0) ts = msprime.simulate(n, random_seed=seed) - tree2 = ts.first() - self.assertEqual(kc_distance(tree, tree2), 0) - self.assertEqual(kc_distance_simple(tree, tree2), 0) + tree2 = next(ts.trees(sample_lists=True)) + self.assertEqual(naive_kc_distance(tree, tree2), 0) + self.assertEqual(c_kc_distance(tree, tree2), 0) self.assertEqual(tree.kc_distance(tree2), 0) def test_sample_2_zero_distance(self): # All trees with 2 leaves must be equal distance from each other. for seed in range(1, 10): - tree1 = msprime.simulate(2, random_seed=seed).first() - tree2 = msprime.simulate(2, random_seed=seed + 1).first() - self.assertEqual(kc_distance(tree1, tree2, 0), 0) - self.assertEqual(kc_distance_simple(tree1, tree2, 0), 0) + ts1 = msprime.simulate(2, random_seed=seed) + tree1 = next(ts1.trees(sample_lists=True)) + ts2 = msprime.simulate(2, random_seed=seed + 1) + tree2 = next(ts2.trees(sample_lists=True)) + self.assertEqual(naive_kc_distance(tree1, tree2, 0), 0) + self.assertEqual(c_kc_distance(tree1, tree2, 0), 0) self.assertEqual(tree1.kc_distance(tree2, 0), 0) def test_different_samples_error(self): - tree1 = msprime.simulate(10, random_seed=1).first() - tree2 = msprime.simulate(2, random_seed=1).first() - self.assertRaises(ValueError, kc_distance, tree1, tree2) - self.assertRaises(ValueError, kc_distance_simple, tree1, tree2) + tree1 = next(msprime.simulate(10, random_seed=1).trees(sample_lists=True)) + tree2 = next(msprime.simulate(2, random_seed=1).trees(sample_lists=True)) + self.assertRaises(ValueError, naive_kc_distance, tree1, tree2) + self.assertRaises(ValueError, c_kc_distance, tree1, tree2) self.assertRaises(_tskit.LibraryError, tree1.kc_distance, tree2) ts1 = msprime.simulate(10, random_seed=1) nmap = np.arange(0, ts1.num_nodes)[::-1] ts2 = tsutil.permute_nodes(ts1, nmap) - tree1 = ts1.first() - tree2 = ts2.first() - self.assertRaises(ValueError, kc_distance, tree1, tree2) - self.assertRaises(ValueError, kc_distance_simple, tree1, tree2) + tree1 = next(ts1.trees(sample_lists=True)) + tree2 = next(ts2.trees(sample_lists=True)) + self.assertRaises(ValueError, naive_kc_distance, tree1, tree2) + self.assertRaises(ValueError, c_kc_distance, tree1, tree2) self.assertRaises(_tskit.LibraryError, tree1.kc_distance, tree2) def validate_trees(self, n): for seed in range(1, 10): - tree1 = msprime.simulate(n, random_seed=seed).first() - tree2 = msprime.simulate(n, random_seed=seed + 1).first() - kc1 = kc_distance(tree1, tree2) - kc2 = kc_distance_simple(tree1, tree2) + ts1 = msprime.simulate(n, random_seed=seed) + ts2 = msprime.simulate(n, random_seed=seed + 1) + tree1 = next(ts1.trees(sample_lists=True)) + tree2 = next(ts2.trees(sample_lists=True)) + kc1 = naive_kc_distance(tree1, tree2) + kc2 = c_kc_distance(tree1, tree2) kc3 = tree1.kc_distance(tree2) self.assertAlmostEqual(kc1, kc2) self.assertAlmostEqual(kc1, kc3) - self.assertAlmostEqual(kc1, kc_distance(tree2, tree1)) - self.assertAlmostEqual(kc2, kc_distance_simple(tree2, tree1)) + self.assertAlmostEqual(kc1, naive_kc_distance(tree2, tree1)) + self.assertAlmostEqual(kc2, c_kc_distance(tree2, tree1)) self.assertAlmostEqual(kc3, tree2.kc_distance(tree1)) def test_sample_3(self): @@ -309,29 +371,38 @@ def validate_nonbinary_trees(self, n): found = True break self.assertTrue(found) - tree1 = ts.first() + tree1 = next(ts.trees(sample_lists=True)) ts = msprime.simulate( n, random_seed=seed + 1, demographic_events=demographic_events ) - tree2 = ts.first() - self.assertAlmostEqual(kc_distance(tree1, tree2), kc_distance(tree1, tree2)) - self.assertAlmostEqual(kc_distance(tree2, tree1), kc_distance(tree2, tree1)) + tree2 = next(ts.trees(sample_lists=True)) + self.assertAlmostEqual( + naive_kc_distance(tree1, tree2), naive_kc_distance(tree1, tree2) + ) self.assertAlmostEqual( - kc_distance_simple(tree1, tree2), kc_distance_simple(tree1, tree2) + naive_kc_distance(tree2, tree1), naive_kc_distance(tree2, tree1) ) self.assertAlmostEqual( - kc_distance_simple(tree2, tree1), kc_distance_simple(tree2, tree1) + c_kc_distance(tree1, tree2), c_kc_distance(tree1, tree2) + ) + self.assertAlmostEqual( + c_kc_distance(tree2, tree1), c_kc_distance(tree2, tree1) ) # compare to a binary tree also - tree2 = msprime.simulate(n, random_seed=seed + 1).first() - self.assertAlmostEqual(kc_distance(tree1, tree2), kc_distance(tree1, tree2)) - self.assertAlmostEqual(kc_distance(tree2, tree1), kc_distance(tree2, tree1)) + ts = msprime.simulate(n, random_seed=seed + 1) + tree2 = next(ts.trees(sample_lists=True)) + self.assertAlmostEqual( + naive_kc_distance(tree1, tree2), naive_kc_distance(tree1, tree2) + ) + self.assertAlmostEqual( + naive_kc_distance(tree2, tree1), naive_kc_distance(tree2, tree1) + ) self.assertAlmostEqual( - kc_distance_simple(tree1, tree2), kc_distance_simple(tree1, tree2) + c_kc_distance(tree1, tree2), c_kc_distance(tree1, tree2) ) self.assertAlmostEqual( - kc_distance_simple(tree2, tree1), kc_distance_simple(tree2, tree1) + c_kc_distance(tree2, tree1), c_kc_distance(tree2, tree1) ) def test_non_binary_sample_10(self): @@ -344,15 +415,15 @@ def test_non_binary_sample_30(self): self.validate_nonbinary_trees(30) def verify_result(self, tree1, tree2, lambda_, result, places=None): - kc1 = kc_distance(tree1, tree2, lambda_) - kc2 = kc_distance_simple(tree1, tree2, lambda_) + kc1 = naive_kc_distance(tree1, tree2, lambda_) + kc2 = c_kc_distance(tree1, tree2, lambda_) kc3 = tree1.kc_distance(tree2, lambda_) self.assertAlmostEqual(kc1, result, places=places) self.assertAlmostEqual(kc2, result, places=places) self.assertAlmostEqual(kc3, result, places=places) - kc1 = kc_distance(tree2, tree1, lambda_) - kc2 = kc_distance_simple(tree2, tree1, lambda_) + kc1 = naive_kc_distance(tree2, tree1, lambda_) + kc2 = c_kc_distance(tree2, tree1, lambda_) kc3 = tree2.kc_distance(tree1, lambda_) self.assertAlmostEqual(kc1, result, places=places) self.assertAlmostEqual(kc2, result, places=places) @@ -383,8 +454,8 @@ def test_known_kc_sample_3(self): tables_1.edges.add_row(left=l, right=r, parent=p, child=c) tables_2.edges.add_row(left=l, right=r, parent=p, child=c) - tree_1 = tables_1.tree_sequence().first() - tree_2 = tables_2.tree_sequence().first() + tree_1 = next(tables_1.tree_sequence().trees(sample_lists=True)) + tree_2 = next(tables_2.tree_sequence().trees(sample_lists=True)) self.verify_result(tree_1, tree_2, 0, 0) self.verify_result(tree_1, tree_2, 1, 4.243, places=3) @@ -490,8 +561,8 @@ def test_10_samples(self): nodes_2, edges_2, sequence_length=10000, strict=False, base64_metadata=False ) - tree_1 = ts_1.first() - tree_2 = ts_2.first() + tree_1 = next(ts_1.trees(sample_lists=True)) + tree_2 = next(ts_2.trees(sample_lists=True)) self.verify_result(tree_1, tree_2, 0, 12.85, places=2) self.verify_result(tree_1, tree_2, 1, 10.64, places=2) @@ -638,8 +709,8 @@ def test_15_samples(self): nodes_2, edges_2, sequence_length=10000, strict=False, base64_metadata=False ) - tree_1 = ts_1.first() - tree_2 = ts_2.first() + tree_1 = next(ts_1.trees(sample_lists=True)) + tree_2 = next(ts_2.trees(sample_lists=True)) self.verify_result(tree_1, tree_2, 0, 19.95, places=2) self.verify_result(tree_1, tree_2, 1, 17.74, places=2) @@ -754,8 +825,8 @@ def test_nobinary_trees(self): ts_2 = tskit.load_text( nodes_2, edges_2, sequence_length=10000, strict=False, base64_metadata=False ) - tree_1 = ts_1.first() - tree_2 = ts_2.first() + tree_1 = next(ts_1.trees(sample_lists=True)) + tree_2 = next(ts_2.trees(sample_lists=True)) self.verify_result(tree_1, tree_2, 0, 9.434, places=3) self.verify_result(tree_1, tree_2, 1, 44, places=1) @@ -773,21 +844,21 @@ def test_multiple_roots(self): ts = tables.tree_sequence() with self.assertRaises(ValueError): - kc_distance(ts.first(), ts.first(), 0) + naive_kc_distance(ts.first(), ts.first(), 0) with self.assertRaises(ValueError): - kc_distance_simple(ts.first(), ts.first(), 0) + c_kc_distance(ts.first(), ts.first(), 0) with self.assertRaises(_tskit.LibraryError): ts.first().kc_distance(ts.first(), 0) def do_kc_distance(self, t1, t2, lambda_=0): - kc1 = kc_distance(t1, t2, lambda_) - kc2 = kc_distance_simple(t1, t2, lambda_) + kc1 = naive_kc_distance(t1, t2, lambda_) + kc2 = c_kc_distance(t1, t2, lambda_) kc3 = t1.kc_distance(t2, lambda_) self.assertAlmostEqual(kc1, kc2) self.assertAlmostEqual(kc1, kc3) - kc1 = kc_distance(t2, t1, lambda_) - kc2 = kc_distance_simple(t1, t1, lambda_) + kc1 = naive_kc_distance(t2, t1, lambda_) + kc2 = c_kc_distance(t1, t1, lambda_) kc3 = t2.kc_distance(t1, lambda_) self.assertAlmostEqual(kc1, kc2) self.assertAlmostEqual(kc1, kc3) @@ -796,20 +867,20 @@ def test_non_initial_samples(self): ts1 = msprime.simulate(10, random_seed=1) nmap = np.arange(0, ts1.num_nodes)[::-1] ts2 = tsutil.permute_nodes(ts1, nmap) - t1 = ts2.first() - t2 = ts2.first() + t1 = next(ts2.trees(sample_lists=True)) + t2 = next(ts2.trees(sample_lists=True)) self.do_kc_distance(t1, t2) def test_internal_samples(self): ts1 = msprime.simulate(10, random_seed=1) ts2 = tsutil.jiggle_samples(ts1) - t1 = ts2.first() - t2 = ts2.first() + t1 = next(ts2.trees(sample_lists=True)) + t2 = next(ts2.trees(sample_lists=True)) with self.assertRaises(ValueError): - kc_distance(t1, t2) + naive_kc_distance(t1, t2) with self.assertRaises(ValueError): - kc_distance_simple(t1, t2) + c_kc_distance(t1, t2) with self.assertRaises(_tskit.LibraryError): t1.kc_distance(t2)