From 6d12abb84bae1b469b30d0eaddb8bb97256cffcb Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 4 May 2022 11:13:21 +0100 Subject: [PATCH] Implement split_edges Closes #2276 --- c/tests/test_trees.c | 203 ++++++++++++ c/tskit/tables.h | 1 + c/tskit/trees.c | 120 +++++++ c/tskit/trees.h | 6 + docs/python-api.md | 1 + python/CHANGELOG.rst | 7 +- python/_tskitmodule.c | 63 ++++ python/tests/test_lowlevel.py | 39 +++ python/tests/test_table_transforms.py | 429 ++++++++++++++++++++++++++ python/tskit/trees.py | 58 ++++ 10 files changed, 926 insertions(+), 1 deletion(-) create mode 100644 python/tests/test_table_transforms.py diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 9093b6d794..2b89dd38f8 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -7249,6 +7249,206 @@ test_reference_sequence(void) tsk_table_collection_free(&tables); } +static void +test_split_edges_no_populations(void) +{ + int ret; + tsk_treeseq_t ts, split_ts; + tsk_table_collection_t tables; + tsk_id_t new_nodes[] = { 9, 10, 11 }; + tsk_size_t num_new_nodes = 3; + const char *metadata = "some metadata"; + tsk_size_t j; + tsk_node_t node; + double time = 0.09; + tsk_id_t ret_id; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + + ret_id = tsk_table_collection_copy(ts.tables, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + tsk_treeseq_free(&ts); + ret_id = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret = tsk_table_collection_compute_mutation_times(&tables, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret_id = tsk_treeseq_init(&ts, &tables, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + + /* NOTE: haven't worked out the exact IDs on the branches here, just + * for illustration. + + 0.25┊ 8 ┊ ┊ ┊ + ┊ ┏━┻━┓ ┊ ┊ ┊ + 0.20┊ ┃ ┃ ┊ ┊ 7 ┊ + ┊ ┃ ┃ ┊ ┊ ┏━┻━┓ ┊ + 0.17┊ 6 ┃ ┊ 6 ┊ ┃ ┃ ┊ + ┊ ┏━┻┓ ┃ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ + 0.09┊ 9 5 10┊ 9 5 ┊ 11 5 ┊ + ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┏━┻┓ ┊ ┃ ┏━┻┓ ┊ + 0.07┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ 4 ┊ ┃ ┃ 4 ┊ + ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┏┻┓ ┊ ┃ ┃ ┏┻┓ ┊ + 0.00┊ 0 1 3 2 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + 0.00 2.00 7.00 10.00 + */ + ret = tsk_treeseq_split_edges( + &ts, time, 1234, 0, metadata, strlen(metadata), 0, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&split_ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&split_ts), 12); + + for (j = 0; j < num_new_nodes; j++) { + ret = tsk_treeseq_get_node(&split_ts, new_nodes[j], &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(node.time, time); + CU_ASSERT_EQUAL(node.flags, 1234); + CU_ASSERT_EQUAL(node.individual, TSK_NULL); + CU_ASSERT_EQUAL(node.population, 0); + CU_ASSERT_EQUAL(node.metadata_length, strlen(metadata)); + CU_ASSERT_EQUAL(strncmp(node.metadata, metadata, strlen(metadata)), 0); + } + tsk_treeseq_free(&split_ts); + + /* And again with imputed population value */ + ret = tsk_treeseq_split_edges(&ts, time, 1234, 0, metadata, strlen(metadata), + TSK_SPLIT_EDGES_IMPUTE_POPULATION, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&split_ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&split_ts), 12); + + for (j = 0; j < num_new_nodes; j++) { + ret = tsk_treeseq_get_node(&split_ts, new_nodes[j], &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(node.time, time); + CU_ASSERT_EQUAL(node.flags, 1234); + CU_ASSERT_EQUAL(node.individual, TSK_NULL); + CU_ASSERT_EQUAL(node.population, TSK_NULL); + CU_ASSERT_EQUAL(node.metadata_length, strlen(metadata)); + CU_ASSERT_EQUAL(strncmp(node.metadata, metadata, strlen(metadata)), 0); + } + tsk_treeseq_free(&split_ts); + + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_split_edges_populations(void) +{ + int ret; + tsk_treeseq_t ts, split_ts; + tsk_table_collection_t tables; + double time = 0.5; + tsk_node_t node; + tsk_id_t valid_pops[] = { -1, 0, 1 }; + tsk_id_t num_valid_pops = 3; + tsk_id_t j, population, ret_id; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; + + ret_id = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret_id = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 1); + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 0, 0, TSK_NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 1, 1, TSK_NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 1); + ret_id = tsk_edge_table_add_row(&tables.edges, 0, 1, 1, 0, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + for (j = 0; j < num_valid_pops; j++) { + population = valid_pops[j]; + ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, 0, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&split_ts), 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&split_ts), 3); + ret = tsk_treeseq_get_node(&split_ts, 2, &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(node.population, population); + tsk_treeseq_free(&split_ts); + + ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, + TSK_SPLIT_EDGES_IMPUTE_POPULATION, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&split_ts), 1); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&split_ts), 3); + ret = tsk_treeseq_get_node(&split_ts, 2, &node); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL(node.population, 0); + tsk_treeseq_free(&split_ts); + } + + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + +static void +test_split_edges_errors(void) +{ + int ret; + tsk_treeseq_t ts, split_ts; + tsk_table_collection_t tables; + double time = 0.5; + tsk_id_t invalid_pops[] = { -2, 2, 3 }; + tsk_id_t num_invalid_pops = 3; + tsk_id_t j, population, ret_id; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; + + ret_id = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret_id = tsk_population_table_add_row(&tables.populations, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 1); + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 0, 0, TSK_NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret_id = tsk_node_table_add_row(&tables.nodes, 0, 1, 1, TSK_NULL, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 1); + ret_id = tsk_edge_table_add_row(&tables.edges, 0, 1, 1, 0, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_split_edges( + &ts, TSK_UNKNOWN_TIME, 0, TSK_NULL, NULL, 0, 0, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TIME_NONFINITE); + + for (j = 0; j < num_invalid_pops; j++) { + population = invalid_pops[j]; + ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, 0, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + tsk_treeseq_free(&split_ts); + + /* We always check population values, even if they aren't used */ + ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, + TSK_SPLIT_EDGES_IMPUTE_POPULATION, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); + tsk_treeseq_free(&split_ts); + } + tsk_treeseq_free(&ts); + + ret_id + = tsk_migration_table_add_row(&tables.migrations, 0, 1, 0, 0, 1, 1.0, NULL, 0); + CU_ASSERT_EQUAL_FATAL(ret_id, 0); + ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, 0, &split_ts); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MIGRATIONS_NOT_SUPPORTED); + tsk_treeseq_free(&split_ts); + + tsk_table_collection_free(&tables); + tsk_treeseq_free(&ts); +} + static void test_init_take_ownership_no_edge_metadata(void) { @@ -7449,6 +7649,9 @@ main(int argc, char **argv) { "test_tree_sequence_metadata", test_tree_sequence_metadata }, { "test_time_uncalibrated", test_time_uncalibrated }, { "test_reference_sequence", test_reference_sequence }, + { "test_split_edges_no_populations", test_split_edges_no_populations }, + { "test_split_edges_populations", test_split_edges_populations }, + { "test_split_edges_errors", test_split_edges_errors }, { "test_init_take_ownership_no_edge_metadata", test_init_take_ownership_no_edge_metadata }, { NULL, NULL }, diff --git a/c/tskit/tables.h b/c/tskit/tables.h index 54f6ca7967..ec14e90be3 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -4252,6 +4252,7 @@ int tsk_provenance_table_takeset_columns(tsk_provenance_table_t *self, tsk_size_t *record_offset); bool tsk_table_collection_has_reference_sequence(const tsk_table_collection_t *self); + int tsk_reference_sequence_init(tsk_reference_sequence_t *self, tsk_flags_t options); int tsk_reference_sequence_free(tsk_reference_sequence_t *self); bool tsk_reference_sequence_is_null(const tsk_reference_sequence_t *self); diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 98d50ed6ae..c1fa398b81 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3253,6 +3253,126 @@ tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, return ret; } +int TSK_WARN_UNUSED +tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags, + tsk_id_t population, const char *metadata, tsk_size_t metadata_length, + tsk_flags_t options, tsk_treeseq_t *output) +{ + int ret = 0; + tsk_table_collection_t *tables = tsk_malloc(sizeof(*tables)); + const double *restrict node_time = self->tables->nodes.time; + const tsk_id_t *restrict node_population = self->tables->nodes.population; + const tsk_size_t num_edges = self->tables->edges.num_rows; + const tsk_size_t num_mutations = self->tables->mutations.num_rows; + tsk_id_t *split_edge = tsk_malloc(num_edges * sizeof(*split_edge)); + tsk_id_t j, u, mapped_node, ret_id; + double mutation_time; + tsk_edge_t edge; + tsk_mutation_t mutation; + tsk_bookmark_t sort_start; + bool impute_population = options & TSK_SPLIT_EDGES_IMPUTE_POPULATION; + + memset(output, 0, sizeof(*output)); + if (split_edge == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + ret = tsk_treeseq_copy_tables(self, tables, 0); + if (ret != 0) { + goto out; + } + if (tables->migrations.num_rows > 0) { + ret = TSK_ERR_MIGRATIONS_NOT_SUPPORTED; + goto out; + } + if (population < -1 || population >= (tsk_id_t) self->tables->populations.num_rows) { + ret = TSK_ERR_POPULATION_OUT_OF_BOUNDS; + goto out; + } + if (!tsk_isfinite(time)) { + ret = TSK_ERR_TIME_NONFINITE; + goto out; + } + + tsk_edge_table_clear(&tables->edges); + tsk_memset(split_edge, TSK_NULL, num_edges * sizeof(*split_edge)); + + for (j = 0; j < (tsk_id_t) num_edges; j++) { + /* Would prefer to use tsk_edge_table_get_row_unsafe, but it's + * currently static to tables.c */ + ret = tsk_edge_table_get_row(&self->tables->edges, j, &edge); + tsk_bug_assert(ret == 0); + if (node_time[edge.child] < time && time < node_time[edge.parent]) { + if (impute_population) { + population = TSK_NULL; + if (node_population[edge.child] != TSK_NULL) { + population = node_population[edge.child]; + } + } + u = tsk_node_table_add_row(&tables->nodes, flags, time, population, TSK_NULL, + metadata, metadata_length); + if (u < 0) { + ret = (int) u; + goto out; + } + ret_id = tsk_edge_table_add_row(&tables->edges, edge.left, edge.right, u, + edge.child, edge.metadata, edge.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + edge.child = u; + split_edge[j] = u; + } + ret_id = tsk_edge_table_add_row(&tables->edges, edge.left, edge.right, + edge.parent, edge.child, edge.metadata, edge.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + } + + for (j = 0; j < (tsk_id_t) num_mutations; j++) { + /* Note: we could speed this up a bit by accessing the local + * memory for mutations directly. */ + ret = tsk_treeseq_get_mutation(self, j, &mutation); + tsk_bug_assert(ret == 0); + mapped_node = TSK_NULL; + if (mutation.edge != TSK_NULL) { + mapped_node = split_edge[mutation.edge]; + } + mutation_time = tsk_is_unknown_time(mutation.time) ? node_time[mutation.node] + : mutation.time; + if (mapped_node != TSK_NULL && mutation_time >= time) { + /* Update the column in-place to save a bit of time. */ + tables->mutations.node[j] = mapped_node; + } + } + + /* Skip mutations and sites as they haven't been altered */ + /* Note we can probably optimise the edge sort a bit here also by + * reasoning about when the first edge gets altered in the table. + */ + memset(&sort_start, 0, sizeof(sort_start)); + sort_start.sites = tables->sites.num_rows; + sort_start.mutations = tables->mutations.num_rows; + ret = tsk_table_collection_sort(tables, &sort_start, 0); + if (ret != 0) { + goto out; + } + + ret = tsk_treeseq_init( + output, tables, TSK_TS_INIT_BUILD_INDEXES | TSK_TAKE_OWNERSHIP); + tables = NULL; +out: + if (tables != NULL) { + tsk_table_collection_free(tables); + tsk_safe_free(tables); + } + tsk_safe_free(split_edge); + return ret; +} + /* ======================================================== * * Tree * ======================================================== */ diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 60b011c667..c0bb804873 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -882,6 +882,12 @@ int tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, /** @} */ +#define TSK_SPLIT_EDGES_IMPUTE_POPULATION (1 << 1) + +int tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags, + tsk_id_t population, const char *metadata, tsk_size_t metadata_length, + tsk_flags_t options, tsk_treeseq_t *output); + bool tsk_treeseq_has_reference_sequence(const tsk_treeseq_t *self); int tsk_treeseq_kc_distance(const tsk_treeseq_t *self, const tsk_treeseq_t *other, diff --git a/docs/python-api.md b/docs/python-api.md index b169058c4f..8a9db91812 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -218,6 +218,7 @@ which perform the same actions but modify the {class}`TableCollection` in place. TreeSequence.delete_intervals TreeSequence.delete_sites TreeSequence.trim + TreeSequence.split_edges ``` (sec_python_api_tree_sequences_ibd)= diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 1178169b28..808ab519f2 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -4,7 +4,8 @@ **Changes** -- ``VcfWriter.write`` now prints the site ID of variants in the ID field of the output VCF files. +- ``VcfWriter.write`` now prints the site ID of variants in the ID field of the + output VCF files. (:user:`roohy`, :issue:`2103`, :pr:`2107`) - Make dumping of tables and tree sequences to disk a zero-copy operation. @@ -21,6 +22,10 @@ edge that the mutation falls on. (:user:`jeromekelleher`, :issue:`685`, :pr:`2279`). +- Add the ``TreeSequence.split_edges`` operation which inserts nodes into + edges at a specific time. + (:user:`jeromekelleher`, :issue:`2276`, :pr:`2296`). + **Breaking Changes** - The JSON metadata codec now interprets the empty string as an empty object. This means diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index a8ae82d79a..f470b35364 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9344,6 +9344,65 @@ TreeSequence_get_genotype_matrix(TreeSequence *self, PyObject *args, PyObject *k return ret; } +static PyObject * +TreeSequence_split_edges(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] + = { "time", "flags", "population", "metadata", "impute_population", NULL }; + double time; + tsk_flags_t flags; + tsk_id_t population; + int impute_population; + PyObject *py_metadata = Py_None; + char *metadata; + Py_ssize_t metadata_length; + int err; + tsk_flags_t options = 0; + TreeSequence *output = NULL; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + /* NOTE: we could make most of these arguments optional and reason about + * impute_population through the value of the population parameter, + * but we're trying to keep this code as simple as possible and put + * the logic in the high-level Python code */ + if (!PyArg_ParseTupleAndKeywords(args, kwds, "dO&O&Oi", kwlist, &time, + &uint32_converter, &flags, &tsk_id_converter, &population, &py_metadata, + &impute_population)) { + goto out; + } + + if (PyBytes_AsStringAndSize(py_metadata, &metadata, &metadata_length) < 0) { + goto out; + } + if (impute_population) { + options |= TSK_SPLIT_EDGES_IMPUTE_POPULATION; + } + + output = (TreeSequence *) _PyObject_New((PyTypeObject *) &TreeSequenceType); + if (output == NULL) { + goto out; + } + output->tree_sequence = PyMem_Malloc(sizeof(*output->tree_sequence)); + if (output->tree_sequence == NULL) { + PyErr_NoMemory(); + goto out; + } + err = tsk_treeseq_split_edges(self->tree_sequence, time, flags, population, metadata, + metadata_length, options, output->tree_sequence); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) output; + output = NULL; +out: + Py_XDECREF(output); + return ret; +} + static PyObject * TreeSequence_has_reference_sequence(TreeSequence *self) { @@ -9577,6 +9636,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_get_genotype_matrix, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Returns the genotypes matrix." }, + { .ml_name = "split_edges", + .ml_meth = (PyCFunction) TreeSequence_split_edges, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Returns a copy of this tree sequence edges split at time t" }, { .ml_name = "has_reference_sequence", .ml_meth = (PyCFunction) TreeSequence_has_reference_sequence, .ml_flags = METH_NOARGS, diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 02caa178d4..6c97136129 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1535,6 +1535,45 @@ def test_discrete_time(self): ts.load_tables(tables) assert ts.get_discrete_time() == 1 + def test_split_edges_return_type(self): + ts = self.get_example_tree_sequence() + split = ts.split_edges( + time=0, flags=0, population=0, metadata=b"", impute_population=True + ) + assert isinstance(split, _tskit.TreeSequence) + + def test_split_edges_bad_types(self): + ts = self.get_example_tree_sequence() + + def f(time=0, flags=0, population=0, metadata=b"", impute_population=False): + return ts.split_edges( + time=time, + flags=flags, + population=population, + metadata=metadata, + impute_population=impute_population, + ) + + with pytest.raises(TypeError): + f(time="0") + with pytest.raises(TypeError): + f(flags="0") + with pytest.raises(TypeError): + f(metadata="0") + with pytest.raises(TypeError): + f(impute_population="0") + + def test_split_edges_bad_population(self): + ts = self.get_example_tree_sequence() + with pytest.raises(_tskit.LibraryError, match="POPULATION_OUT_OF_BOUNDS"): + ts.split_edges( + time=0, + flags=0, + population=ts.get_num_populations(), + metadata=b"", + impute_population=False, + ) + class StatsInterfaceMixin: """ diff --git a/python/tests/test_table_transforms.py b/python/tests/test_table_transforms.py new file mode 100644 index 0000000000..a7c1e5c319 --- /dev/null +++ b/python/tests/test_table_transforms.py @@ -0,0 +1,429 @@ +# MIT License +# +# Copyright (c) 2022 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for table transformation operations like trim(), decapitate, etc. +""" +import math + +import numpy as np +import pytest + +import tests +import tskit +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + + +def split_edges_definition(ts, time, *, flags=None, population=None, metadata=None): + tables = ts.dump_tables() + if ts.num_migrations > 0: + raise ValueError("Migrations not supported") + default_population = population is None + if not default_population: + # -1 is a valid value + if population < -1 or population >= ts.num_populations: + raise ValueError("Population out of bounds") + flags = 0 if flags is None else flags + if metadata is None: + metadata = tables.nodes.metadata_schema.empty_value + metadata = tables.nodes.metadata_schema.validate_and_encode_row(metadata) + # This is the easiest way to turn off encoding when calling add_row below + schema = tables.nodes.metadata_schema + tables.nodes.metadata_schema = tskit.MetadataSchema(None) + + node_time = tables.nodes.time + node_population = tables.nodes.population + tables.edges.clear() + split_edge = np.full(ts.num_edges, tskit.NULL, dtype=int) + for edge in ts.edges(): + if node_time[edge.child] < time < node_time[edge.parent]: + if default_population: + population = node_population[edge.child] + u = tables.nodes.add_row( + flags=flags, time=time, population=population, metadata=metadata + ) + tables.edges.append(edge.replace(parent=u)) + tables.edges.append(edge.replace(child=u)) + split_edge[edge.id] = u + else: + tables.edges.append(edge) + # Reinstate schema + tables.nodes.metadata_schema = schema + + tables.mutations.clear() + for mutation in ts.mutations(): + mapped_node = tskit.NULL + if mutation.edge != tskit.NULL: + mapped_node = split_edge[mutation.edge] + if mapped_node != tskit.NULL and mutation.time >= time: + mutation = mutation.replace(node=mapped_node) + tables.mutations.append(mutation) + + tables.sort() + return tables.tree_sequence() + + +class TestSplitEdgesSimpleTree: + + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + @tests.cached_example + def ts(self): + return tskit.Tree.generate_balanced(3, branch_length=1).tree_sequence + + @pytest.mark.parametrize("time", [0.1, 0.5, 0.9]) + def test_lowest_branches(self, time): + # 2.00┊ 4 ┊ 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ ┊ ┃ ┏┻┓ ┊ + # ┊ ┃ ┃ ┃ ┊ t ┊ 7 5 6 ┊ + # ┊ ┃ ┃ ┃ ┊ -> ┊ ┃ ┃ ┃ ┊ + # 0.00┊ 0 1 2 ┊ 0.00┊ 0 1 2 ┊ + # 0 1 0 1 + before_ts = self.ts() + ts = before_ts.split_edges(time) + assert ts.num_nodes == 8 + assert all(ts.node(u).time == time for u in [5, 6, 7]) + assert ts.num_trees == 1 + assert ts.first().parent_dict == {0: 7, 1: 5, 2: 6, 5: 3, 6: 3, 7: 4, 3: 4} + ts = ts.simplify() + ts.tables.assert_equals(before_ts.tables, ignore_provenance=True) + + def test_same_time_as_node(self): + # 2.00┊ 4 ┊ 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ 1.00┊ 5 3 ┊ + # ┊ ┃ ┏┻┓ ┊ ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ 0.00┊ 0 1 2 ┊ + # 0 1 0 1 + before_ts = self.ts() + ts = before_ts.split_edges(1) + assert ts.num_nodes == 6 + assert ts.node(5).time == 1 + assert ts.num_trees == 1 + assert ts.first().parent_dict == {0: 5, 1: 3, 2: 3, 5: 4, 3: 4} + ts = ts.simplify() + ts.tables.assert_equals(before_ts.tables, ignore_provenance=True) + + @pytest.mark.parametrize("time", [1.1, 1.5, 1.9]) + def test_top_branches(self, time): + # 2.00┊ 4 ┊ 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ ┊ ┏━┻┓ ┊ + # ┊ ┃ ┃ ┊ t ┊ 5 6 ┊ + # ┊ ┃ ┃ ┊ -> ┊ ┃ ┃ ┊ + # 1.00┊ ┃ 3 ┊ 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ 0.00┊ 0 1 2 ┊ + # 0 1 0 1 + + before_ts = self.ts() + ts = before_ts.split_edges(time) + assert ts.num_nodes == 7 + assert all(ts.node(u).time == time for u in [5, 6]) + assert ts.num_trees == 1 + assert ts.first().parent_dict == {0: 5, 1: 3, 2: 3, 3: 6, 6: 4, 5: 4} + ts = ts.simplify() + ts.tables.assert_equals(before_ts.tables, ignore_provenance=True) + + @pytest.mark.parametrize("time", [0, 2]) + def test_at_leaf_or_root_time(self, time): + split = self.ts().split_edges(time) + split.tables.assert_equals(self.ts().tables, ignore_provenance=True) + + @pytest.mark.parametrize("time", [-1, 2.1]) + def test_outside_time_scales(self, time): + split = self.ts().split_edges(time) + split.tables.assert_equals(self.ts().tables, ignore_provenance=True) + + +class TestSplitEdgesSimpleTreeMutationExamples: + def test_single_mutation_no_time(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T", metadata=b"1234") + ts = tables.tree_sequence() + + ts_split = ts.split_edges(1) + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ 5 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert ts_split.num_nodes == 6 + mut = ts_split.mutation(0) + assert mut.node == 0 + assert mut.derived_state == "T" + assert mut.metadata == b"1234" + assert tskit.is_unknown_time(mut.time) + + def test_single_mutation_split_before_time(self): + # 2.00┊ 4 ┊ + # ┊ x━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row( + site=0, node=0, time=1.5, derived_state="T", metadata=b"1234" + ) + ts = tables.tree_sequence() + + ts_split = ts.split_edges(1) + # 2.00┊ 4 ┊ + # ┊ x━┻┓ ┊ + # 1.00┊ 5 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert ts_split.num_nodes == 6 + mut = ts_split.mutation(0) + assert mut.node == 5 + assert mut.derived_state == "T" + assert mut.metadata == b"1234" + assert mut.time == 1.5 + + def test_single_mutation_split_at_time(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ x 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row( + site=0, node=0, time=1, derived_state="T", metadata=b"1234" + ) + ts = tables.tree_sequence() + + ts_split = ts.split_edges(1) + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ 5x 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + mut = ts_split.mutation(0) + assert mut.node == 5 + assert mut.derived_state == "T" + assert mut.metadata == b"1234" + assert mut.time == 1.0 + + def test_multi_mutation_no_time(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ x 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + tables.mutations.add_row(site=0, node=0, parent=0, derived_state="G") + ts = tables.tree_sequence() + + ts_split = ts.split_edges(1) + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # ┊ 5 3 ┊ + # ┊ x ┃ ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + ts_split.tables.mutations.assert_equals(tables.mutations) + + def test_multi_mutation_over_sample_time(self): + # 2.00┊ 4 ┊ + # ┊ x━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, time=1.01, derived_state="T") + tables.mutations.add_row(site=0, node=0, time=0.99, parent=0, derived_state="G") + ts = tables.tree_sequence() + + ts_split = ts.split_edges(1) + # 2.00┊ 4 ┊ + # ┊ x━┻┓ ┊ + # 1.00┊ 5 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert ts_split.num_mutations == 2 + + mut = ts_split.mutation(0) + assert mut.site == 0 + assert mut.node == 5 + assert mut.time == 1.01 + mut = ts_split.mutation(1) + assert mut.site == 0 + assert mut.node == 0 + assert mut.time == 0.99 + + def test_mutation_not_on_branch(self): + tables = tskit.TableCollection(1) + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + ts = tables.tree_sequence() + tables.assert_equals(ts.split_edges(0).tables, ignore_provenance=True) + + +class TestSplitEdgesExamples: + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_genotypes_round_trip(self, ts): + time = 0 if ts.num_nodes == 0 else np.median(ts.tables.nodes.time) + if ts.num_migrations == 0: + split_ts = ts.split_edges(time) + assert np.array_equal(split_ts.genotype_matrix(), ts.genotype_matrix()) + else: + with pytest.raises(tskit.LibraryError): + ts.split_edges(time) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("population", [-1, None]) + def test_definition(self, ts, population): + time = 0 if ts.num_nodes == 0 else np.median(ts.tables.nodes.time) + if ts.num_migrations == 0: + ts1 = split_edges_definition(ts, time, population=population) + ts2 = ts.split_edges(time, population=population) + ts1.tables.assert_equals(ts2.tables, ignore_provenance=True) + + +class TestSplitEdgesInterface: + def test_migrations_fail(self, ts_fixture): + assert ts_fixture.num_migrations > 0 + with pytest.raises(tskit.LibraryError, match="MIGRATIONS_NOT_SUPPORTED"): + ts_fixture.split_edges(0) + + def test_population_out_of_bounds(self): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + with pytest.raises(tskit.LibraryError, match="POPULATION_OUT_OF_BOUNDS"): + ts.split_edges(0, population=0) + + def test_bad_flags(self): + ts = tskit.TableCollection(1).tree_sequence() + with pytest.raises(TypeError): + ts.split_edges(0, flags="asdf") + + def test_bad_metadata_no_schema(self): + ts = tskit.TableCollection(1).tree_sequence() + with pytest.raises(TypeError): + ts.split_edges(0, metadata="asdf") + + def test_bad_metadata_json_schema(self): + tables = tskit.TableCollection(1) + tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json() + ts = tables.tree_sequence() + with pytest.raises(tskit.MetadataEncodingError): + ts.split_edges(0, metadata=b"bytes") + + @pytest.mark.parametrize("time", [math.inf, np.inf, tskit.UNKNOWN_TIME, np.nan]) + def test_nonfinite_time(self, time): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + with pytest.raises(tskit.LibraryError, match="TIME_NONFINITE"): + ts.split_edges(time) + + +class TestSplitEdgesNodeValues: + @tests.cached_example + def ts(self): + tables = tskit.TableCollection(1) + for _ in range(5): + tables.populations.add_row() + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, population=0, time=0) + tables.nodes.add_row(time=1) + tables.edges.add_row(0, 1, 1, 0) + return tables.tree_sequence() + + @tests.cached_example + def ts_with_schema(self): + tables = tskit.TableCollection(1) + for _ in range(5): + tables.populations.add_row() + tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json() + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, population=0, time=0) + tables.nodes.add_row(time=1) + tables.edges.add_row(0, 1, 1, 0) + return tables.tree_sequence() + + def test_default_population(self): + ts = self.ts().split_edges(0.5) + assert ts.node(2).population == 0 + + @pytest.mark.parametrize("population", range(-1, 5)) + def test_specify_population(self, population): + ts = self.ts().split_edges(0.5, population=population) + assert ts.node(2).population == population + + def test_default_flags(self): + ts = self.ts().split_edges(0.5) + assert ts.node(2).flags == 0 + + @pytest.mark.parametrize("flags", range(0, 5)) + def test_specify_flags(self, flags): + ts = self.ts().split_edges(0.5, flags=flags) + assert ts.node(2).flags == flags + + def test_default_metadata_no_schema(self): + ts = self.ts().split_edges(0.5) + assert ts.node(2).metadata == b"" + + @pytest.mark.parametrize("metadata", [b"", b"some bytes"]) + def test_specify_metadata_no_schema(self, metadata): + ts = self.ts().split_edges(0.5, metadata=metadata) + assert ts.node(2).metadata == metadata + + def test_default_metadata_with_schema(self): + ts = self.ts_with_schema().split_edges(0.5) + assert ts.node(2).metadata == {} + + @pytest.mark.parametrize("metadata", [{}, {"some": "json"}]) + def test_specify_metadata_with_schema(self, metadata): + ts = self.ts_with_schema().split_edges(0.5, metadata=metadata) + assert ts.node(2).metadata == metadata diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 11b83d6bdb..dcfc1dd083 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -5935,6 +5935,64 @@ def trim(self, record_provenance=True): tables.trim(record_provenance) return tables.tree_sequence() + def split_edges(self, time, *, flags=None, population=None, metadata=None): + """ + Returns a copy of this tree sequence in which we replace any + edge ``(left, right, parent, child)`` in which + ``node_time[child] < time < node_time[parent]`` with two edges + ``(left, right, parent, u)`` and ``(left, right, u, child)``, + where ``u`` is a newly added node for each intersecting edge. + + If ``metadata``, ``flags``, or ``population`` are specified, newly + added nodes will be assigned these values. Otherwise, default values + will be used. The default metadata is an empty dictionary if a metadata + schema is defined for the node table, and is an empty byte string + otherwise. The default population for the new node will be derived from + the population of the edge's child. Newly added have a default + ``flags`` value of 0. + + Any metadata associated with a split edge will be copied to the new edge. + + .. warning:: This method currently does not support migrations + and a error will be raised if the migration table is not + empty. Future versions may take migrations that intersect with the + edge into account when determining the default population + assignments for new nodes. + + Any mutations lying on the edge whose time is >= ``time`` will have + their node value set to ``u``. Note that the time of the mutation is + defined as the time of the child node if the mutation's time is + unknown. + + :param float time: The cutoff time. + :param int flags: The flags value for newly-inserted nodes. (Default = 0) + :param int population: The population value for newly inserted nodes. + Defaults to the population of the child node of the split edge + if not specified. + :param metadata: The metadata for any newly inserted nodes. See + :meth:`.NodeTable.add_row` for details on how default metadata + is produced for a given schema (or none). + :return: A copy of this tree sequence with edges split at the specified time. + :rtype: tskit.TreeSequence + """ + impute_population = False + if population is None: + impute_population = True + population = tskit.NULL + flags = 0 if flags is None else flags + schema = self.table_metadata_schemas.node + if metadata is None: + metadata = schema.empty_value + metadata = schema.validate_and_encode_row(metadata) + ll_ts = self._ll_tree_sequence.split_edges( + time=time, + flags=flags, + population=population, + metadata=metadata, + impute_population=impute_population, + ) + return TreeSequence(ll_ts) + def subset( self, nodes,