diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 10f722fb6b..40e4f99cb4 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -10155,6 +10155,80 @@ test_table_collection_clear(void) | TSK_CLEAR_TS_METADATA_AND_SCHEMA); } +static void +test_table_collection_decapitate(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_table_collection_t t; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + ret = tsk_treeseq_copy_tables(&ts, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); + + /* Add some migrations */ + tsk_population_table_add_row(&t.populations, NULL, 0); + tsk_population_table_add_row(&t.populations, NULL, 0); + tsk_migration_table_add_row(&t.migrations, 0, 10, 0, 0, 1, 0.05, NULL, 0); + tsk_migration_table_add_row(&t.migrations, 0, 10, 0, 1, 0, 0.09, NULL, 0); + tsk_migration_table_add_row(&t.migrations, 0, 10, 0, 0, 1, 0.10, NULL, 0); + CU_ASSERT_EQUAL(t.migrations.num_rows, 3); + + /* NOTE: haven't worked out the exact IDs on the branches here, just + * for illustration. + 0.09┊ 9 5 10 ┊ 9 5 ┊11 5 ┊ + ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┏━┻┓ ┊ ┃ ┏━┻┓ ┊ + 0.07┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ 4 ┊ ┃ ┃ 4 ┊ + ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┏┻┓ ┊ ┃ ┃ ┏┻┓ ┊ + 0.00┊ 0 1 3 2 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + 0.00 2.00 7.00 10.00 + */ + + ret = tsk_table_collection_decapitate(&t, 0.09, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_init(&ts, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 12); + /* Lost the mutation over 5 */ + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 2); + /* We keep the migration at exactly 0.09. */ + CU_ASSERT_EQUAL(tsk_treeseq_get_num_migrations(&ts), 2); + + tsk_table_collection_free(&t); + tsk_treeseq_free(&ts); +} + +static void +test_table_collection_decapitate_errors(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_table_collection_t t; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + ret = tsk_treeseq_copy_tables(&ts, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); + + /* This should be caught later when we try to index */ + reverse_edges(&t); + ret = tsk_table_collection_decapitate(&t, 0.09, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EDGES_NOT_SORTED_CHILD); + + /* This should be caught immediately on entry to the function */ + t.sequence_length = -1; + ret = tsk_table_collection_decapitate(&t, 0.09, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SEQUENCE_LENGTH); + + tsk_table_collection_free(&t); +} + static void test_table_collection_takeset_indexes(void) { @@ -10317,6 +10391,9 @@ main(int argc, char **argv) test_table_collection_union_middle_merge }, { "test_table_collection_union_errors", test_table_collection_union_errors }, { "test_table_collection_clear", test_table_collection_clear }, + { "test_table_collection_decapitate", test_table_collection_decapitate }, + { "test_table_collection_decapitate_errors", + test_table_collection_decapitate_errors }, { "test_table_collection_takeset_indexes", test_table_collection_takeset_indexes }, { NULL, NULL }, diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 9093b6d794..e6432447b1 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -7249,6 +7249,202 @@ test_reference_sequence(void) tsk_table_collection_free(&tables); } +static void +test_split_edges_no_populations(void) +{ + int ret; + tsk_treeseq_t ts, split_ts; + tsk_table_collection_t tables; + tsk_id_t new_nodes[] = { 9, 10, 11 }; + tsk_size_t num_new_nodes = 3; + const char *metadata = "some metadata"; + tsk_size_t j; + tsk_node_t node; + double time = 0.09; + tsk_id_t ret_id; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + ret_id = tsk_table_collection_copy(ts.tables, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + tsk_treeseq_free(&ts); + ret_id = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret = tsk_table_collection_compute_mutation_times(&tables, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret_id = tsk_treeseq_init(&ts, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + + /* NOTE: haven't worked out the exact IDs on the branches here, just + * for illustration. + + 0.25┊ 8 ┊ ┊ ┊ + ┊ ┏━┻━┓ ┊ ┊ ┊ + 0.20┊ ┃ ┃ ┊ ┊ 7 ┊ + ┊ ┃ ┃ ┊ ┊ ┏━┻━┓ ┊ + 0.17┊ 6 ┃ ┊ 6 ┊ ┃ ┃ ┊ + ┊ ┏━┻┓ ┃ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ + 0.09┊ 9 5 10┊ 9 5 ┊ 11 5 ┊ + ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┏━┻┓ ┊ ┃ ┏━┻┓ ┊ + 0.07┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ 4 ┊ ┃ ┃ 4 ┊ + ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┏┻┓ ┊ ┃ ┃ ┏┻┓ ┊ + 0.00┊ 0 1 3 2 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + 0.00 2.00 7.00 10.00 + */ + ret = tsk_treeseq_split_edges( + &ts, time, 1234, 0, metadata, strlen(metadata), 0, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&split_ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&split_ts), 12); + + for (j = 0; j < num_new_nodes; j++) { + ret = tsk_treeseq_get_node(&split_ts, new_nodes[j], &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(node.time, time); + CU_ASSERT_EQUAL(node.flags, 1234); + CU_ASSERT_EQUAL(node.individual, TSK_NULL); + CU_ASSERT_EQUAL(node.population, 0); + CU_ASSERT_EQUAL(node.metadata_length, strlen(metadata)); + CU_ASSERT_EQUAL(strncmp(node.metadata, metadata, strlen(metadata)), 0); + } + tsk_treeseq_free(&split_ts); + + /* And again with imputed population value */ + ret = tsk_treeseq_split_edges(&ts, time, 1234, 0, metadata, strlen(metadata), + TSK_SPLIT_EDGES_IMPUTE_POPULATION, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&split_ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&split_ts), 12); + + for (j = 0; j < num_new_nodes; j++) { + ret = tsk_treeseq_get_node(&split_ts, new_nodes[j], &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(node.time, time); + CU_ASSERT_EQUAL(node.flags, 1234); + CU_ASSERT_EQUAL(node.individual, TSK_NULL); + CU_ASSERT_EQUAL(node.population, TSK_NULL); + CU_ASSERT_EQUAL(node.metadata_length, strlen(metadata)); + CU_ASSERT_EQUAL(strncmp(node.metadata, metadata, strlen(metadata)), 0); + } + tsk_treeseq_free(&split_ts); + + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_split_edges_populations(void) +{ + int ret; + tsk_treeseq_t ts, split_ts; + tsk_table_collection_t tables; + double time = 0.5; + tsk_node_t node; + tsk_id_t valid_pops[] = { -1, 0, 1 }; + tsk_id_t num_valid_pops = 3; + tsk_id_t j, population, ret_id; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; + + ret_id = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret_id = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 1); + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 0, 0, TSK_NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 1, 1, TSK_NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 1); + ret_id = tsk_edge_table_add_row(&tables.edges, 0, 1, 1, 0, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < num_valid_pops; j++) { + population = valid_pops[j]; + ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, 0, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&split_ts), 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&split_ts), 3); + ret = tsk_treeseq_get_node(&split_ts, 2, &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(node.population, population); + tsk_treeseq_free(&split_ts); + + ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, + TSK_SPLIT_EDGES_IMPUTE_POPULATION, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&split_ts), 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&split_ts), 3); + ret = tsk_treeseq_get_node(&split_ts, 2, &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(node.population, 0); + tsk_treeseq_free(&split_ts); + } + + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_split_edges_errors(void) +{ + int ret; + tsk_treeseq_t ts, split_ts; + tsk_table_collection_t tables; + double time = 0.5; + tsk_id_t invalid_pops[] = { -2, 2, 3 }; + tsk_id_t num_invalid_pops = 3; + tsk_id_t j, population, ret_id; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; + + ret_id = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret_id = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 1); + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 0, 0, TSK_NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 1, 1, TSK_NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 1); + ret_id = tsk_edge_table_add_row(&tables.edges, 0, 1, 1, 0, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < num_invalid_pops; j++) { + population = invalid_pops[j]; + ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, 0, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + tsk_treeseq_free(&split_ts); + + /* We always check population values, even if they aren't used */ + ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, + TSK_SPLIT_EDGES_IMPUTE_POPULATION, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + tsk_treeseq_free(&split_ts); + } + tsk_treeseq_free(&ts); + + ret_id + = tsk_migration_table_add_row(&tables.migrations, 0, 1, 0, 0, 1, 1.0, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, 0, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MIGRATIONS_NOT_SUPPORTED); + tsk_treeseq_free(&split_ts); + + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + static void test_init_take_ownership_no_edge_metadata(void) { @@ -7449,6 +7645,9 @@ main(int argc, char **argv) { "test_tree_sequence_metadata", test_tree_sequence_metadata }, { "test_time_uncalibrated", test_time_uncalibrated }, { "test_reference_sequence", test_reference_sequence }, + { "test_split_edges_no_populations", test_split_edges_no_populations }, + { "test_split_edges_populations", test_split_edges_populations }, + { "test_split_edges_errors", test_split_edges_errors }, { "test_init_take_ownership_no_edge_metadata", test_init_take_ownership_no_edge_metadata }, { NULL, NULL }, diff --git a/c/tskit/tables.c b/c/tskit/tables.c index d916fdee46..a9faf76ee4 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -12647,6 +12647,133 @@ tsk_table_collection_union(tsk_table_collection_t *self, return ret; } +int +tsk_table_collection_decapitate( + tsk_table_collection_t *self, double time, tsk_flags_t TSK_UNUSED(options)) +{ + int ret = 0; + tsk_edge_t edge; + tsk_mutation_t mutation; + tsk_migration_t migration; + tsk_edge_table_t edges; + tsk_mutation_table_t mutations; + tsk_migration_table_t migrations; + const double *restrict node_time = self->nodes.time; + tsk_id_t j, ret_id; + double mutation_time; + + memset(&edges, 0, sizeof(edges)); + memset(&mutations, 0, sizeof(mutations)); + memset(&migrations, 0, sizeof(migrations)); + + /* Note: perhaps should be stricter about what we accept here in terms + * of sorting, but compute_mutation_parents below will check anyway and + * so we're safe while we're calling that. + */ + ret = (int) tsk_table_collection_check_integrity(self, 0); + if (ret != 0) { + goto out; + } + + ret = tsk_edge_table_copy(&self->edges, &edges, 0); + if (ret != 0) { + goto out; + } + /* Note: we are assuming below that the tables are sorted, so we could save + * a bit of time and memory here by truncating part-way through the + * edges. */ + ret = tsk_edge_table_clear(&self->edges); + if (ret != 0) { + goto out; + } + for (j = 0; j < (tsk_id_t) edges.num_rows; j++) { + tsk_edge_table_get_row_unsafe(&edges, j, &edge); + if (node_time[edge.child] < time) { + if (time < node_time[edge.parent]) { + ret_id = tsk_node_table_add_row( + &self->nodes, 0, time, TSK_NULL, TSK_NULL, NULL, 0); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + edge.parent = ret_id; + } + ret_id = tsk_edge_table_add_row(&self->edges, edge.left, edge.right, + edge.parent, edge.child, edge.metadata, edge.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + } + } + /* Calling x_table_free multiple times is safe, so get rid of the + * extra edge table memory as soon as we can. */ + tsk_edge_table_free(&edges); + + ret = tsk_mutation_table_copy(&self->mutations, &mutations, 0); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_table_clear(&self->mutations); + if (ret != 0) { + goto out; + } + for (j = 0; j < (tsk_id_t) mutations.num_rows; j++) { + tsk_mutation_table_get_row_unsafe(&mutations, j, &mutation); + mutation_time = tsk_is_unknown_time(mutation.time) ? node_time[mutation.node] + : mutation.time; + if (mutation_time < time) { + /* Set the mutation parent to NULL, and recalculate below */ + ret_id = tsk_mutation_table_add_row(&self->mutations, mutation.site, + mutation.node, TSK_NULL, mutation.time, mutation.derived_state, + mutation.derived_state_length, mutation.metadata, + mutation.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + } + } + tsk_mutation_table_free(&mutations); + + ret = tsk_migration_table_copy(&self->migrations, &migrations, 0); + if (ret != 0) { + goto out; + } + ret = tsk_migration_table_clear(&self->migrations); + if (ret != 0) { + goto out; + } + for (j = 0; j < (tsk_id_t) migrations.num_rows; j++) { + tsk_migration_table_get_row_unsafe(&migrations, j, &migration); + if (migration.time <= time) { + ret_id = tsk_migration_table_add_row(&self->migrations, migration.left, + migration.right, migration.node, migration.source, migration.dest, + migration.time, migration.metadata, migration.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + } + } + tsk_migration_table_free(&migrations); + + ret = tsk_table_collection_build_index(self, 0); + if (ret != 0) { + goto out; + } + ret = tsk_table_collection_compute_mutation_parents(self, 0); + if (ret != 0) { + goto out; + } + +out: + tsk_edge_table_free(&edges); + tsk_mutation_table_free(&mutations); + tsk_migration_table_free(&migrations); + return ret; +} + static int cmp_edge_cl(const void *a, const void *b) { diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 54f6ca7967..ac8448c614 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -4252,6 +4252,12 @@ int tsk_provenance_table_takeset_columns(tsk_provenance_table_t *self, tsk_size_t *record_offset); bool tsk_table_collection_has_reference_sequence(const tsk_table_collection_t *self); + +/* Not documenting this because we may want to pass through default values + * for the new nodes (in particular the metadata) in the future */ +int tsk_table_collection_decapitate( + tsk_table_collection_t *self, double time, tsk_flags_t options); + int tsk_reference_sequence_init(tsk_reference_sequence_t *self, tsk_flags_t options); int tsk_reference_sequence_free(tsk_reference_sequence_t *self); bool tsk_reference_sequence_is_null(const tsk_reference_sequence_t *self); diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 98d50ed6ae..86f02e79b7 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3253,6 +3253,123 @@ tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, return ret; } +int TSK_WARN_UNUSED +tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags, + tsk_id_t population, const char *metadata, tsk_size_t metadata_length, + tsk_flags_t options, tsk_treeseq_t *output) +{ + int ret = 0; + tsk_table_collection_t *tables = tsk_malloc(sizeof(*tables)); + const double *restrict node_time = self->tables->nodes.time; + const tsk_id_t *restrict node_population = self->tables->nodes.population; + const tsk_size_t num_edges = self->tables->edges.num_rows; + const tsk_size_t num_mutations = self->tables->mutations.num_rows; + tsk_id_t *split_edge = tsk_malloc(num_edges * sizeof(*split_edge)); + tsk_id_t j, u, mapped_node, ret_id; + double mutation_time; + tsk_edge_t edge; + tsk_mutation_t mutation; + tsk_bookmark_t sort_start; + bool impute_population = options & TSK_SPLIT_EDGES_IMPUTE_POPULATION; + + memset(output, 0, sizeof(*output)); + if (split_edge == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_treeseq_copy_tables(self, tables, 0); + if (ret != 0) { + goto out; + } + if (tables->migrations.num_rows > 0) { + ret = TSK_ERR_MIGRATIONS_NOT_SUPPORTED; + goto out; + } + if (population < -1 || population >= (tsk_id_t) self->tables->populations.num_rows) { + ret = TSK_ERR_POPULATION_OUT_OF_BOUNDS; + goto out; + } + + tsk_edge_table_clear(&tables->edges); + tsk_memset(split_edge, TSK_NULL, num_edges * sizeof(*split_edge)); + + for (j = 0; j < (tsk_id_t) num_edges; j++) { + /* Would prefer to use tsk_edge_table_get_row_unsafe, but it's + * currently static to tables.c */ + ret = tsk_edge_table_get_row(&self->tables->edges, j, &edge); + tsk_bug_assert(ret == 0); + if (node_time[edge.child] < time && time < node_time[edge.parent]) { + if (impute_population) { + population = TSK_NULL; + if (node_population[edge.child] != TSK_NULL) { + population = node_population[edge.child]; + } + } + u = tsk_node_table_add_row(&tables->nodes, flags, time, population, TSK_NULL, + metadata, metadata_length); + if (u < 0) { + ret = (int) u; + goto out; + } + ret_id = tsk_edge_table_add_row(&tables->edges, edge.left, edge.right, u, + edge.child, edge.metadata, edge.metadata_length); + if (ret_id < 0) { + ret = (int) ret; + goto out; + } + edge.child = u; + split_edge[j] = u; + } + ret_id = tsk_edge_table_add_row(&tables->edges, edge.left, edge.right, + edge.parent, edge.child, edge.metadata, edge.metadata_length); + if (ret < 0) { + ret = (int) ret; + goto out; + } + } + + for (j = 0; j < (tsk_id_t) num_mutations; j++) { + /* Note: we could speed this up a bit by accessing the local + * memory for mutations directly. */ + ret = tsk_treeseq_get_mutation(self, j, &mutation); + tsk_bug_assert(ret == 0); + mapped_node = TSK_NULL; + if (mutation.edge != TSK_NULL) { + mapped_node = split_edge[mutation.edge]; + } + mutation_time = tsk_is_unknown_time(mutation.time) ? node_time[mutation.node] + : mutation.time; + if (mapped_node != TSK_NULL && mutation_time >= time) { + mutation.node = mapped_node; + } + /* Update the column in-place to save a bit of time. */ + tables->mutations.node[j] = mutation.node; + } + + /* Skip mutations and sites as they haven't been altered */ + /* Note we can probably optimise the edge sort a bit here also by + * reasoning about when the first edge gets altered in the table. + */ + memset(&sort_start, 0, sizeof(sort_start)); + sort_start.sites = tables->sites.num_rows; + sort_start.mutations = tables->mutations.num_rows; + ret = tsk_table_collection_sort(tables, &sort_start, 0); + if (ret != 0) { + goto out; + } + + ret = tsk_treeseq_init( + output, tables, TSK_TS_INIT_BUILD_INDEXES | TSK_TAKE_OWNERSHIP); + tables = NULL; +out: + if (tables != NULL) { + tsk_table_collection_free(tables); + tsk_safe_free(tables); + } + tsk_safe_free(split_edge); + return ret; +} + /* ======================================================== * * Tree * ======================================================== */ diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 60b011c667..c0bb804873 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -882,6 +882,12 @@ int tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, /** @} */ +#define TSK_SPLIT_EDGES_IMPUTE_POPULATION (1 << 1) + +int tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags, + tsk_id_t population, const char *metadata, tsk_size_t metadata_length, + tsk_flags_t options, tsk_treeseq_t *output); + bool tsk_treeseq_has_reference_sequence(const tsk_treeseq_t *self); int tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, diff --git a/docs/python-api.md b/docs/python-api.md index b169058c4f..9b0a44ba2d 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -218,6 +218,7 @@ which perform the same actions but modify the {class}`TableCollection` in place. TreeSequence.delete_intervals TreeSequence.delete_sites TreeSequence.trim + TreeSequence.decapitate ``` (sec_python_api_tree_sequences_ibd)= diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 1178169b28..461c3b73cd 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -4,7 +4,12 @@ **Changes** -- ``VcfWriter.write`` now prints the site ID of variants in the ID field of the output VCF files. +- Add ``TableCollection.decapitate`` and ``TreeSequence.decapitate`` operations + to remove information from that data model that is older than a specific time. + (:user:`jeromekelleher`, :issue:`2236`, :pr:`2240`) + +- ``VcfWriter.write`` now prints the site ID of variants in the ID field of the + output VCF files. (:user:`roohy`, :issue:`2103`, :pr:`2107`) - Make dumping of tables and tree sequences to disk a zero-copy operation. diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index a8ae82d79a..2320701cb3 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7022,6 +7022,29 @@ TableCollection_compute_mutation_parents(TableCollection *self) return ret; } +static PyObject * +TableCollection_decapitate(TableCollection *self, PyObject *args) +{ + PyObject *ret = NULL; + int err; + double time; + + if (TableCollection_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "d", &time)) { + goto out; + } + err = tsk_table_collection_decapitate(self->tables, time, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + static PyObject * TableCollection_compute_mutation_times(TableCollection *self) { @@ -7507,6 +7530,10 @@ static PyMethodDef TableCollection_methods[] = { .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Returns True if the parameter table collection is equal to this one." }, + { .ml_name = "decapitate", + .ml_meth = (PyCFunction) TableCollection_decapitate, + .ml_flags = METH_VARARGS, + .ml_doc = "Removes information older than the specified time." }, { .ml_name = "compute_mutation_parents", .ml_meth = (PyCFunction) TableCollection_compute_mutation_parents, .ml_flags = METH_NOARGS, @@ -9344,6 +9371,66 @@ TreeSequence_get_genotype_matrix(TreeSequence *self, PyObject *args, PyObject *k return ret; } +static PyObject * +TreeSequence_split_edges(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] + = { "time", "flags", "population", "metadata", "impute_population", NULL }; + double time; + tsk_flags_t flags; + tsk_id_t population; + int impute_population; + PyObject *py_metadata = Py_None; + char *metadata; + Py_ssize_t metadata_length; + int err; + tsk_flags_t options = 0; + TreeSequence *output = NULL; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + /* NOTE: we could make most of these arguments optional and reason about + * impute_population through the value of the population parameter, + * but we're trying to keep this code as simple as possible and put + * the logic in the high-level Python code */ + if (!PyArg_ParseTupleAndKeywords(args, kwds, "dO&O&Oi", kwlist, &time, + &uint32_converter, &flags, &tsk_id_converter, &population, &py_metadata, + &impute_population)) { + goto out; + } + + if (PyBytes_AsStringAndSize(py_metadata, &metadata, &metadata_length) < 0) { + goto out; + } + if (impute_population) { + options |= TSK_SPLIT_EDGES_IMPUTE_POPULATION; + } + + output = (TreeSequence *) _PyObject_New((PyTypeObject *) &TreeSequenceType); + if (output == NULL) { + goto out; + } + output->tree_sequence = PyMem_Malloc(sizeof(*output->tree_sequence)); + if (output->tree_sequence == NULL) { + PyErr_NoMemory(); + goto out; + } + err = tsk_treeseq_split_edges(self->tree_sequence, time, flags, + (tsk_id_t) population, metadata, metadata_length, options, + output->tree_sequence); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) output; + output = NULL; +out: + Py_XDECREF(output); + return ret; +} + static PyObject * TreeSequence_has_reference_sequence(TreeSequence *self) { @@ -9577,6 +9664,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_get_genotype_matrix, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Returns the genotypes matrix." }, + { .ml_name = "split_edges", + .ml_meth = (PyCFunction) TreeSequence_split_edges, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Returns a copy of this tree sequence edges split at time t" }, { .ml_name = "has_reference_sequence", .ml_meth = (PyCFunction) TreeSequence_has_reference_sequence, .ml_flags = METH_NOARGS, diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 02caa178d4..6356a02233 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -410,6 +410,19 @@ def test_union_bad_args(self): with pytest.raises(ValueError): tc.union(tc2, np.array([[1], [2]], dtype="int32")) + def test_decapitate_bad_args(self): + tc = _tskit.TableCollection(1) + self.get_example_tree_sequence().dump_tables(tc) + with pytest.raises(TypeError): + tc.decapitate() + with pytest.raises(TypeError): + tc.decapitate("1234") + + def test_decapitate_error(self): + tc = _tskit.TableCollection(-1) + with pytest.raises(_tskit.LibraryError, match="Sequence length"): + tc.decapitate(0) + def test_equals_bad_args(self): ts = msprime.simulate(10, random_seed=1242) tc = ts.tables._ll_tables @@ -1535,6 +1548,45 @@ def test_discrete_time(self): ts.load_tables(tables) assert ts.get_discrete_time() == 1 + def test_split_edges_return_type(self): + ts = self.get_example_tree_sequence() + split = ts.split_edges( + time=0, flags=0, population=0, metadata=b"", impute_population=False + ) + assert isinstance(split, _tskit.TreeSequence) + + def test_split_edges_bad_types(self): + ts = self.get_example_tree_sequence() + + def f(time=0, flags=0, population=0, metadata=b"", impute_population=False): + return ts.split_edges( + time=time, + flags=flags, + population=population, + metadata=metadata, + impute_population=impute_population, + ) + + with pytest.raises(TypeError): + f(time="0") + with pytest.raises(TypeError): + f(flags="0") + with pytest.raises(TypeError): + f(metadata="0") + with pytest.raises(TypeError): + f(impute_population="0") + + def test_split_edges_bad_population(self): + ts = self.get_example_tree_sequence() + with pytest.raises(_tskit.LibraryError, match="POPULATION_OUT_OF_BOUNDS"): + ts.split_edges( + time=0, + flags=0, + population=ts.get_num_populations(), + metadata=b"", + impute_population=False, + ) + class StatsInterfaceMixin: """ diff --git a/python/tests/test_table_transforms.py b/python/tests/test_table_transforms.py new file mode 100644 index 0000000000..98f8d00def --- /dev/null +++ b/python/tests/test_table_transforms.py @@ -0,0 +1,997 @@ +# MIT License +# +# Copyright (c) 2022 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for table transformation operations like trim(), decapitate, etc. +""" +import decimal +import fractions +import io +import json + +import numpy as np +import pytest + +import tests +import tskit +import tskit.util as util +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + + +# class TestDeleteOlderSimpleTree: + +# # 2.00┊ 4 ┊ +# # ┊ ┏━┻┓ ┊ +# # 1.00┊ ┃ 3 ┊ +# # ┊ ┃ ┏┻┓ ┊ +# # 0.00┊ 0 1 2 ┊ +# # 0 1 +# def tables(self): +# # Don't cache this because we modify the result! +# tree = tskit.Tree.generate_balanced(3, branch_length=1) +# return tree.tree_sequence.dump_tables() + +# @pytest.mark.parametrize("time", [0, -0.5, -100, 0.01, 0.999]) +# def test_before_first_internal_node(self, time): +# tables = self.tables() +# before = tables.copy() +# tables.delete_older(time) +# ts = tables.tree_sequence() +# assert ts.num_trees == 1 +# tree = ts.first() +# assert tree.num_roots == 3 +# assert list(sorted(tree.roots)) == [0, 1, 2] +# assert before.nodes.equals(tables.nodes[: len(before.nodes)]) +# assert len(tables.edges) == 0 + +# def test_t1(self): +# # +# # 2.00┊ ┊ +# # ┊ ┊ +# # 1.00┊ 3 ┊ +# # ┊ ┏┻┓ ┊ +# # 0.00┊ 0 1 2 ┊ +# # 0 1 +# tables = self.tables() +# before = tables.copy() +# tables.delete_older(1) +# print(tables) +# ts = tables.tree_sequence() +# assert ts.num_trees == 1 +# tree = ts.first() +# assert tree.num_roots == 2 +# assert list(sorted(tree.roots)) == [0, 3] +# assert len(tables.nodes) == 5 +# assert before.nodes.equals(tables.nodes) + +# @pytest.mark.parametrize("time", [1.01, 1.5, 1.999]) +# def test_t1_to_2(self, time): +# # 2.00┊ ┊ +# # ┊ ┊ +# # 1.01┊ 5 6 ┊ +# # ┊ ┃ ┃ ┊ +# # 1.00┊ ┃ 3 ┊ +# # ┊ ┃ ┏┻┓ ┊ +# # 0.00┊ 0 1 2 ┊ +# # 0 1 +# tables = self.tables() +# before = tables.copy() +# tables.decapitate(time) +# ts = tables.tree_sequence() +# assert ts.num_trees == 1 +# tree = ts.first() +# assert tree.num_roots == 2 +# assert list(sorted(tree.roots)) == [5, 6] +# assert len(tables.nodes) == 7 +# assert tables.nodes[5].time == time +# assert tables.nodes[6].time == time +# assert before.nodes.equals(tables.nodes[: len(before.nodes)]) + +# @pytest.mark.parametrize("time", [2, 2.5, 1e9]) +# def test_t2(self, time): +# tables = self.tables() +# before = tables.copy() +# tables.decapitate(time) +# tables.assert_equals(before, ignore_provenance=True) + + +def decapitate_definition(ts, time): + """ + Simple loop implementation of the decapitate operation + """ + tables = ts.dump_tables() + node_time = tables.nodes.time + tables.edges.clear() + # To avoid having a discrepancy between the C version and the Python + # version here we unset the metadata schema before adding any nodes. + # Setting empty metadata seems like the right thing to do here, but + # there will definitely be cases where doing this will break encoding + # (e.g., suppose there was a fixed-size binary schema - empty metadata + # is going to break this). It's not clear what the right answer is here, + # but this is at least fast to do and simple to describe. + md_schema = tables.nodes.metadata_schema + tables.nodes.metadata_schema = tskit.MetadataSchema(None) + for edge in ts.edges(): + if node_time[edge.parent] <= time: + tables.edges.append(edge) + elif node_time[edge.child] < time: + new_parent = tables.nodes.add_row(time=time) + tables.edges.append(edge.replace(parent=new_parent)) + tables.nodes.metadata_schema = md_schema + + tables.mutations.clear() + for mutation in ts.mutations(): + mutation_time = ( + node_time[mutation.node] + if util.is_unknown_time(mutation.time) + else mutation.time + ) + if mutation_time < time: + tables.mutations.append(mutation.replace(parent=tskit.NULL)) + + tables.migrations.clear() + for migration in ts.migrations(): + if migration.time <= time: + tables.migrations.append(migration) + + tables.build_index() + tables.compute_mutation_parents() + return tables.tree_sequence() + + +@pytest.mark.parametrize("ts", get_example_tree_sequences()) +def test_decapitate_examples(ts): + time = 0 if ts.num_nodes == 0 else np.median(ts.tables.nodes.time) + decap1 = decapitate_definition(ts, time) + decap2 = ts.decapitate(time) + decap1.tables.assert_equals(decap2.tables, ignore_provenance=True) + + +class TestDecapitateSimpleTree: + + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + def tables(self): + # Don't cache this because we modify the result! + tree = tskit.Tree.generate_balanced(3, branch_length=1) + return tree.tree_sequence.dump_tables() + + @pytest.mark.parametrize("time", [0, -0.5, -100]) + def test_t0_or_before(self, time): + tables = self.tables() + before = tables.copy() + tables.decapitate(time) + ts = tables.tree_sequence() + assert ts.num_trees == 1 + tree = ts.first() + assert tree.num_roots == 3 + assert list(sorted(tree.roots)) == [0, 1, 2] + assert before.nodes.equals(tables.nodes[: len(before.nodes)]) + assert len(tables.edges) == 0 + + @pytest.mark.parametrize("time", [0.01, 0.5, 0.999]) + def test_t0_to_1(self, time): + # + # 2.00┊ ┊ + # ┊ ┊ + # 0.99┊ 7 5 6 ┊ + # ┊ ┃ ┃ ┃ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tables = self.tables() + before = tables.copy() + tables.decapitate(time) + ts = tables.tree_sequence() + assert ts.num_trees == 1 + tree = ts.first() + assert tree.num_roots == 3 + assert list(sorted(tree.roots)) == [5, 6, 7] + assert len(tables.nodes) == 8 + assert tables.nodes[5].time == time + assert tables.nodes[6].time == time + assert tables.nodes[7].time == time + assert before.nodes.equals(tables.nodes[: len(before.nodes)]) + + def test_t1(self): + # + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 5 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tables = self.tables() + before = tables.copy() + tables.decapitate(1) + ts = tables.tree_sequence() + assert ts.num_trees == 1 + tree = ts.first() + assert tree.num_roots == 2 + assert list(sorted(tree.roots)) == [3, 5] + assert len(tables.nodes) == 6 + assert tables.nodes[5].time == 1 + assert before.nodes.equals(tables.nodes[: len(before.nodes)]) + + @pytest.mark.parametrize("time", [1.01, 1.5, 1.999]) + def test_t1_to_2(self, time): + # 2.00┊ ┊ + # ┊ ┊ + # 1.01┊ 5 6 ┊ + # ┊ ┃ ┃ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tables = self.tables() + before = tables.copy() + tables.decapitate(time) + ts = tables.tree_sequence() + assert ts.num_trees == 1 + tree = ts.first() + assert tree.num_roots == 2 + assert list(sorted(tree.roots)) == [5, 6] + assert len(tables.nodes) == 7 + assert tables.nodes[5].time == time + assert tables.nodes[6].time == time + assert before.nodes.equals(tables.nodes[: len(before.nodes)]) + + @pytest.mark.parametrize("time", [2, 2.5, 1e9]) + def test_t2(self, time): + tables = self.tables() + before = tables.copy() + tables.decapitate(time) + tables.assert_equals(before, ignore_provenance=True) + + +class TestDecapitateSimpleTreeMutationExamples: + def test_single_mutation_over_sample(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + + before = tables.copy() + tables.decapitate(1) + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 5 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + before.mutations.assert_equals(tables.mutations) + assert list(before.tree_sequence().alignments()) == list( + tables.tree_sequence().alignments() + ) + + def test_single_mutation_at_decap_time(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ x 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, time=1, derived_state="T") + + # Because the mutation is at exactly the decapitation time, we must + # remove it, or it would violate the requirement that a mutation must + # have a time less than that of the parent of the edge that its on. + tables.decapitate(1) + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 5 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert len(tables.mutations) == 0 + assert list(tables.tree_sequence().alignments()) == ["A", "A", "A"] + + def test_multi_mutation_over_sample(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ x 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + tables.mutations.add_row(site=0, node=0, parent=0, derived_state="G") + + before = tables.copy() + tables.decapitate(1) + # 2.00┊ ┊ + # ┊ 5 3 ┊ + # ┊ x ┃ ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + before.mutations.assert_equals(tables.mutations) + assert list(before.tree_sequence().alignments()) == list( + tables.tree_sequence().alignments() + ) + + def test_multi_mutation_over_sample_time(self): + # 2.00┊ 4 ┊ + # ┊ x━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, time=1.01, derived_state="T") + tables.mutations.add_row(site=0, node=0, time=0.99, parent=0, derived_state="G") + + before = tables.copy() + tables.decapitate(1) + # 2.00┊ ┊ + # ┊ 5 3 ┊ + # ┊ ┃ ┃ ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert len(tables.mutations) == 1 + # Alignments are equal because the ancestral mutation was silent anyway. + assert list(before.tree_sequence().alignments()) == list( + tables.tree_sequence().alignments() + ) + + def test_multi_mutation_over_root(self): + # x + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=4, derived_state="G") + tables.mutations.add_row(site=0, node=0, parent=0, derived_state="T") + + before = tables.copy() + tables.decapitate(1) + # 2.00┊ ┊ + # ┊ 5 3 ┊ + # ┊ ┃ ┃ ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert len(tables.mutations) == 1 + assert list(before.tree_sequence().alignments()) == ["T", "G", "G"] + # The states inherited by samples changes because we drop the old mutation + assert list(tables.tree_sequence().alignments()) == ["T", "A", "A"] + + +class TestDecapitateSimpleTreeMigrationExamples: + @tests.cached_example + def ts(self): + # 2.00┊ 4 ┊ + # ┊ o━┻┓ ┊ + # 1.00┊ o 3 ┊ + # ┊ o ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.populations.add_row() + tables.populations.add_row() + tables.migrations.add_row(source=0, dest=1, node=0, time=0.5, left=0, right=1) + tables.migrations.add_row(source=1, dest=0, node=0, time=1.0, left=0, right=1) + tables.migrations.add_row(source=0, dest=1, node=0, time=1.5, left=0, right=1) + tables.compute_mutation_parents() + ts = tables.tree_sequence() + return ts + + def test_t099(self): + ts = self.ts() + tables = ts.decapitate(0.99).tables + assert len(tables.migrations) == 1 + assert tables.migrations[0].time == 0.5 + + def test_t1(self): + ts = self.ts() + tables = ts.decapitate(1).tables + assert len(tables.migrations) == 2 + assert tables.migrations[0].time == 0.5 + assert tables.migrations[1].time == 1.0 + + +class TestSimpleTsExample: + # 9.08┊ 9 ┊ ┊ ┊ ┊ ┊ + # ┊ ┏━┻━┓ ┊ ┊ ┊ ┊ ┊ + # 6.57┊ ┃ ┃ ┊ ┊ ┊ ┊ 8 ┊ + # ┊ ┃ ┃ ┊ ┊ ┊ ┊ ┏━┻━┓ ┊ + # 5.31┊ ┃ ┃ ┊ 7 ┊ ┊ 7 ┊ ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┏━┻━┓ ┊ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ + # 1.75┊ ┃ ┃ ┊ ┃ ┃ ┊ 6 ┊ ┃ ┃ ┊ ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┃ ┃ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ ┃ ┃ ┊ + # 1.11┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ + # ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ + # 0.11┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ + # ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + # 0.00 0.06 0.79 0.91 0.91 1.00 + + @tests.cached_example + def ts(self): + nodes = io.StringIO( + """\ + id is_sample population individual time metadata + 0 1 0 -1 0 + 1 1 0 -1 0 + 2 1 0 -1 0 + 3 1 0 -1 0 + 4 0 0 -1 0.114 + 5 0 0 -1 1.110 + 6 0 0 -1 1.750 + 7 0 0 -1 5.310 + 8 0 0 -1 6.573 + 9 0 0 -1 9.083 + """ + ) + edges = io.StringIO( + """\ + id left right parent child + 0 0.00000000 1.00000000 4 0 + 1 0.00000000 1.00000000 4 1 + 2 0.00000000 1.00000000 5 2 + 3 0.00000000 1.00000000 5 3 + 4 0.79258618 0.90634460 6 4 + 5 0.79258618 0.90634460 6 5 + 6 0.05975243 0.79258618 7 4 + 7 0.90634460 0.91029435 7 4 + 8 0.05975243 0.79258618 7 5 + 9 0.90634460 0.91029435 7 5 + 10 0.91029435 1.00000000 8 4 + 11 0.91029435 1.00000000 8 5 + 12 0.00000000 0.05975243 9 4 + 13 0.00000000 0.05975243 9 5 + """ + ) + sites = io.StringIO( + """\ + position ancestral_state + 0.05 A + 0.06 0 + 0.3 C + 0.5 AAA + 0.91 T + """ + ) + muts = io.StringIO( + """\ + site node derived_state parent time + 0 9 T -1 15 + 0 9 GGG 0 9.1 + 0 5 1 1 9 + 1 4 C -1 1.6 + 1 4 G 3 1.5 + 2 7 G -1 10 + 2 3 C 5 1 + 4 3 G -1 1 + """ + ) + ts = tskit.load_text(nodes, edges, sites=sites, mutations=muts, strict=False) + return ts + + def test_at_time_of_5(self): + # NOTE: we don't remember that the edge 4-7 was shared in trees 1 and 3. + # 1.11┊ 14 5 ┊ 11 5 ┊ 10 5 ┊ 12 5 ┊ 13 5 ┊ + # ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ + # 0.11┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ + # ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + # 0.00 0.06 0.79 0.91 0.91 1.00 + ts = self.ts().decapitate(1.110) + assert ts.num_nodes == 15 + assert ts.num_trees == 5 + # Most mutations are older than this. + assert ts.num_mutations == 2 + for u in range(10, 15): + node = ts.node(u) + assert node.time == 1.110 + assert node.flags == 0 + assert [set(tree.roots) for tree in ts.trees()] == [ + {5, 14}, + {11, 5}, + {10, 5}, + {12, 5}, + {13, 5}, + ] + + def test_at_time6(self): + # 6 ┊ 12 13 ┊ ┊ ┊ ┊ 10 11 ┊ + # 5.31┊ ┃ ┃ ┊ 7 ┊ ┊ 7 ┊ ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┏━┻━┓ ┊ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ + # 1.75┊ ┃ ┃ ┊ ┃ ┃ ┊ 6 ┊ ┃ ┃ ┊ ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┃ ┃ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ ┃ ┃ ┊ + # 1.11┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ + # ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ + # 0.11┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ + # ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + # 0.00 0.06 0.79 0.91 0.91 1.00 + ts = self.ts().decapitate(6) + assert ts.num_nodes == 14 + assert ts.num_trees == 5 + assert ts.num_mutations == 4 + for u in range(10, 14): + node = ts.node(u) + assert node.time == 6 + assert node.flags == 0 + assert [set(tree.roots) for tree in ts.trees()] == [ + {12, 13}, + {7}, + {6}, + {7}, + {10, 11}, + ] + + +class TestDecapitateInterface: + @pytest.mark.parametrize("bad_type", ["x", "0.1", [], [0.1]]) + def test_bad_types(self, ts_fixture, bad_type): + with pytest.raises(TypeError, match="number"): + ts_fixture.decapitate(bad_type) + + @pytest.mark.parametrize( + "time", [1, 1.0, np.array([1])[0], fractions.Fraction(1, 1), decimal.Decimal(1)] + ) + def test_number_types(self, ts_fixture, time): + expected = ts_fixture.decapitate(1) + got = ts_fixture.decapitate(time) + expected.tables.assert_equals(got.tables, ignore_timestamps=True) + + def test_provenance(self, ts_fixture): + ts = ts_fixture.decapitate(1.5) + assert ts.num_provenances == ts_fixture.num_provenances + 1 + prov = json.loads(ts.provenance(ts.num_provenances - 1).record) + assert prov["parameters"] == {"command": "decapitate", "time": 1.5} + + def test_no_provenance(self, ts_fixture): + ts = ts_fixture.decapitate(1.5, record_provenance=False) + assert ts.num_provenances == ts_fixture.num_provenances + + def test_tables_ts_equivalent(self, ts_fixture): + time = 0.5 + ts = ts_fixture.decapitate(time) + tables = ts_fixture.dump_tables() + tables.decapitate(time) + tables.assert_equals(ts.tables, ignore_timestamps=True) + + def test_unsorted_tables_raises_error(self): + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + edges = tables.edges.copy() + tables.edges.clear() + for edge in reversed(edges): + tables.edges.append(edge) + with pytest.raises(tskit.LibraryError, match="order violated"): + tables.tree_sequence() + with pytest.raises(tskit.LibraryError, match="order violated"): + tables.decapitate(1) + + +def split_edges_definition(ts, time, *, flags=None, population=None, metadata=None): + tables = ts.dump_tables() + if ts.num_migrations > 0: + raise ValueError("Migrations not supported") + default_population = population is None + if not default_population: + # -1 is a valid value + if population < -1 or population >= ts.num_populations: + raise ValueError("Population out of bounds") + flags = 0 if flags is None else flags + if metadata is None: + metadata = tables.nodes.metadata_schema.empty_value + metadata = tables.nodes.metadata_schema.validate_and_encode_row(metadata) + # This is the easiest way to turn off encoding when calling add_row below + schema = tables.nodes.metadata_schema + tables.nodes.metadata_schema = tskit.MetadataSchema(None) + + node_time = tables.nodes.time + node_population = tables.nodes.population + tables.edges.clear() + split_edge = np.full(ts.num_edges, tskit.NULL, dtype=int) + for edge in ts.edges(): + if node_time[edge.child] < time < node_time[edge.parent]: + if default_population: + population = node_population[edge.child] + u = tables.nodes.add_row( + flags=flags, time=time, population=population, metadata=metadata + ) + tables.edges.append(edge.replace(parent=u)) + tables.edges.append(edge.replace(child=u)) + split_edge[edge.id] = u + else: + tables.edges.append(edge) + # Reinstate schema + tables.nodes.metadata_schema = schema + + tables.mutations.clear() + for mutation in ts.mutations(): + mapped_node = tskit.NULL + if mutation.edge != tskit.NULL: + mapped_node = split_edge[mutation.edge] + if mapped_node != tskit.NULL and mutation.time >= time: + mutation = mutation.replace(node=mapped_node) + tables.mutations.append(mutation) + + tables.sort() + return tables.tree_sequence() + + +class TestSplitEdgesSimpleTree: + + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + @tests.cached_example + def ts(self): + return tskit.Tree.generate_balanced(3, branch_length=1).tree_sequence + + @pytest.mark.parametrize("time", [0.1, 0.5, 0.9]) + def test_lowest_branches(self, time): + # 2.00┊ 4 ┊ 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ ┊ ┃ ┏┻┓ ┊ + # ┊ ┃ ┃ ┃ ┊ t ┊ 7 5 6 ┊ + # ┊ ┃ ┃ ┃ ┊ -> ┊ ┃ ┃ ┃ ┊ + # 0.00┊ 0 1 2 ┊ 0.00┊ 0 1 2 ┊ + # 0 1 0 1 + before_ts = self.ts() + ts = before_ts.split_edges(time) + assert ts.num_nodes == 8 + assert all(ts.node(u).time == time for u in [5, 6, 7]) + assert ts.num_trees == 1 + assert ts.first().parent_dict == {0: 7, 1: 5, 2: 6, 5: 3, 6: 3, 7: 4, 3: 4} + ts = ts.simplify() + ts.tables.assert_equals(before_ts.tables, ignore_provenance=True) + + def test_same_time_as_node(self): + # 2.00┊ 4 ┊ 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ 1.00┊ 5 3 ┊ + # ┊ ┃ ┏┻┓ ┊ ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ 0.00┊ 0 1 2 ┊ + # 0 1 0 1 + before_ts = self.ts() + ts = before_ts.split_edges(1) + assert ts.num_nodes == 6 + assert ts.node(5).time == 1 + assert ts.num_trees == 1 + assert ts.first().parent_dict == {0: 5, 1: 3, 2: 3, 5: 4, 3: 4} + ts = ts.simplify() + ts.tables.assert_equals(before_ts.tables, ignore_provenance=True) + + @pytest.mark.parametrize("time", [1.1, 1.5, 1.9]) + def test_top_branches(self, time): + # 2.00┊ 4 ┊ 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ ┊ ┏━┻┓ ┊ + # ┊ ┃ ┃ ┊ t ┊ 5 6 ┊ + # ┊ ┃ ┃ ┊ -> ┊ ┃ ┃ ┊ + # 1.00┊ ┃ 3 ┊ 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ 0.00┊ 0 1 2 ┊ + # 0 1 0 1 + + before_ts = self.ts() + ts = before_ts.split_edges(time) + assert ts.num_nodes == 7 + assert all(ts.node(u).time == time for u in [5, 6]) + assert ts.num_trees == 1 + assert ts.first().parent_dict == {0: 5, 1: 3, 2: 3, 3: 6, 6: 4, 5: 4} + ts = ts.simplify() + ts.tables.assert_equals(before_ts.tables, ignore_provenance=True) + + @pytest.mark.parametrize("time", [0, 2]) + def test_at_leaf_or_root_time(self, time): + split = self.ts().split_edges(time) + split.tables.assert_equals(self.ts().tables, ignore_provenance=True) + + @pytest.mark.parametrize("time", [-1, 2.1]) + def test_outside_time_scales(self, time): + split = self.ts().split_edges(time) + split.tables.assert_equals(self.ts().tables, ignore_provenance=True) + + +class TestSplitEdgesSimpleTreeMutationExamples: + def test_single_mutation_no_time(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T", metadata=b"1234") + ts = tables.tree_sequence() + + ts_split = ts.split_edges(1) + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ 5 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert ts_split.num_nodes == 6 + mut = ts_split.mutation(0) + assert mut.node == 0 + assert mut.derived_state == "T" + assert mut.metadata == b"1234" + assert tskit.is_unknown_time(mut.time) + + def test_single_mutation_split_before_time(self): + # 2.00┊ 4 ┊ + # ┊ x━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row( + site=0, node=0, time=1.5, derived_state="T", metadata=b"1234" + ) + ts = tables.tree_sequence() + + ts_split = ts.split_edges(1) + # 2.00┊ 4 ┊ + # ┊ x━┻┓ ┊ + # 1.00┊ 5 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert ts_split.num_nodes == 6 + mut = ts_split.mutation(0) + assert mut.node == 5 + assert mut.derived_state == "T" + assert mut.metadata == b"1234" + assert mut.time == 1.5 + + def test_single_mutation_split_at_time(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ x 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row( + site=0, node=0, time=1, derived_state="T", metadata=b"1234" + ) + ts = tables.tree_sequence() + + ts_split = ts.split_edges(1) + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ 5x 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + mut = ts_split.mutation(0) + assert mut.node == 5 + assert mut.derived_state == "T" + assert mut.metadata == b"1234" + assert mut.time == 1.0 + + def test_multi_mutation_no_time(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ x 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + tables.mutations.add_row(site=0, node=0, parent=0, derived_state="G") + ts = tables.tree_sequence() + + ts_split = ts.split_edges(1) + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # ┊ 5 3 ┊ + # ┊ x ┃ ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ts_split.tables.mutations.assert_equals(tables.mutations) + + def test_multi_mutation_over_sample_time(self): + # 2.00┊ 4 ┊ + # ┊ x━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, time=1.01, derived_state="T") + tables.mutations.add_row(site=0, node=0, time=0.99, parent=0, derived_state="G") + ts = tables.tree_sequence() + + ts_split = ts.split_edges(1) + # 2.00┊ 4 ┊ + # ┊ x━┻┓ ┊ + # 1.00┊ 5 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert ts_split.num_mutations == 2 + + mut = ts_split.mutation(0) + assert mut.site == 0 + assert mut.node == 5 + assert mut.time == 1.01 + mut = ts_split.mutation(1) + assert mut.site == 0 + assert mut.node == 0 + assert mut.time == 0.99 + + def test_mutation_not_on_branch(self): + tables = tskit.TableCollection(1) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + ts = tables.tree_sequence() + tables.assert_equals(ts.split_edges(0).tables, ignore_provenance=True) + + +class TestSplitEdgesExamples: + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_genotypes_round_trip(self, ts): + time = 0 if ts.num_nodes == 0 else np.median(ts.tables.nodes.time) + if ts.num_migrations == 0: + split_ts = ts.split_edges(time) + assert np.array_equal(split_ts.genotype_matrix(), ts.genotype_matrix()) + else: + with pytest.raises(tskit.LibraryError): + ts.split_edges(time) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("population", [-1, None]) + def test_definition(self, ts, population): + time = 0 if ts.num_nodes == 0 else np.median(ts.tables.nodes.time) + if ts.num_migrations == 0: + ts1 = split_edges_definition(ts, time, population=population) + ts2 = ts.split_edges(time, population=population) + ts1.tables.assert_equals(ts2.tables, ignore_provenance=True) + + +class TestSplitEdgesInterface: + def test_migrations_fail(self, ts_fixture): + assert ts_fixture.num_migrations > 0 + with pytest.raises(tskit.LibraryError, match="MIGRATIONS_NOT_SUPPORTED"): + ts_fixture.split_edges(0) + + def test_population_out_of_bounds(self): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + with pytest.raises(tskit.LibraryError, match="POPULATION_OUT_OF_BOUNDS"): + ts.split_edges(0, population=0) + + def test_bad_flags(self): + ts = tskit.TableCollection(1).tree_sequence() + with pytest.raises(TypeError): + ts.split_edges(0, flags="asdf") + + def test_bad_metadata_no_schema(self): + ts = tskit.TableCollection(1).tree_sequence() + with pytest.raises(TypeError): + ts.split_edges(0, metadata="asdf") + + def test_bad_metadata_json_schema(self): + tables = tskit.TableCollection(1) + tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json() + ts = tables.tree_sequence() + with pytest.raises(tskit.MetadataEncodingError): + ts.split_edges(0, metadata=b"bytes") + + +class TestSplitEdgesNodeValues: + @tests.cached_example + def ts(self): + tables = tskit.TableCollection(1) + for _ in range(5): + tables.populations.add_row() + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, population=0, time=0) + tables.nodes.add_row(time=1) + tables.edges.add_row(0, 1, 1, 0) + return tables.tree_sequence() + + @tests.cached_example + def ts_with_schema(self): + tables = tskit.TableCollection(1) + for _ in range(5): + tables.populations.add_row() + tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json() + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, population=0, time=0) + tables.nodes.add_row(time=1) + tables.edges.add_row(0, 1, 1, 0) + return tables.tree_sequence() + + def test_default_population(self): + ts = self.ts().split_edges(0.5) + assert ts.node(2).population == 0 + + @pytest.mark.parametrize("population", range(-1, 5)) + def test_specify_population(self, population): + ts = self.ts().split_edges(0.5, population=population) + assert ts.node(2).population == population + + def test_default_flags(self): + ts = self.ts().split_edges(0.5) + assert ts.node(2).flags == 0 + + @pytest.mark.parametrize("flags", range(0, 5)) + def test_specify_flags(self, flags): + ts = self.ts().split_edges(0.5, flags=flags) + assert ts.node(2).flags == flags + + def test_default_metadata_no_schema(self): + ts = self.ts().split_edges(0.5) + assert ts.node(2).metadata == b"" + + @pytest.mark.parametrize("metadata", [b"", b"some bytes"]) + def test_specify_metadata_no_schema(self, metadata): + ts = self.ts().split_edges(0.5, metadata=metadata) + assert ts.node(2).metadata == metadata + + def test_default_metadata_with_schema(self): + ts = self.ts_with_schema().split_edges(0.5) + assert ts.node(2).metadata == {} + + @pytest.mark.parametrize("metadata", [{}, {"some": "json"}]) + def test_specify_metadata_with_schema(self, metadata): + ts = self.ts_with_schema().split_edges(0.5, metadata=metadata) + assert ts.node(2).metadata == metadata diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index af0f315215..ae00b9f5ba 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -6245,14 +6245,14 @@ def test_single_tree_three_mutations_per_branch(self): def test_single_multiroot_tree_recurrent_mutations(self): ts = msprime.simulate(6, random_seed=10) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: self.verify_branch_mutations(ts, mutations_per_branch) def test_many_multiroot_trees_recurrent_mutations(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: self.verify_branch_mutations(ts, mutations_per_branch) diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 9c30d0728a..b481b893d1 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -3862,6 +3862,90 @@ def trim(self, record_provenance=True): record=json.dumps(provenance.get_provenance_dict(parameters)) ) + # def delete_older(self, time, *, record_provenance=True): + # node_time = self.nodes.time + # edges = self.edges.copy() + # self.edges.clear() + # for edge in edges: + # if node_time[edge.child] < time: + # self.edges.append(edge) + + # mutations = self.mutations.copy() + # # Map of old ID -> new ID + # mutation_map = np.full(len(mutations), tskit.NULL, dtype=int) + # self.mutations.clear() + # keep = [] + # for mutation in mutations: + # mutation_time = ( + # node_time[mutation.node] + # if util.is_unknown_time(mutation.time) + # else mutation.time + # ) + # if mutation_time < time: + # mutation_map[len(keep)] = mutation.id + # keep.append(mutation_map) + # # Not making assumptions about ordering, so it it in two passes. + # for mutation in keep: + # if mutation.parent != tskit.NULL: + # mutation = mutation.replace(parent=mutation_map[mutation.parent]) + # self.mutations.append(mutation) + + # migrations = self.migrations.copy() + # self.migrations.clear() + # for migration in migrations: + # if migration.time < time: + # self.migrations.append(migration) + + def decapitate(self, time, *, record_provenance=True): + """ + Delete all edge topology and mutational information older than the + specified time from this set of tables. + + Removes all edges in which the time of the child is >= the specified + time ``t``, and breaks edges that intersect with ``t``. For each edge + intersecting with ``t`` we create a new node with time equal to ``t``, + and set the parent of the edge to this new node. The node table + is not altered in any other way. Newly added nodes have empty metadata, + a ``flags`` value of 0, and NULL ``population`` and ``individual`` + references. + + .. warning:: + The empty metadata values for newly added nodes may not be compatible + with the node table's metadata schema! The current behaviour may + change in future versions to better accomodate metadata schemas. + + .. note:: + Note that each edge is treated independently, so that even if two + edges that are broken by this operation share the same parent and + child nodes, there will be two different new parent nodes inserted. + + Any mutation whose time is >= ``t`` will be removed. A mutation's time + is its associated ``time`` value, or the time of its node if the + mutation's time was marked as unknown (:data:`UNKNOWN_TIME`). + + Any migration with time > ``t`` will be removed. + + .. important:: + The tables must satisfy the standard + :ref:`sortedness requirements `. + + .. note:: + As a side-effect, this method will build the table indexes + and recompute the parents of all mutations. This is an implementation + detail and may not happen in future versions. + + :param float time: The cutoff time. + :param bool record_provenance: If ``True``, add details of this operation + to the provenance table in this TableCollection. (Default: ``True``). + """ + self._ll_tables.decapitate(time) + if record_provenance: + # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 + parameters = {"command": "decapitate", "time": float(time)} + self.provenances.add_row( + record=json.dumps(provenance.get_provenance_dict(parameters)) + ) + def clear( self, clear_provenance=False, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 11b83d6bdb..a02c6a8dd2 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -5935,6 +5935,80 @@ def trim(self, record_provenance=True): tables.trim(record_provenance) return tables.tree_sequence() + def decapitate(self, time, record_provenance=True): + """ + Return a copy of this tree sequence with topology and mutation information + older than the specified time removed. Please see the + :meth:`.TableCollection.decapitate` method for details. + + :param float time: The cutoff time. + :param bool record_provenance: If True, add details of this operation to the + provenance information of the returned tree sequence. (Default: True). + :rtype: tskit.TreeSequence + """ + tables = self.dump_tables() + tables.decapitate(time, record_provenance=record_provenance) + return tables.tree_sequence() + + def split_edges(self, time, *, flags=None, population=None, metadata=None): + """ + Returns a copy of this tree sequence in which we replace any + edge ``(left, right, parent, child)`` in which + ``node_time[child] < time < node_time[parent]`` with two edges + ``(left, right, parent, u)`` and ``(left, right, u, child)``, + where ``u`` is a newly added node for each intersecting edge. + + If ``metadata``, ``flags``, or ``population`` are specified, newly + added nodes will be assigned these values. Otherwise, default values + will be used. The default metadata is an empty dictionary if a metadata + schema is defined for the node table, and is an empty byte string + otherwise. The default population for the new node will be derived from + the population of the edge's child. Newly added have a default + ``flags`` value of 0. + + Any metadata associated with a split edge will be copied to the new edge. + + .. warning:: This method currently does not support migrations + and a ValueError will be raised if the migration table is not + empty. Future versions may take migrations that intersect with the + edge into account when determining the default population + assignments for new nodes. + + Any mutations lying on the edge whose time is >= ``time`` will have + their node value set to ``u``. Note that the time of the mutation is + defined as the time of the child node if the mutation's time is + unknown. + + :param float time: The cutoff time. + :param int flags: The flags value for newly-inserted nodes. (Default = 0) + :param int population: The population value for newly inserted nodes. + Defaults to the population of the child node of the split edge + if not specified. + :param metadata: The metadata for any newly inserted nodes. See + :meth:`.NodeTable.add_row` for details on how default metadata + is produced for a given schema (or none). + :param bool record_provenance: If True, add details of this operation to the + provenance information of the returned tree sequence. (Default: True). + :rtype: tskit.TreeSequence + """ + impute_population = False + if population is None: + impute_population = True + population = tskit.NULL + flags = 0 if flags is None else flags + schema = self.table_metadata_schemas.node + if metadata is None: + metadata = schema.empty_value + metadata = schema.validate_and_encode_row(metadata) + ll_ts = self._ll_tree_sequence.split_edges( + time=time, + flags=flags, + population=population, + metadata=metadata, + impute_population=impute_population, + ) + return TreeSequence(ll_ts) + def subset( self, nodes,