diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 0776a0f5a9..4a7958471f 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -25,6 +25,7 @@ #include "testlib.h" #include +#include #include #include @@ -2851,6 +2852,367 @@ test_link_ancestors_multiple_to_single_tree(void) tsk_treeseq_free(&ts); } +/* Helper method for running IBD tests */ +static int TSK_WARN_UNUSED +ibd_finder_init_and_run(tsk_ibd_finder_t *ibd_finder, tsk_table_collection_t *tables, + tsk_id_t *samples, tsk_size_t num_samples, double min_length, double max_time) +{ + int ret = 0; + + ret = tsk_ibd_finder_init(ibd_finder, tables, samples, num_samples); + if (ret != 0) { + goto out; + } + ret = tsk_ibd_finder_set_min_length(ibd_finder, min_length); + if (ret != 0) { + goto out; + } + ret = tsk_ibd_finder_set_max_time(ibd_finder, max_time); + if (ret != 0) { + goto out; + } + ret = tsk_ibd_finder_run(ibd_finder); + if (ret != 0) { + goto out; + } + +out: + return ret; +} + +static void +test_ibd_finder(void) +{ + int ret; + int j, k; + tsk_treeseq_t ts; + tsk_table_collection_t tables; + tsk_id_t samples[] = { 0, 1 }; + tsk_ibd_finder_t ibd_finder; + double true_left[] = { 0.0 }; + double true_right[] = { 1.0 }; + tsk_id_t true_node[] = { 4 }; + tsk_segment_t *seg = NULL; + + // Read in the tree sequence. + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, + NULL, NULL, NULL, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = ibd_finder_init_and_run(&ibd_finder, &tables, samples, 1, 0.0, DBL_MAX); + + // Check the output. + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < (int) ibd_finder.num_pairs; j++) { + tsk_ibd_finder_get_ibd_segments(&ibd_finder, j, &seg); + CU_ASSERT_EQUAL_FATAL(ret, 0); + k = 0; + while (seg != NULL) { + CU_ASSERT_EQUAL_FATAL(seg->left, true_left[k]); + CU_ASSERT_EQUAL_FATAL(seg->right, true_right[k]); + CU_ASSERT_EQUAL_FATAL(seg->node, true_node[k]); + k++; + seg = seg->next; + } + } + tsk_ibd_finder_print_state(&ibd_finder, _devnull); + + // Free. + tsk_ibd_finder_free(&ibd_finder); + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_ibd_finder_multiple_trees(void) +{ + int ret; + int j, k; + tsk_treeseq_t ts; + tsk_table_collection_t tables; + tsk_id_t samples[] = { 0, 1, 0, 2 }; + tsk_ibd_finder_t ibd_finder; + double true_left[2][2] = { { 0.0, 0.7 }, { 0.7, 0.0 } }; + double true_right[2][2] = { { 0.7, 1.0 }, { 1.0, 0.7 } }; + double true_node[2][2] = { { 4, 5 }, { 5, 6 } }; + tsk_segment_t *seg = NULL; + + // Read in the tree sequence. + tsk_treeseq_from_text(&ts, 2, multiple_tree_ex_nodes, multiple_tree_ex_edges, NULL, + NULL, NULL, NULL, NULL, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // Run ibd_finder. + ret = ibd_finder_init_and_run(&ibd_finder, &tables, samples, 2, 0.0, DBL_MAX); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // Check the output. + for (j = 0; j < (int) ibd_finder.num_pairs; j++) { + ret = tsk_ibd_finder_get_ibd_segments(&ibd_finder, j, &seg); + CU_ASSERT_EQUAL_FATAL(ret, 0); + k = 0; + while (seg != NULL) { + CU_ASSERT_EQUAL_FATAL(seg->left, true_left[j][k]); + CU_ASSERT_EQUAL_FATAL(seg->right, true_right[j][k]); + CU_ASSERT_EQUAL_FATAL(seg->node, true_node[j][k]); + k++; + seg = seg->next; + } + } + + // Free. + tsk_ibd_finder_free(&ibd_finder); + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_ibd_finder_empty_result(void) +{ + int ret; + int j; + tsk_treeseq_t ts; + tsk_table_collection_t tables; + tsk_id_t samples[] = { 0, 1 }; + tsk_ibd_finder_t ibd_finder; + tsk_segment_t *seg = NULL; + + // Read in the tree sequence. + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, + NULL, NULL, NULL, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // Run ibd_finder. + ret = ibd_finder_init_and_run(&ibd_finder, &tables, samples, 1, 0.0, 0.5); + + // Check the output. + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < (int) ibd_finder.num_pairs; j++) { + tsk_ibd_finder_get_ibd_segments(&ibd_finder, j, &seg); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(seg, NULL); + } + tsk_ibd_finder_print_state(&ibd_finder, _devnull); + + // Free. + tsk_ibd_finder_free(&ibd_finder); + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_ibd_finder_min_length_max_time(void) +{ + int ret; + int j, k; + tsk_treeseq_t ts; + tsk_table_collection_t tables; + tsk_id_t samples[] = { 0, 1, 1, 2, 2, 0 }; + tsk_ibd_finder_t ibd_finder; + double true_left[3][1] = { { 0.0 }, { -1 }, { -1 } }; + double true_right[3][1] = { { 0.7 }, { -1 }, { -1 } }; + double true_node[3][1] = { { 4 }, { -1 }, { -1 } }; + tsk_segment_t *seg = NULL; + + // Read in the tree sequence. + tsk_treeseq_from_text(&ts, 2, multiple_tree_ex_nodes, multiple_tree_ex_edges, NULL, + NULL, NULL, NULL, NULL, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // Run ibd_finder. + ret = ibd_finder_init_and_run(&ibd_finder, &tables, samples, 3, 0.5, 3.0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // Check the output. + for (j = 0; j < (int) ibd_finder.num_pairs; j++) { + ret = tsk_ibd_finder_get_ibd_segments(&ibd_finder, j, &seg); + CU_ASSERT_TRUE_FATAL((ret == 0) || (ret == -1)); + if (ret == -1) { + continue; + } + k = 0; + while (seg != NULL) { + CU_ASSERT_EQUAL_FATAL(seg->left, true_left[j][k]); + CU_ASSERT_EQUAL_FATAL(seg->right, true_right[j][k]); + CU_ASSERT_EQUAL_FATAL(seg->node, true_node[j][k]); + k++; + seg = seg->next; + } + } + + // Free. + tsk_ibd_finder_free(&ibd_finder); + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_ibd_finder_errors(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_table_collection_t tables; + tsk_id_t samples[] = { 0, 1, 2, 0 }; + tsk_id_t samples2[] = { -1, 1 }; + tsk_id_t samples3[] = { 0 }; + tsk_ibd_finder_t ibd_finder; + + tsk_treeseq_from_text(&ts, 2, multiple_tree_ex_nodes, multiple_tree_ex_edges, NULL, + NULL, NULL, NULL, NULL, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // Invalid sample IDs + ret = ibd_finder_init_and_run(&ibd_finder, &tables, samples2, 1, 0.0, DBL_MAX); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); + tsk_ibd_finder_free(&ibd_finder); + + // Only 1 sample + ret = ibd_finder_init_and_run(&ibd_finder, &tables, samples3, 0, 0.0, DBL_MAX); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NO_SAMPLE_PAIRS); + tsk_ibd_finder_free(&ibd_finder); + + // Bad length or time + ret = ibd_finder_init_and_run(&ibd_finder, &tables, samples, 2, 0.0, -1); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + tsk_ibd_finder_free(&ibd_finder); + ret = ibd_finder_init_and_run(&ibd_finder, &tables, samples, 2, -1, 0.0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE); + tsk_ibd_finder_free(&ibd_finder); + + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_ibd_finder_samples_are_descendants(void) +{ + int ret; + int j, k; + tsk_treeseq_t ts; + tsk_table_collection_t tables; + tsk_id_t samples[] = { 0, 2, 0, 4, 2, 4, 1, 3, 1, 5, 3, 5 }; + tsk_ibd_finder_t ibd_finder; + double true_left[6][1] = { { 0.0 }, { 0.0 }, { 0.0 }, { 0.0 }, { 0.0 }, { 0.0 } }; + double true_right[6][1] = { { 1.0 }, { 1.0 }, { 1.0 }, { 1.0 }, { 1.0 }, { 1.0 } }; + tsk_id_t true_node[6][1] = { { 2 }, { 4 }, { 4 }, { 3 }, { 5 }, { 5 } }; + tsk_segment_t *seg = NULL; + + // Read in the tree sequence. + tsk_treeseq_from_text(&ts, 1, multi_root_tree_ex_nodes, multi_root_tree_ex_edges, + NULL, NULL, NULL, NULL, NULL, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // Run ibd_finder. + ret = ibd_finder_init_and_run(&ibd_finder, &tables, samples, 6, 0.0, DBL_MAX); + + // Check the output. + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < (int) ibd_finder.num_pairs; j++) { + tsk_ibd_finder_get_ibd_segments(&ibd_finder, j, &seg); + CU_ASSERT_EQUAL_FATAL(ret, 0); + k = 0; + while (seg != NULL) { + CU_ASSERT_EQUAL_FATAL(seg->left, true_left[j][k]); + CU_ASSERT_EQUAL_FATAL(seg->right, true_right[j][k]); + CU_ASSERT_EQUAL_FATAL(seg->node, true_node[j][k]); + k++; + seg = seg->next; + } + } + tsk_ibd_finder_print_state(&ibd_finder, _devnull); + + // Free. + tsk_ibd_finder_free(&ibd_finder); + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_ibd_finder_multiple_ibd_paths(void) +{ + int ret; + int j, k; + tsk_treeseq_t ts; + tsk_table_collection_t tables; + tsk_id_t samples[] = { 0, 1, 0, 2, 1, 2 }; + tsk_ibd_finder_t ibd_finder; + double true_left[3][2] = { { 0.2, 0.0 }, { 0.2, 0.0 }, { 0.0, 0.2 } }; + double true_right[3][2] = { { 1.0, 0.2 }, { 1.0, 0.2 }, { 0.2, 1.0 } }; + double true_node[3][2] = { { 4, 5 }, { 3, 5 }, { 4, 4 } }; + tsk_segment_t *seg = NULL; + + // Read in the tree sequence. + tsk_treeseq_from_text(&ts, 2, multi_path_tree_ex_nodes, multi_path_tree_ex_edges, + NULL, NULL, NULL, NULL, NULL, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // Run ibd_finder. + ret = ibd_finder_init_and_run(&ibd_finder, &tables, samples, 3, 0.0, 0.0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // Check the output. + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < (int) ibd_finder.num_pairs; j++) { + tsk_ibd_finder_get_ibd_segments(&ibd_finder, j, &seg); + CU_ASSERT_EQUAL_FATAL(ret, 0); + k = 0; + while (seg != NULL) { + CU_ASSERT_EQUAL_FATAL(seg->left, true_left[j][k]); + CU_ASSERT_EQUAL_FATAL(seg->right, true_right[j][k]); + CU_ASSERT_EQUAL_FATAL(seg->node, true_node[j][k]); + k++; + seg = seg->next; + } + } + tsk_ibd_finder_print_state(&ibd_finder, _devnull); + + // Free. + tsk_ibd_finder_free(&ibd_finder); + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_ibd_finder_odd_topologies(void) +{ + int ret; + // int j; + tsk_treeseq_t ts; + tsk_table_collection_t tables; + tsk_id_t samples[] = { 0, 1 }; + tsk_id_t samples1[] = { 0, 2 }; + tsk_ibd_finder_t ibd_finder; + + tsk_treeseq_from_text( + &ts, 1, odd_tree1_ex_nodes, odd_tree1_ex_edges, NULL, NULL, NULL, NULL, NULL, 0); + ret = tsk_treeseq_copy_tables(&ts, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + // Multiple roots. + ret = ibd_finder_init_and_run(&ibd_finder, &tables, samples, 1, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_ibd_finder_free(&ibd_finder); + + // // Parent is a sample. + ret = ibd_finder_init_and_run(&ibd_finder, &tables, samples1, 1, 0, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_ibd_finder_free(&ibd_finder); + + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + static void test_simplify_tables_drops_indexes(void) { @@ -4627,6 +4989,15 @@ main(int argc, char **argv) { "test_link_ancestors_multiple_to_single_tree", test_link_ancestors_multiple_to_single_tree }, { "test_sort_tables_offsets", test_sort_tables_offsets }, + { "test_ibd_finder", test_ibd_finder }, + { "test_ibd_finder_multiple_trees", test_ibd_finder_multiple_trees }, + { "test_ibd_finder_empty_result", test_ibd_finder_empty_result }, + { "test_ibd_finder_min_length_max_time", test_ibd_finder_min_length_max_time }, + { "test_ibd_finder_samples_are_descendants", + test_ibd_finder_samples_are_descendants }, + { "test_ibd_finder_multiple_ibd_paths", test_ibd_finder_multiple_ibd_paths }, + { "test_ibd_finder_odd_topologies", test_ibd_finder_odd_topologies }, + { "test_ibd_finder_errors", test_ibd_finder_errors }, { "test_sort_tables_drops_indexes", test_sort_tables_drops_indexes }, { "test_sort_tables_edge_metadata", test_sort_tables_edge_metadata }, { "test_sort_tables_no_edge_metadata", test_sort_tables_no_edge_metadata }, diff --git a/c/tests/testlib.c b/c/tests/testlib.c index 5a4b6552c4..aca2ad3bbb 100644 --- a/c/tests/testlib.c +++ b/c/tests/testlib.c @@ -27,7 +27,7 @@ char *_tmp_file_name; FILE *_devnull; -/*** Simple single tree example. ***/ +/* Simple single tree example. */ const char *single_tree_ex_nodes = /* 6 */ "1 0 -1 -1\n" /* / \ */ "1 0 -1 -1\n" /* / \ */ @@ -178,7 +178,89 @@ const char *unary_ex_mutations = "0 2 1\n" "1 6 1\n" "2 5 1\n"; -/*** An example of a tree sequence with internally sampled nodes. ***/ +/* An example of a simple tree sequence with multiple marginal trees. */ + +/* Simple single tree example. */ +const char *multiple_tree_ex_nodes = /* */ + "1 0 -1 -1\n" /* 6 | */ + "1 0 -1 -1\n" /* / \ | */ + "1 0 -1 -1\n" /* / \ | 5 */ + "0 1 -1 -1\n" /* 4 \ | / \ */ + "0 2 -1 -1\n" /* / \ \ | / 3 */ + "0 3 -1 -1\n" /* / \ \ | / / \ */ + "0 4 -1 -1\n"; /* 0 1 2 | 0 1 2 */ + /* |----------------|---------------| */ + /* 0 1 2 */ + +const char *multiple_tree_ex_edges = "0.7 1.0 3 1,2\n" + "0.0 0.7 4 0,1\n" + "0.7 1.0 5 0,3\n" + "0.0 0.7 6 2,4\n"; + +/* Odd topology -- different roots. */ + +const char *odd_tree1_ex_nodes = /* | | 5 */ + "1 0 -1 -1\n" /* | 4 | | */ + "1 0 -1 -1\n" /* 3 | | | | */ + "0 1 -1 -1\n" /* | | | | | */ + "0 2 -1 -1\n" /* 2 | 2 | 2 */ + "0 3 -1 -1\n" /* / \ | / \ | / \ */ + "0 4 -1 -1\n"; /* 0 1 | 0 1 | 0 1 */ + /* |------|-------|------| */ + /* 0.0 0.2 0.7 1.0*/ + +const char *odd_tree1_ex_edges = "0.0 1.0 2 0,1\n" + "0.0 0.2 3 2\n" + "0.2 0.7 4 2\n" + "0.7 1.0 4 2\n"; + +/* An example where some samples descend from other samples, and multiple roots */ + +const char *multi_root_tree_ex_nodes = "1 0 -1 -1\n" /* 4 5 */ + "1 0 -1 -1\n" /* | | */ + "1 1 -1 -1\n" /* 2 3 */ + "1 1 -1 -1\n" /* | | */ + "0 2 -1 -1\n" /* 0 1 */ + "0 2 -1 -1\n"; + +const char *multi_root_tree_ex_edges = "0 1 2 0\n" + "0 1 3 1\n" + "0 1 4 2\n" + "0 1 5 3\n"; + +/* Examples of tree sequences where samples have different paths to the same ancestor. */ + +const char *multi_path_tree_ex_nodes = /* 5 | */ + "1 0 -1 -1\n" /* / \ | */ + "1 0 -1 -1\n" /* / 4 | 4 */ + "1 0 -1 -1\n" /* / / \ | / \ */ + "0 1 -1 -1\n" /* / / \ | 3 \ */ + "0 2 -1 -1\n" /* / / \ | / \ \ */ + "0 3 -1 -1\n"; /* 0 2 1 | 0 2 1 */ + /*----------------|------------ */ + /*0.0 0.2 1.0*/ + +const char *multi_path_tree_ex_edges = "0.2 1.0 3 0\n" + "0.2 1.0 3 2\n" + "0.0 1.0 4 1\n" + "0.0 0.2 4 2\n" + "0.2 1.0 4 3\n" + "0.0 0.2 5 0\n" + "0.0 0.2 5 4\n"; + +const char *multi_path_tree_ex2_nodes = "1 0 -1 -1\n" + "1 0 -1 -1\n" + "0 1 -1 -1\n" + "0 2 -1 -1\n" + "0 3 -1 -1\n"; + +const char *multi_path_tree_ex2_edges = "0.6 1.0 2 1\n" + "0.0 1.0 3 0\n" + "0.0 0.6 4 1\n" + "0.6 1.0 4 2\n" + "0.0 1.0 4 3\n"; + +/* An example of a tree sequence with internally sampled nodes. */ /* 1.20┊ ┊ 8 ┊ ┊ diff --git a/c/tests/testlib.h b/c/tests/testlib.h index e70a0f37bf..266f6dd239 100644 --- a/c/tests/testlib.h +++ b/c/tests/testlib.h @@ -60,6 +60,18 @@ extern const char *single_tree_ex_edges; extern const char *single_tree_ex_sites; extern const char *single_tree_ex_mutations; +extern const char *multiple_tree_ex_nodes; +extern const char *multiple_tree_ex_edges; + +extern const char *odd_tree1_ex_nodes; +extern const char *odd_tree1_ex_edges; + +extern const char *multi_root_tree_ex_nodes; +extern const char *multi_root_tree_ex_edges; + +extern const char *multi_path_tree_ex_nodes; +extern const char *multi_path_tree_ex_edges; + extern const char *nonbinary_ex_nodes; extern const char *nonbinary_ex_edges; extern const char *nonbinary_ex_sites; diff --git a/c/tskit/core.c b/c/tskit/core.c index 162c76be9a..9b35da049a 100644 --- a/c/tskit/core.c +++ b/c/tskit/core.c @@ -461,6 +461,12 @@ tsk_strerror_internal(int err) // histories could be equivalent, because subset does not reorder // edges (if not sorted) or mutations. ret = "Shared portions of the tree sequences are not equal."; + break; + + /* IBD errors */ + case TSK_ERR_NO_SAMPLE_PAIRS: + ret = "There are no possible sample pairs."; + break; } return ret; } diff --git a/c/tskit/core.h b/c/tskit/core.h index efa040f840..871d10145f 100644 --- a/c/tskit/core.h +++ b/c/tskit/core.h @@ -315,6 +315,9 @@ not found in the file. #define TSK_ERR_UNION_BAD_MAP -1400 #define TSK_ERR_UNION_DIFF_HISTORIES -1401 +/* IBD errors */ +#define TSK_ERR_NO_SAMPLE_PAIRS -1500 + // clang-format on /* This bit is 0 for any errors originating from kastore */ diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 742e49af68..6638b31a74 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -4685,13 +4685,6 @@ tsk_table_sorter_free(tsk_table_sorter_t *self) * segment overlapper *************************/ -typedef struct _tsk_segment_t { - double left; - double right; - struct _tsk_segment_t *next; - tsk_id_t node; -} tsk_segment_t; - typedef struct _interval_list_t { double left; double right; @@ -5388,6 +5381,470 @@ ancestor_mapper_run(ancestor_mapper_t *self) return ret; } +/************************* + * IBD finder + *************************/ + +static tsk_segment_t *TSK_WARN_UNUSED +tsk_ibd_finder_alloc_segment( + tsk_ibd_finder_t *self, double left, double right, tsk_id_t node) +{ + tsk_segment_t *seg = NULL; + + seg = tsk_blkalloc_get(&self->segment_heap, sizeof(*seg)); + if (seg == NULL) { + goto out; + } + seg->next = NULL; + seg->left = left; + seg->right = right; + seg->node = node; + +out: + return seg; +} + +static int TSK_WARN_UNUSED +tsk_ibd_finder_add_output( + tsk_ibd_finder_t *self, double left, double right, tsk_id_t node_id, int pair_num) +{ + int ret = 0; + tsk_segment_t *tail = self->ibd_segments_tail[pair_num]; + tsk_segment_t *x; + + assert(left < right); + if (tail == NULL) { + x = tsk_ibd_finder_alloc_segment(self, left, right, node_id); + if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + self->ibd_segments_head[pair_num] = x; + self->ibd_segments_tail[pair_num] = x; + } else { + x = tsk_ibd_finder_alloc_segment(self, left, right, node_id); + if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + tail->next = x; + self->ibd_segments_tail[pair_num] = x; + } +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_ibd_finder_add_ancestry(tsk_ibd_finder_t *self, tsk_id_t input_id, double left, + double right, tsk_id_t output_id) +{ + int ret = 0; + tsk_segment_t *tail = self->ancestor_map_tail[input_id]; + tsk_segment_t *x = NULL; + + assert(left < right); + if (tail == NULL) { + x = tsk_ibd_finder_alloc_segment(self, left, right, output_id); + if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + self->ancestor_map_head[input_id] = x; + self->ancestor_map_tail[input_id] = x; + } else { + x = tsk_ibd_finder_alloc_segment(self, left, right, output_id); + if (x == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + tail->next = x; + self->ancestor_map_tail[input_id] = x; + } +out: + return ret; +} + +static int +tsk_ibd_finder_init_samples(tsk_ibd_finder_t *self) +{ + int ret = 0; + size_t j; + tsk_id_t u; + + /* Go through the sample pairs to define samples. */ + for (j = 0; j < 2 * self->num_pairs; j++) { + u = self->pairs[j]; + + if (u < 0 || u > (tsk_id_t) self->tables->nodes.num_rows) { + ret = TSK_ERR_NODE_OUT_OF_BOUNDS; + goto out; + } + + if (!self->is_sample[u]) { + self->is_sample[u] = true; + ret = tsk_ibd_finder_add_ancestry( + self, u, 0, self->tables->sequence_length, u); + if (ret != 0) { + goto out; + } + } + } + +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_ibd_finder_init(tsk_ibd_finder_t *self, tsk_table_collection_t *tables, + tsk_id_t *pairs, tsk_size_t num_pairs) +{ + int ret = 0; + size_t num_nodes_alloc; + + memset(self, 0, sizeof(tsk_ibd_finder_t)); + self->pairs = pairs; + self->num_pairs = num_pairs; + self->sequence_length = tables->sequence_length; + self->num_nodes = tables->nodes.num_rows; + self->tables = tables; + self->max_time = DBL_MAX; + self->min_length = 0; + + if (pairs == NULL || num_pairs < 1) { + ret = TSK_ERR_NO_SAMPLE_PAIRS; + goto out; + } + + // Allocate the heaps used for small objects. + ret = tsk_blkalloc_init(&self->segment_heap, 8192); + if (ret != 0) { + goto out; + } + + // Mallocing and callocing. + num_nodes_alloc = 1 + tables->nodes.num_rows; + self->ancestor_map_head = calloc(num_nodes_alloc, sizeof(*self->ancestor_map_head)); + self->ancestor_map_tail = calloc(num_nodes_alloc, sizeof(*self->ancestor_map_tail)); + self->ibd_segments_head = calloc(self->num_pairs, sizeof(*self->ibd_segments_head)); + self->ibd_segments_tail = calloc(self->num_pairs, sizeof(*self->ibd_segments_tail)); + self->is_sample = calloc(num_nodes_alloc, sizeof(*self->is_sample)); + self->segment_queue_size = 0; + self->max_segment_queue_size = 64; + self->segment_queue + = malloc(self->max_segment_queue_size * sizeof(*self->segment_queue)); + if (self->ancestor_map_head == NULL || self->ancestor_map_tail == NULL + || self->ibd_segments_head == NULL || self->ibd_segments_tail == NULL + || self->is_sample == NULL || self->segment_queue == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + + ret = tsk_ibd_finder_init_samples(self); + if (ret != 0) { + goto out; + } + +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_ibd_finder_set_min_length(tsk_ibd_finder_t *self, double min_length) +{ + int ret = 0; + + if (min_length < 0) { + ret = TSK_ERR_BAD_PARAM_VALUE; + } + self->min_length = min_length; + return ret; +} + +int TSK_WARN_UNUSED +tsk_ibd_finder_set_max_time(tsk_ibd_finder_t *self, double max_time) +{ + int ret = 0; + + if (max_time < 0) { + ret = TSK_ERR_BAD_PARAM_VALUE; + } + self->max_time = max_time; + return ret; +} + +static int TSK_WARN_UNUSED +tsk_ibd_finder_enqueue_segment( + tsk_ibd_finder_t *self, double left, double right, tsk_id_t node) +{ + int ret = 0; + tsk_segment_t *seg; + void *p; + + assert(left < right); + /* Make sure we always have room for one more segment in the queue so we + * can put a tail sentinel on it */ + if (self->segment_queue_size == self->max_segment_queue_size - 1) { + self->max_segment_queue_size *= 2; + p = realloc(self->segment_queue, + self->max_segment_queue_size * sizeof(*self->segment_queue)); + if (p == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + self->segment_queue = p; + } + seg = self->segment_queue + self->segment_queue_size; + seg->left = left; + seg->right = right; + seg->node = node; + self->segment_queue_size++; +out: + return ret; +} + +static int +tsk_ibd_finder_find_sample_pair_index2( + tsk_ibd_finder_t *self, tsk_id_t sample0, tsk_id_t sample1) +{ + int i = 0; + int ret = -1; + tsk_id_t s0, s1; + + for (i = 0; i < (tsk_id_t) self->num_pairs; i++) { + s0 = self->pairs[2 * i]; + s1 = self->pairs[2 * i + 1]; + if ((s0 == sample0 && s1 == sample1) || (s0 == sample1 && s1 == sample0)) { + ret = i; + goto out; + } + } +out: + return ret; +} + +static int TSK_WARN_UNUSED +tsk_ibd_finder_calculate_ibd(tsk_ibd_finder_t *self, tsk_id_t current_parent) +{ + int ret = 0; + int j, pair_index; + tsk_segment_t *seg, *seg0, *seg1; + double left, right; + + if (self->ancestor_map_head[current_parent] == NULL) { + for (j = 0; j != (int) self->segment_queue_size; j++) { + seg = &self->segment_queue[j]; + ret = tsk_ibd_finder_add_ancestry( + self, current_parent, seg->left, seg->right, seg->node); + if (ret != 0) { + goto out; + } + } + } else { + for (seg0 = self->ancestor_map_head[current_parent]; seg0 != NULL; + seg0 = seg0->next) { + for (j = 0; j != (int) self->segment_queue_size; j++) { + seg1 = &self->segment_queue[j]; + if (seg0->node == seg1->node) { + continue; + } + ret = tsk_ibd_finder_find_sample_pair_index2( + self, seg0->node, seg1->node); + if (ret < 0) { + continue; + } + pair_index = ret; + + if (seg0->left > seg1->left) { + left = seg0->left; + } else { + left = seg1->left; + } + if (seg0->right < seg1->right) { + right = seg0->right; + } else { + right = seg1->right; + } + + if (left < right) { + if (right - left > self->min_length) { + ret = tsk_ibd_finder_add_output( + self, left, right, current_parent, pair_index); + if (ret != 0) { + goto out; + } + } + } + } + } + for (j = 0; j != (int) self->segment_queue_size; j++) { + seg = &self->segment_queue[j]; + ret = tsk_ibd_finder_add_ancestry( + self, current_parent, seg->left, seg->right, seg->node); + if (ret != 0) { + goto out; + } + } + } + self->segment_queue_size = 0; + +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_ibd_finder_get_ibd_segments( + tsk_ibd_finder_t *self, tsk_id_t pair_index, tsk_segment_t **ret_ibd_segments_head) +{ + int ret = 0; + + if (((pair_index < 0) || (pair_index >= (tsk_id_t) self->num_pairs))) { + ret = TSK_ERR_NO_SAMPLE_PAIRS; + goto out; + } + if (self->ibd_segments_head[pair_index] != NULL) { + *ret_ibd_segments_head = self->ibd_segments_head[pair_index]; + } else { + ret = -1; + } +out: + return ret; +} + +void +tsk_ibd_finder_print_state(tsk_ibd_finder_t *self, FILE *out) +{ + size_t j; + tsk_segment_t *u = NULL; + + fprintf(out, "--ibd-finder stats--\n"); + fprintf(out, "===\nEdge table\n==\n"); + for (j = 0; j < self->tables->edges.num_rows; j++) { + fprintf(out, "L:%f, R:%f, P:%d, C:%d\n", self->tables->edges.left[j], + self->tables->edges.right[j], self->tables->edges.parent[j], + self->tables->edges.child[j]); + } + fprintf(out, "===\nNode table\n==\n"); + for (j = 0; j < self->tables->nodes.num_rows; j++) { + fprintf(out, "ID:%f, Time:%f, Flag:%d\n", (double) j, + self->tables->nodes.time[j], self->tables->nodes.flags[j]); + } + fprintf(out, "==\nSample pairs\n==\n"); + for (j = 0; j < 2 * self->num_pairs; j++) { + fprintf(out, "%i ", (int) self->pairs[j]); + if (j % 2 != 0) { + fprintf(out, "\n"); + } + } + fprintf(out, "===\nAncestral map\n==\n"); + for (j = 0; j < self->tables->nodes.num_rows; j++) { + fprintf(out, "Node %d: ", (int) j); + for (u = self->ancestor_map_head[j]; u != NULL; u = u->next) { + fprintf(out, "(%f,%f->%d)", u->left, u->right, u->node); + } + fprintf(out, "\n"); + } + fprintf(out, "===\nIBD segments\n===\n"); + for (j = 0; j < self->num_pairs; j++) { + fprintf(out, "Pair (%i, %i)\n", (int) self->pairs[2 * j], + (int) self->pairs[2 * j + 1]); + for (u = self->ibd_segments_head[j]; u != NULL; u = u->next) { + fprintf(out, "(%f,%f->%d)", u->left, u->right, u->node); + } + fprintf(out, "\n"); + } +} + +int TSK_WARN_UNUSED +tsk_ibd_finder_run(tsk_ibd_finder_t *self) +{ + const tsk_edge_table_t *input_edges = &self->tables->edges; + int ret = 0; + size_t j; + tsk_id_t u; + tsk_id_t current_parent = -1; + size_t num_edges = input_edges->num_rows; + tsk_segment_t *seg; + tsk_segment_t *s; + double intvl_l, intvl_r, current_time; + bool parent_should_be_added = true; + + for (j = 0; j < num_edges; j++) { + + if (current_parent >= 0 && current_parent != input_edges->parent[j]) { + parent_should_be_added = true; + } + + // Stop if the processed node's time exceeds the max time. + current_parent = input_edges->parent[j]; + current_time = self->tables->nodes.time[current_parent]; + if (current_time > self->max_time) { + goto out; + } + + // Extract segment. + seg = tsk_ibd_finder_alloc_segment( + self, input_edges->left[j], input_edges->right[j], input_edges->child[j]); + // Create a SegmentList holding all of the sample segments descending from + // seg. + u = seg->node; + if (self->is_sample[u]) { + ret = tsk_ibd_finder_enqueue_segment(self, seg->left, seg->right, seg->node); + if (ret != 0) { + goto out; + } + } else { + for (s = self->ancestor_map_head[u]; s != NULL; s = s->next) { + if (seg->left > s->left) { + intvl_l = seg->left; + } else { + intvl_l = s->left; + } + if (seg->right < s->right) { + intvl_r = seg->right; + } else { + intvl_r = s->right; + } + // Add to the segment queue. + if (intvl_r - intvl_l > 0) { + ret = tsk_ibd_finder_enqueue_segment( + self, intvl_l, intvl_r, s->node); + if (ret != 0) { + goto out; + } + } + } + } + + // Calculate new ibd segments descending from the current parent. + if (self->segment_queue_size > 0) { + ret = tsk_ibd_finder_calculate_ibd(self, current_parent); + } + if (ret != 0) { + goto out; + } + + // For samples that appear in the parent column of the edge table + if (self->is_sample[current_parent] && parent_should_be_added) { + parent_should_be_added = false; + } + } +out: + return ret; +} + +int TSK_WARN_UNUSED +tsk_ibd_finder_free(tsk_ibd_finder_t *self) +{ + tsk_safe_free(self->ibd_segments_head); + tsk_safe_free(self->ibd_segments_tail); + tsk_blkalloc_free(&self->segment_heap); + tsk_safe_free(self->is_sample); + tsk_safe_free(self->ancestor_map_head); + tsk_safe_free(self->ancestor_map_tail); + tsk_safe_free(self->segment_queue); + return 0; +} + /************************* * simplifier *************************/ diff --git a/c/tskit/tables.h b/c/tskit/tables.h index bd90670bec..91d549182f 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -629,6 +629,36 @@ typedef struct _tsk_table_sorter_t { tsk_id_t *site_id_map; } tsk_table_sorter_t; +/* Structs for IBD finding. + * TODO: document properly + * */ + +typedef struct _tsk_segment_t { + double left; + double right; + struct _tsk_segment_t *next; + tsk_id_t node; +} tsk_segment_t; + +typedef struct { + tsk_id_t *pairs; + size_t num_pairs; + size_t num_nodes; + double sequence_length; + tsk_table_collection_t *tables; + tsk_segment_t **ibd_segments_head; + tsk_segment_t **ibd_segments_tail; + tsk_blkalloc_t segment_heap; + bool *is_sample; + double min_length; + double max_time; + tsk_segment_t **ancestor_map_head; + tsk_segment_t **ancestor_map_tail; + tsk_segment_t *segment_queue; + size_t segment_queue_size; + size_t max_segment_queue_size; +} tsk_ibd_finder_t; + /****************************************************************************/ /* Common function options */ /****************************************************************************/ @@ -2732,7 +2762,6 @@ tsk_id_t tsk_table_collection_check_integrity( int tsk_table_collection_link_ancestors(tsk_table_collection_t *self, tsk_id_t *samples, tsk_size_t num_samples, tsk_id_t *ancestors, tsk_size_t num_ancestors, tsk_flags_t options, tsk_edge_table_t *result); - int tsk_table_collection_deduplicate_sites( tsk_table_collection_t *tables, tsk_flags_t options); int tsk_table_collection_compute_mutation_parents( @@ -2819,6 +2848,17 @@ int tsk_table_sorter_free(struct _tsk_table_sorter_t *self); int tsk_squash_edges( tsk_edge_t *edges, tsk_size_t num_edges, tsk_size_t *num_output_edges); +/* IBD finder API. This is experimental and the interface may change. */ +int tsk_ibd_finder_init(tsk_ibd_finder_t *ibd_finder, tsk_table_collection_t *tables, + tsk_id_t *pairs, tsk_size_t num_pairs); +int tsk_ibd_finder_set_min_length(tsk_ibd_finder_t *self, double min_length); +int tsk_ibd_finder_set_max_time(tsk_ibd_finder_t *self, double max_time); +int tsk_ibd_finder_free(tsk_ibd_finder_t *self); +int tsk_ibd_finder_run(tsk_ibd_finder_t *ibd_finder); +int tsk_ibd_finder_get_ibd_segments(tsk_ibd_finder_t *ibd_finder, tsk_id_t pair_index, + tsk_segment_t **ret_ibd_segments_head); +void tsk_ibd_finder_print_state(tsk_ibd_finder_t *self, FILE *out); + #ifdef __cplusplus } #endif diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 4526aac3b3..472a678a5d 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -6632,6 +6632,146 @@ TableCollection_union(TableCollection *self, PyObject *args, PyObject *kwds) { return ret; } +static PyObject * +convert_ibd_segments(tsk_ibd_finder_t *ibd_finder, tsk_id_t *pairs, + tsk_size_t num_pairs) +{ + PyObject *ret = NULL; + PyObject *key = NULL; + PyObject *value = NULL; + PyArrayObject *left_array = NULL; + PyArrayObject *right_array = NULL; + PyArrayObject *node_array = NULL; + double *left, *right; + int err; + tsk_id_t *node; + tsk_size_t j, seg_index; + tsk_segment_t *u, *head; + PyObject *pair_dict = PyDict_New(); + npy_intp num_segments; + + if (pair_dict == NULL) { + goto out; + } + + for (j = 0; j < num_pairs; j++) { + err = tsk_ibd_finder_get_ibd_segments(ibd_finder, j, &head); + if (err == -1) { + head = NULL; + } else if (err != 0) { + handle_library_error(err); + goto out; + } + num_segments = 0; + for (u = head; u != NULL; u = u->next) { + num_segments++; + } + /* For each pair we return an array of left, right, node values */ + left_array = (PyArrayObject *) PyArray_SimpleNew(1, &num_segments, NPY_FLOAT64); + right_array = (PyArrayObject *) PyArray_SimpleNew(1, &num_segments, NPY_FLOAT64); + node_array = (PyArrayObject *) PyArray_SimpleNew(1, &num_segments, NPY_INT32); + if (left_array == NULL || right_array == NULL || node_array == NULL) { + goto out; + } + left = (double *) PyArray_DATA(left_array); + right = (double *) PyArray_DATA(right_array); + node = (tsk_id_t *) PyArray_DATA(node_array); + seg_index = 0; + for (u = head; u != NULL; u = u->next) { + left[seg_index] = u->left; + right[seg_index] = u->right; + node[seg_index] = u->node; + seg_index++; + } + key = Py_BuildValue("(ii)", pairs[2 * j], pairs[2 * j + 1]); + value = Py_BuildValue("{s:O,s:O,s:O}", + "left", left_array, "right", right_array, "node", node_array); + if (key == NULL || value == NULL) { + goto out; + } + if (PyDict_SetItem(pair_dict, key, value) != 0) { + goto out; + } + Py_DECREF(key); + Py_DECREF(value); + Py_DECREF(left_array); + Py_DECREF(right_array); + Py_DECREF(node_array); + key = NULL; + value = NULL; + left_array = NULL; + right_array = NULL; + node_array = NULL; + } + ret = pair_dict; + pair_dict = NULL; +out: + Py_XDECREF(key); + Py_XDECREF(value); + Py_XDECREF(left_array); + Py_XDECREF(right_array); + Py_XDECREF(node_array); + Py_XDECREF(pair_dict); + return ret; +} + +static PyObject * +TableCollection_find_ibd(TableCollection *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + tsk_ibd_finder_t ibd_finder; + PyObject *samples; + PyArrayObject *samples_array = NULL; + double min_length = 0; + double max_time = DBL_MAX; + npy_intp *shape; + static char *kwlist[] = {"samples", "min_length", "max_time", NULL}; + + memset(&ibd_finder, 0, sizeof(ibd_finder)); + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|dd", kwlist, &samples, + &min_length, &max_time)) { + goto out; + } + samples_array = (PyArrayObject *) PyArray_FROMANY(samples, NPY_INT32, + 2, 2, NPY_ARRAY_IN_ARRAY); + if (samples_array == NULL) { + goto out; + } + shape = PyArray_DIMS(samples_array); + if (shape[1] != 2) { + PyErr_SetString(PyExc_ValueError, "sample pairs must have shape (n, 2)"); + goto out; + } + err = tsk_ibd_finder_init(&ibd_finder, self->tables, + PyArray_DATA(samples_array), (tsk_size_t) shape[0]); + if (err != 0) { + handle_library_error(err); + goto out; + } + err = tsk_ibd_finder_set_min_length(&ibd_finder, min_length); + if (err != 0) { + handle_library_error(err); + goto out; + } + err = tsk_ibd_finder_set_max_time(&ibd_finder, max_time); + if (err != 0) { + handle_library_error(err); + goto out; + } + err = tsk_ibd_finder_run(&ibd_finder); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = convert_ibd_segments(&ibd_finder, PyArray_DATA(samples_array), + (tsk_size_t) shape[0]); +out: + Py_XDECREF(samples_array); + tsk_ibd_finder_free(&ibd_finder); + return ret; +} + static PyObject * TableCollection_sort(TableCollection *self, PyObject *args, PyObject *kwds) { @@ -6790,7 +6930,11 @@ static PyMethodDef TableCollection_methods[] = { {"subset", (PyCFunction) TableCollection_subset, METH_VARARGS, "Subsets the table collection to a set of nodes." }, {"union", (PyCFunction) TableCollection_union, METH_VARARGS|METH_KEYWORDS, - "Adds to this table collection the portions of another table collection that are not shared with this one." }, + "Adds to this table collection the portions of another table collection " + "that are not shared with this one." }, + {"find_ibd", (PyCFunction) TableCollection_find_ibd, + METH_VARARGS|METH_KEYWORDS, + "Returns IBD segments for the specified sample pairs."}, {"sort", (PyCFunction) TableCollection_sort, METH_VARARGS|METH_KEYWORDS, "Sorts the tables to satisfy tree sequence requirements." }, {"equals", (PyCFunction) TableCollection_equals, METH_VARARGS, diff --git a/python/tests/ibd.py b/python/tests/ibd.py index d23a4fefec..d9a40ad11c 100644 --- a/python/tests/ibd.py +++ b/python/tests/ibd.py @@ -125,23 +125,20 @@ def add(self, other): class IbdFinder: """ - Finds all IBD relationships between specified samples in a tree sequence. + Finds all IBD relationships between specified sample pairs in a tree sequence. """ - def __init__(self, ts, samples=None, min_length=0, max_time=None): + def __init__(self, ts, sample_pairs, min_length=0, max_time=None): self.ts = ts - # Note: samples *must* be in order of ascending node ID - if samples is None: - self.samples = ts.samples() - else: - self.samples = samples - if len(self.samples) == 0: - raise ValueError("The tree sequence contains no samples.") + self.sample_pairs = sample_pairs + self.check_sample_pairs() + self.samples = list({i for pair in self.sample_pairs for i in pair}) self.sample_id_map = np.zeros(ts.num_nodes, dtype=int) - 1 for index, u in enumerate(self.samples): self.sample_id_map[u] = index + self.min_length = min_length if max_time is None: self.max_time = 2 * ts.max_root_time @@ -152,15 +149,6 @@ def __init__(self, ts, samples=None, min_length=0, max_time=None): self.oldest_parent = self.get_oldest_parents() - # Objects below are needed for the IBD segment-holding object. - self.num_samples = len(self.samples) - self.sample_pairs = self.get_sample_pairs() - - # Note: in the C code the object below should be a struct array. - # Each item will be accessed using its index, which corresponds to a particular - # sample pair. The mapping between index and sample pair is defined in the - # find_sample_pair_index method further down. - self.ibd_segments = {} for key in self.sample_pairs: self.ibd_segments[key] = None @@ -177,20 +165,37 @@ def get_oldest_parents(self): oldest_parents[c] = e.parent return oldest_parents - def add_ibd_segments(self, sample0, sample1, seg): - index = self.find_sample_pair_index(sample0, sample1) - - # Note: the code below is specific to the Python implementation, where the - # output is a dictionary indexed by sample pairs. - # In the C implementation, it'll be more like - # self.ibd_segments[index].add(seg) - + def add_ibd_segments(self, index, seg): + assert index != -1 if self.ibd_segments[self.sample_pairs[index]] is None: - self.ibd_segments[self.sample_pairs[index]] = SegmentList( - head=seg, tail=seg - ) + self.ibd_segments[self.sample_pairs[index]] = [seg] else: - self.ibd_segments[self.sample_pairs[index]].add(seg) + self.ibd_segments[self.sample_pairs[index]].append(seg) + + def check_sample_pairs(self): + """ + Checks that the user-inputted list of sample pairs is valid. + """ + for ind, p in enumerate(self.sample_pairs): + if not isinstance(p, tuple): + raise ValueError("Sample pairs must be a list of tuples.") + assert len(p) == 2 + # Assumes the node IDs are 0 ... ts.num_nodes - 1 + if not ( + p[0] in range(0, self.ts.num_nodes) + and p[1] in range(0, self.ts.num_nodes) + ): + raise ValueError("Each sample pair must contain valid node IDs.") + if p[0] == p[1]: + raise ValueError( + "Each sample pair must contain two different node IDs." + ) + # Ensure there are no duplicate pairs. + for ind2, p2 in enumerate(self.sample_pairs): + if ind == ind2: + continue + if p == p2 or (p[1], p[0]) == p2: + raise ValueError("The list of sample pairs contains duplicates.") def get_sample_pairs(self): """ @@ -212,26 +217,18 @@ def find_sample_pair_index(self, sample0, sample1): This calculates the position of the object corresponding to the inputted sample pair in the struct array. """ + index = 0 + while index < len(self.sample_pairs): + if self.sample_pairs[index] == (sample0, sample1) or self.sample_pairs[ + index + ] == (sample1, sample0): + break + index += 1 - # Ensure samples are in order. - if sample0 == sample1: - raise ValueError("Samples in pair must have different node IDs.") - elif sample0 > sample1: - sample0, sample1 = sample1, sample0 - - i0 = self.sample_id_map[sample0] - i1 = self.sample_id_map[sample1] - - # Calculate the position of the sample pair in the vector. - index = ( - (self.num_samples) * (self.num_samples - 1) / 2 - - (self.num_samples - i0) * (self.num_samples - i0 - 1) / 2 - + i1 - - i0 - - 1 - ) - - return int(index) + if index < len(self.sample_pairs): + return int(index) + else: + return -1 def find_ibd_segments(self): """ @@ -282,6 +279,9 @@ def find_ibd_segments(self): ) and parent_should_be_added: singleton_seg = SegmentList() singleton_seg.add(Segment(0, self.ts.sequence_length, current_parent)) + # u + # if self.A[u] is not None: + # list_to_add.add(self.A[u]) self.calculate_ibd_segs(current_parent, singleton_seg) parent_should_be_added = False @@ -289,13 +289,37 @@ def find_ibd_segments(self): e = next(edges_iter, None) # Remove any processed nodes that are no longer needed. + # Update parent_should_be_added. if e is not None and e.parent != current_parent: - for i, n in enumerate(self.oldest_parent): - if current_parent == n: - self.A[i] = None + parent_should_be_added = True + + self.convert_output_to_numpy() return self.ibd_segments + def convert_output_to_numpy(self): + """ + Converts the output to the format required by the Python-C interface layer. + """ + for key in self.sample_pairs: + left = [] + right = [] + node = [] + # Define numpy array values. + if self.ibd_segments[key] is None: + pass + else: + for seg in self.ibd_segments[key]: + left.append(seg.left) + right.append(seg.right) + node.append(seg.node) + # Convert lists to numpy arrays. + left = np.asarray(left, dtype=np.float64) + right = np.asarray(right, dtype=np.float64) + node = np.asarray(node, dtype=np.int64) + # Overwrite existing entry. + self.ibd_segments[key] = {"left": left, "right": right, "node": node} + def calculate_ibd_segs(self, current_parent, list_to_add): """ Write later. @@ -306,25 +330,29 @@ def calculate_ibd_segs(self, current_parent, list_to_add): if self.A[current_parent] is None: self.A[current_parent] = list_to_add - else: seg0 = self.A[current_parent].head while seg0 is not None: seg1 = list_to_add.head while seg1 is not None: + if seg0.node == seg1.node: + seg1 = seg1.next + continue + index = self.find_sample_pair_index(seg0.node, seg1.node) + if index == -1: + seg1 = seg1.next + continue left = max(seg0.left, seg1.left) right = min(seg0.right, seg1.right) if left >= right: seg1 = seg1.next continue - nodes = [seg0.node, seg1.node] - nodes.sort() # If there are any overlapping segments, record as a new # IBD relationship. if right - left > self.min_length: self.add_ibd_segments( - nodes[0], nodes[1], Segment(left, right, current_parent), + index, Segment(left, right, current_parent), ) seg1 = seg1.next seg0 = seg0.next diff --git a/python/tests/test_ibd.py b/python/tests/test_ibd.py index 2093107b43..91c4851f24 100644 --- a/python/tests/test_ibd.py +++ b/python/tests/test_ibd.py @@ -15,6 +15,39 @@ # Functions for computing IBD 'naively'. +def find_ibd( + ts, + sample_pairs, + min_length=0, + max_time=None, + compare_lib=True, + print_c=False, + print_py=False, +): + """ + Calculates IBD segments using Python and converts output to lists of segments. + Also compares result with C library. + """ + ibd_f = ibd.IbdFinder( + ts, sample_pairs=sample_pairs, max_time=max_time, min_length=min_length + ) + ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = convert_ibd_output_to_seglists(ibd_segs) + if compare_lib: + c_out = ts.tables.find_ibd( + sample_pairs, max_time=max_time, min_length=min_length + ) + c_out = convert_ibd_output_to_seglists(c_out) + if print_c: + print("C output:\n") + print(c_out) + if print_py: + print("Python output:\n") + print(ibd_segs) + assert ibd_is_equal(ibd_segs, c_out) + return ibd_segs + + def get_ibd( sample0, sample1, @@ -110,8 +143,7 @@ def get_ibd_all_pairs( path_ibd=path_ibd, mrca_ibd=mrca_ibd, ) - if len(ibd_list) > 0: - ibd_dict[pair] = ibd_list + ibd_dict[pair] = ibd_list return ibd_dict @@ -138,7 +170,9 @@ def subtrees_are_equal(tree1, pdict0, root): return True -def verify_equal_ibd(treeSequence): +def verify_equal_ibd( + ts, sample_pairs=None, compare_lib=True, print_c=False, print_py=False +): """ Calculates IBD segments using both the 'naive' and sophisticated algorithms, verifies that the same output is produced. @@ -146,55 +180,45 @@ def verify_equal_ibd(treeSequence): of IBD options are tested simultaneously (all the MRCA and path-IBD combos), for example. """ - ts = treeSequence - ibd0 = ibd.IbdFinder(ts, samples=ts.samples()) - ibd0 = ibd0.find_ibd_segments() + if sample_pairs is None: + sample_pairs = list(itertools.combinations(ts.samples(), 2)) + ibd0 = find_ibd( + ts, + sample_pairs=sample_pairs, + compare_lib=compare_lib, + print_c=print_c, + print_py=print_py, + ) ibd1 = get_ibd_all_pairs(ts, path_ibd=True, mrca_ibd=True) - # Convert each SegmentList object into a list of Segment objects. - ibd0_tolist = {} - for key, val in ibd0.items(): - if val is not None: - ibd0_tolist[key] = convert_segmentlist_to_list(val) - # Check for equality. - for key0, val0 in ibd0_tolist.items(): - + for key0, val0 in ibd0.items(): assert key0 in ibd1.keys() val1 = ibd1[key0] val0.sort() val1.sort() -def convert_segmentlist_to_list(seglist): +def convert_ibd_output_to_seglists(ibd_out): """ - Turns a SegmentList object into a list of Segment objects. - (This makes them easier to compare for testing purposes) + Converts the Python mock-up output back into lists of segments. + This is needed to use the ibd_is_equal function. """ - outlist = [] - if seglist is None: - return outlist - else: - seg = seglist.head - outlist = [seg] - seg = seg.next - while seg is not None: - outlist.append(seg) - seg = seg.next - - return outlist + for key in ibd_out.keys(): + seg_list = [] + num_segs = len(ibd_out[key]["left"]) + for s in range(num_segs): + seg_list.append( + ibd.Segment( + left=ibd_out[key]["left"][s], + right=ibd_out[key]["right"][s], + node=ibd_out[key]["node"][s], + ) + ) + ibd_out[key] = seg_list -def convert_dict_of_segmentlists(dict0): - """ - Turns a dictionary of SegmentList objects into a dictionary of lists of - Segment objects. (makes them easier to compare in tests). - """ - dict_out = {} - for key, val in dict0.items(): - dict_out[key] = convert_segmentlist_to_list(val) - - return dict_out + return ibd_out def ibd_is_equal(dict1, dict2): @@ -270,9 +294,7 @@ class TestIbdSingleBinaryTree(unittest.TestCase): # Basic test def test_defaults(self): - ibd_f = ibd.IbdFinder(self.ts) - ibd_segs = ibd_f.find_ibd_segments() - ibd_segs = convert_dict_of_segmentlists(ibd_segs) + ibd_segs = find_ibd(self.ts, sample_pairs=[(0, 1), (0, 2), (1, 2)],) true_segs = { (0, 1): [ibd.Segment(0.0, 1.0, 3)], (0, 2): [ibd.Segment(0.0, 1.0, 4)], @@ -280,22 +302,34 @@ def test_defaults(self): } assert ibd_is_equal(ibd_segs, true_segs) - # Max time = 1.5 def test_time(self): - ibd_f = ibd.IbdFinder(self.ts, max_time=1.5) - ibd_segs = ibd_f.find_ibd_segments() - ibd_segs = convert_dict_of_segmentlists(ibd_segs) + ibd_segs = find_ibd( + self.ts, + sample_pairs=[(0, 1), (0, 2), (1, 2)], + max_time=1.5, + compare_lib=True, + ) true_segs = {(0, 1): [ibd.Segment(0.0, 1.0, 3)], (0, 2): [], (1, 2): []} assert ibd_is_equal(ibd_segs, true_segs) # Min length = 2 def test_length(self): - ibd_f = ibd.IbdFinder(self.ts, min_length=2) - ibd_segs = ibd_f.find_ibd_segments() - ibd_segs = convert_dict_of_segmentlists(ibd_segs) + ibd_segs = find_ibd( + self.ts, sample_pairs=[(0, 1), (0, 2), (1, 2)], min_length=2, + ) true_segs = {(0, 1): [], (0, 2): [], (1, 2): []} assert ibd_is_equal(ibd_segs, true_segs) + def test_input_errors(self): + with self.assertRaises(ValueError): + ibd.IbdFinder(self.ts, sample_pairs=[0]) + with self.assertRaises(AssertionError): + ibd.IbdFinder(self.ts, sample_pairs=[(0, 1, 2)]) + with self.assertRaises(ValueError): + ibd.IbdFinder(self.ts, sample_pairs=[(0, 5)]) + with self.assertRaises(ValueError): + ibd.IbdFinder(self.ts, sample_pairs=[(0, 1), (1, 0)]) + class TestIbdTwoSamplesTwoTrees(unittest.TestCase): @@ -326,25 +360,23 @@ class TestIbdTwoSamplesTwoTrees(unittest.TestCase): # Basic test def test_basic(self): - ibd_f = ibd.IbdFinder(self.ts) - ibd_segs = ibd_f.find_ibd_segments() - ibd_segs = convert_dict_of_segmentlists(ibd_segs) + ibd_segs = find_ibd(self.ts, sample_pairs=[(0, 1)]) true_segs = {(0, 1): [ibd.Segment(0.0, 0.4, 2), ibd.Segment(0.4, 1.0, 3)]} assert ibd_is_equal(ibd_segs, true_segs) # Max time = 1.2 def test_time(self): - ibd_f = ibd.IbdFinder(self.ts, max_time=1.2) - ibd_segs = ibd_f.find_ibd_segments() - ibd_segs = convert_dict_of_segmentlists(ibd_segs) + ibd_segs = find_ibd( + self.ts, sample_pairs=[(0, 1)], max_time=1.2, compare_lib=True + ) true_segs = {(0, 1): [ibd.Segment(0.0, 0.4, 2)]} assert ibd_is_equal(ibd_segs, true_segs) # Min length = 0.5 def test_length(self): - ibd_f = ibd.IbdFinder(self.ts, min_length=0.5) - ibd_segs = ibd_f.find_ibd_segments() - ibd_segs = convert_dict_of_segmentlists(ibd_segs) + ibd_segs = find_ibd( + self.ts, sample_pairs=[(0, 1)], min_length=0.5, compare_lib=True + ) true_segs = {(0, 1): [ibd.Segment(0.4, 1.0, 3)]} assert ibd_is_equal(ibd_segs, true_segs) @@ -372,26 +404,21 @@ class TestIbdUnrelatedSamples(unittest.TestCase): 0 1 3 1 """ ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) def test_basic(self): - ibd_f = ibd.IbdFinder(self.ts) - ibd_segs = ibd_f.find_ibd_segments() - ibd_segs = convert_dict_of_segmentlists(ibd_segs) + ibd_segs = find_ibd(self.ts, sample_pairs=[(0, 1)]) true_segs = {(0, 1): []} assert ibd_is_equal(ibd_segs, true_segs) def test_time(self): - ibd_f = ibd.IbdFinder(self.ts, max_time=1.2) - ibd_segs = ibd_f.find_ibd_segments() - ibd_segs = convert_dict_of_segmentlists(ibd_segs) + ibd_segs = find_ibd(self.ts, sample_pairs=[(0, 1)], max_time=1.2) true_segs = {(0, 1): []} assert ibd_is_equal(ibd_segs, true_segs) def test_length(self): - ibd_f = ibd.IbdFinder(self.ts, min_length=0.2) - ibd_segs = ibd_f.find_ibd_segments() - ibd_segs = convert_dict_of_segmentlists(ibd_segs) + ibd_segs = find_ibd(self.ts, sample_pairs=[(0, 1)], min_length=0.2) true_segs = {(0, 1): []} assert ibd_is_equal(ibd_segs, true_segs) @@ -422,40 +449,61 @@ def test_no_samples(self): ) ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) with self.assertRaises(ValueError): - ibd.IbdFinder(ts) + ibd.IbdFinder(ts, sample_pairs=[(0, 1)]) class TestIbdSamplesAreDescendants(unittest.TestCase): # - # 2 - # | - # 1 - # | - # 0 + # 4 5 + # | | + # 2 3 + # | | + # 0 1 + # nodes = io.StringIO( """\ id is_sample time 0 1 0 - 1 1 1 - 2 0 2 + 1 1 0 + 2 1 1 + 3 1 1 + 4 0 2 + 5 0 2 """ ) edges = io.StringIO( """\ left right parent child - 0 1 1 0 - 0 1 2 1 + 0 1 2 0 + 0 1 3 1 + 0 1 4 2 + 0 1 5 3 """ ) ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) def test_basic(self): - ts = self.ts - ibd_f = ibd.IbdFinder(ts) - ibd_segs = ibd_f.find_ibd_segments() - ibd_segs = convert_dict_of_segmentlists(ibd_segs) - true_segs = {(0, 1): [ibd.Segment(0.0, 1.0, 1)]} + ibd_segs = find_ibd( + self.ts, sample_pairs=[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)], + ) + true_segs = { + (0, 1): [], + (0, 2): [ibd.Segment(0.0, 1.0, 2)], + (0, 3): [], + (1, 2): [], + (1, 3): [ibd.Segment(0.0, 1.0, 3)], + (2, 3): [], + } + + assert ibd_is_equal(ibd_segs, true_segs) + def test_input_sample_pairs(self): + ibd_segs = find_ibd(self.ts, sample_pairs=[(0, 3), (0, 2), (3, 5)]) + true_segs = { + (0, 3): [], + (0, 2): [ibd.Segment(0.0, 1.0, 2)], + (3, 5): [ibd.Segment(0.0, 1.0, 5)], + } assert ibd_is_equal(ibd_segs, true_segs) @@ -496,9 +544,7 @@ class TestIbdDifferentPaths(unittest.TestCase): ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) def test_defaults(self): - ts = self.ts - ibd_f = ibd.IbdFinder(ts) - ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = find_ibd(self.ts, sample_pairs=[(0, 1)]) true_segs = { (0, 1): [ ibd.Segment(0.0, 0.2, 4), @@ -506,34 +552,86 @@ def test_defaults(self): ibd.Segment(0.2, 0.7, 4), ] } - ibd_segs = convert_dict_of_segmentlists(ibd_segs) assert ibd_is_equal(ibd_segs, true_segs) def test_time(self): - ts = self.ts - ibd_f = ibd.IbdFinder(ts, max_time=1.8) - ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = find_ibd(self.ts, sample_pairs=[(0, 1)], max_time=1.8) true_segs = {(0, 1): []} - ibd_segs = convert_dict_of_segmentlists(ibd_segs) assert ibd_is_equal(ibd_segs, true_segs) - ibd_f = ibd.IbdFinder(ts, max_time=2.8) + + def test_length(self): + ibd_segs = find_ibd(self.ts, sample_pairs=[(0, 1)], min_length=0.4) + true_segs = {(0, 1): [ibd.Segment(0.2, 0.7, 4)]} + assert ibd_is_equal(ibd_segs, true_segs) + + # This is a situation where the Python and the C libraries agree, + # but aren't doing as expected. + @unittest.expectedFailure + def test_input_sample_pairs(self): + ibd_f = ibd.IbdFinder(self.ts, sample_pairs=[(0, 1), (2, 3), (1, 3)]) ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = convert_ibd_output_to_seglists(ibd_segs) true_segs = { (0, 1): [ ibd.Segment(0.0, 0.2, 4), ibd.Segment(0.7, 1.0, 4), ibd.Segment(0.2, 0.7, 4), - ] + ], + (2, 3): [ibd.Segment(0.2, 0.7, 4)], } - ibd_segs = convert_dict_of_segmentlists(ibd_segs) + ibd_segs = find_ibd( + self.ts, + sample_pairs=[(0, 1), (2, 3)], + compare_lib=True, + print_c=False, + print_py=False, + ) assert ibd_is_equal(ibd_segs, true_segs) - def test_length(self): - ts = self.ts - ibd_f = ibd.IbdFinder(ts, min_length=0.4) - ibd_segs = ibd_f.find_ibd_segments() - true_segs = {(0, 1): [ibd.Segment(0.2, 0.7, 4)]} - ibd_segs = convert_dict_of_segmentlists(ibd_segs) + +class TestIbdDifferentPaths2(unittest.TestCase): + # + # 5 | + # / \ | + # / 4 | 4 + # / / \ | / \ + # / / \ | / \ + # / / \ | 3 \ + # / / \ | / \ \ + # 0 1 2 | 0 2 1 + # | + # 0.2 + + nodes = io.StringIO( + """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 0 1 + 4 0 2.5 + 5 0 3.5 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0.2 1.0 3 0 + 0.2 1.0 3 2 + 0.0 1.0 4 1 + 0.0 0.2 4 2 + 0.2 1.0 4 3 + 0.0 0.2 5 0 + 0.0 0.2 5 4 + """ + ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + + def test_defaults(self): + ibd_segs = find_ibd(self.ts, sample_pairs=[(1, 2)]) + true_segs = { + (1, 2): [ibd.Segment(0.0, 0.2, 4), ibd.Segment(0.2, 1.0, 4)], + } assert ibd_is_equal(ibd_segs, true_segs) @@ -576,10 +674,9 @@ class TestIbdPolytomies(unittest.TestCase): ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) def test_defaults(self): - ts = self.ts - ibd_f = ibd.IbdFinder(ts) - ibd_segs = ibd_f.find_ibd_segments() - # print(ibd_segs[(0,1)]) + ibd_segs = find_ibd( + self.ts, sample_pairs=[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)], + ) true_segs = { (0, 1): [ibd.Segment(0, 1, 4)], (0, 2): [ibd.Segment(0, 0.3, 4), ibd.Segment(0.3, 1, 5)], @@ -588,14 +685,14 @@ def test_defaults(self): (1, 3): [ibd.Segment(0, 0.3, 5), ibd.Segment(0.3, 1, 4)], (2, 3): [ibd.Segment(0.3, 1, 5), ibd.Segment(0, 0.3, 5)], } - ibd_segs = convert_dict_of_segmentlists(ibd_segs) - # print(ibd_segs) assert ibd_is_equal(ibd_segs, true_segs) def test_time(self): - ts = self.ts - ibd_f = ibd.IbdFinder(ts, max_time=3) - ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = find_ibd( + self.ts, + sample_pairs=[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)], + max_time=3, + ) true_segs = { (0, 1): [ibd.Segment(0, 1, 4)], (0, 2): [ibd.Segment(0, 0.3, 4)], @@ -604,13 +701,14 @@ def test_time(self): (1, 3): [ibd.Segment(0.3, 1, 4)], (2, 3): [], } - ibd_segs = convert_dict_of_segmentlists(ibd_segs) assert ibd_is_equal(ibd_segs, true_segs) def test_length(self): - ts = self.ts - ibd_f = ibd.IbdFinder(ts, min_length=0.5) - ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = find_ibd( + self.ts, + sample_pairs=[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)], + min_length=0.5, + ) true_segs = { (0, 1): [ibd.Segment(0, 1, 4)], (0, 2): [ibd.Segment(0.3, 1, 5)], @@ -619,7 +717,14 @@ def test_length(self): (1, 3): [ibd.Segment(0.3, 1, 4)], (2, 3): [ibd.Segment(0.3, 1, 5)], } - ibd_segs = convert_dict_of_segmentlists(ibd_segs) + assert ibd_is_equal(ibd_segs, true_segs) + + def test_input_sample_pairs(self): + ibd_segs = find_ibd(self.ts, sample_pairs=[(0, 1), (0, 3)]) + true_segs = { + (0, 1): [ibd.Segment(0.0, 1.0, 4)], + (0, 3): [ibd.Segment(0.3, 1.0, 4), ibd.Segment(0.0, 0.3, 5)], + } assert ibd_is_equal(ibd_segs, true_segs) @@ -651,10 +756,7 @@ class TestIbdInternalSamples(unittest.TestCase): ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) def test_defaults(self): - ts = self.ts - ibd_f = ibd.IbdFinder(ts) - ibd_segs = ibd_f.find_ibd_segments() - ibd_segs = convert_dict_of_segmentlists(ibd_segs) + ibd_segs = find_ibd(self.ts, sample_pairs=[(0, 2)]) true_segs = { (0, 2): [ibd.Segment(0, 1, 3)], } @@ -666,22 +768,10 @@ class TestIbdRandomExamples(unittest.TestCase): Randomly generated test cases. """ - # Infinite sites, Hudson model. - def test_random_example1(self): - ts = msprime.simulate(sample_size=10, recombination_rate=0.5, random_seed=2) - verify_equal_ibd(ts) - - def test_random_example2(self): - ts = msprime.simulate(sample_size=10, recombination_rate=0.5, random_seed=23) - verify_equal_ibd(ts) - - def test_random_example3(self): - ts = msprime.simulate(sample_size=10, recombination_rate=0.5, random_seed=232) - verify_equal_ibd(ts) - - def test_random_example4(self): - ts = msprime.simulate(sample_size=10, recombination_rate=0.3, random_seed=726) - verify_equal_ibd(ts) + def test_random_examples(self): + for i in range(1, 50): + ts = msprime.simulate(sample_size=10, recombination_rate=0.3, random_seed=i) + verify_equal_ibd(ts) # Finite sites def sim_finite_sites(self, random_seed, dtwf=False): @@ -705,59 +795,21 @@ def sim_finite_sites(self, random_seed, dtwf=False): ) return ts - def test_finite_sites1(self): - ts = self.sim_finite_sites(9257) - verify_equal_ibd(ts) - - def test_finite_sites2(self): - ts = self.sim_finite_sites(835) - verify_equal_ibd(ts) - - def test_finite_sites3(self): - ts = self.sim_finite_sites(27278) - verify_equal_ibd(ts) - - def test_finite_sites4(self): - ts = self.sim_finite_sites(22446688) - verify_equal_ibd(ts) + def test_finite_sites(self): + for i in range(1, 11): + ts = self.sim_finite_sites(i) + verify_equal_ibd(ts) - # DTWF - def test_dtwf1(self): - ts = self.sim_finite_sites(84, dtwf=True) - verify_equal_ibd(ts) - - def test_dtwf2(self): - ts = self.sim_finite_sites(17482, dtwf=True) - verify_equal_ibd(ts) - - def test_dtwf3(self): - ts = self.sim_finite_sites(846, dtwf=True) - verify_equal_ibd(ts) - - def test_dtwf4(self): - ts = self.sim_finite_sites(273, dtwf=True) - verify_equal_ibd(ts) + def test_dtwf(self): + for i in range(1000, 1010): + ts = self.sim_finite_sites(i, dtwf=True) + verify_equal_ibd(ts) def test_sim_wright_fisher_generations(self): # Uses the bespoke DTWF forward-time simulator. - number_of_gens = 3 - tables = wf.wf_sim(4, number_of_gens, deep_history=False, seed=83) - tables.sort() - ts = tables.tree_sequence() - verify_equal_ibd(ts) - - def test_sim_wright_fisher_generations2(self): - # Uses the bespoke DTWF forward-time simulator. - number_of_gens = 10 - tables = wf.wf_sim(10, number_of_gens, deep_history=False, seed=837) - tables.sort() - ts = tables.tree_sequence() - verify_equal_ibd(ts) - - def test_sim_wright_fisher_generations3(self): - # Uses the bespoke DTWF forward-time simulator. - number_of_gens = 10 - tables = wf.wf_sim(10, number_of_gens, deep_history=False, seed=37) - tables.sort() - ts = tables.tree_sequence() - verify_equal_ibd(ts) + for i in range(1, 6): + number_of_gens = 10 + tables = wf.wf_sim(10, number_of_gens, deep_history=False, seed=i) + tables.sort() + ts = tables.tree_sequence() + verify_equal_ibd(ts) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 6843c5eeec..84681330e4 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -303,6 +303,61 @@ def test_union_bad_args(self): with self.assertRaises(ValueError): tc.union(tc2, np.array([[1], [2]], dtype="int32")) + def test_ibd_bad_args(self): + ts = msprime.simulate(10, random_seed=1) + tc = ts.tables._ll_tables + with self.assertRaises(TypeError): + tc.find_ibd() + for bad_samples in ["sdf", None, {}]: + with self.assertRaises(ValueError): + tc.find_ibd(bad_samples) + for not_enough_samples in [[], [0]]: + with self.assertRaises(ValueError): + tc.find_ibd(not_enough_samples) + # input array must be 2D + with self.assertRaises(ValueError): + tc.find_ibd([[[1], [1]]]) + # Input array must be (n, 2) + with self.assertRaises(ValueError): + tc.find_ibd([[1, 1, 1]]) + for bad_float in ["sdf", None, {}]: + with self.assertRaises(TypeError): + tc.find_ibd([(0, 1)], min_length=bad_float) + with self.assertRaises(TypeError): + tc.find_ibd([(0, 1)], max_time=bad_float) + with self.assertRaises(_tskit.LibraryError): + tc.find_ibd([(0, 1)], max_time=-1) + with self.assertRaises(_tskit.LibraryError): + tc.find_ibd([(0, 1)], min_length=-1) + + def test_ibd_output_no_recomb(self): + ts = msprime.simulate(10, random_seed=1) + tc = ts.tables._ll_tables + segs = tc.find_ibd([(0, 1), (2, 3)]) + self.assertIsInstance(segs, dict) + self.assertGreater(len(segs), 0) + for key, value in segs.items(): + self.assertIsInstance(key, tuple) + self.assertEqual(len(key), 2) + self.assertIsInstance(value, dict) + self.assertEqual(len(value), 3) + self.assertEqual(list(value["left"]), [0]) + self.assertEqual(list(value["right"]), [1]) + self.assertEqual(len(value["node"]), 1) + + def test_ibd_output_recomb(self): + ts = msprime.simulate(10, recombination_rate=1, random_seed=1) + self.assertGreater(ts.num_trees, 1) + tc = ts.tables._ll_tables + segs = tc.find_ibd([(0, 1), (2, 3)]) + self.assertIsInstance(segs, dict) + self.assertGreater(len(segs), 0) + for key, value in segs.items(): + self.assertIsInstance(key, tuple) + self.assertEqual(len(key), 2) + self.assertIsInstance(value, dict) + self.assertEqual(len(value), 3) + class TestTreeSequence(LowLevelTestCase, MetadataTestMixin): """ diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 6d7a6434e4..ad6bac83c3 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -27,6 +27,7 @@ import datetime import itertools import json +import sys import warnings from typing import Any from typing import Tuple @@ -2795,3 +2796,10 @@ def union( self.provenances.add_row( record=json.dumps(provenance.get_provenance_dict(parameters)) ) + + def find_ibd(self, samples, max_time=None, min_length=None): + max_time = sys.float_info.max if max_time is None else max_time + min_length = 0 if min_length is None else min_length + return self._ll_tables.find_ibd( + samples, max_time=max_time, min_length=min_length + )