diff --git a/c/tests/test_tables.c b/c/tests/test_tables.c index 10f722fb6b..40e4f99cb4 100644 --- a/c/tests/test_tables.c +++ b/c/tests/test_tables.c @@ -10155,6 +10155,80 @@ test_table_collection_clear(void) | TSK_CLEAR_TS_METADATA_AND_SCHEMA); } +static void +test_table_collection_decapitate(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_table_collection_t t; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + ret = tsk_treeseq_copy_tables(&ts, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); + + /* Add some migrations */ + tsk_population_table_add_row(&t.populations, NULL, 0); + tsk_population_table_add_row(&t.populations, NULL, 0); + tsk_migration_table_add_row(&t.migrations, 0, 10, 0, 0, 1, 0.05, NULL, 0); + tsk_migration_table_add_row(&t.migrations, 0, 10, 0, 1, 0, 0.09, NULL, 0); + tsk_migration_table_add_row(&t.migrations, 0, 10, 0, 0, 1, 0.10, NULL, 0); + CU_ASSERT_EQUAL(t.migrations.num_rows, 3); + + /* NOTE: haven't worked out the exact IDs on the branches here, just + * for illustration. + 0.09┊ 9 5 10 ┊ 9 5 ┊11 5 ┊ + ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┏━┻┓ ┊ ┃ ┏━┻┓ ┊ + 0.07┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ 4 ┊ ┃ ┃ 4 ┊ + ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┏┻┓ ┊ ┃ ┃ ┏┻┓ ┊ + 0.00┊ 0 1 3 2 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + 0.00 2.00 7.00 10.00 + */ + + ret = tsk_table_collection_decapitate(&t, 0.09, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = tsk_treeseq_init(&ts, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&ts), 3); + CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&ts), 12); + /* Lost the mutation over 5 */ + CU_ASSERT_EQUAL(tsk_treeseq_get_num_mutations(&ts), 2); + /* We keep the migration at exactly 0.09. */ + CU_ASSERT_EQUAL(tsk_treeseq_get_num_migrations(&ts), 2); + + tsk_table_collection_free(&t); + tsk_treeseq_free(&ts); +} + +static void +test_table_collection_decapitate_errors(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_table_collection_t t; + + tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, + paper_ex_mutations, paper_ex_individuals, NULL, 0); + ret = tsk_treeseq_copy_tables(&ts, &t, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tsk_treeseq_free(&ts); + + /* This should be caught later when we try to index */ + reverse_edges(&t); + ret = tsk_table_collection_decapitate(&t, 0.09, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_EDGES_NOT_SORTED_CHILD); + + /* This should be caught immediately on entry to the function */ + t.sequence_length = -1; + ret = tsk_table_collection_decapitate(&t, 0.09, 0); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SEQUENCE_LENGTH); + + tsk_table_collection_free(&t); +} + static void test_table_collection_takeset_indexes(void) { @@ -10317,6 +10391,9 @@ main(int argc, char **argv) test_table_collection_union_middle_merge }, { "test_table_collection_union_errors", test_table_collection_union_errors }, { "test_table_collection_clear", test_table_collection_clear }, + { "test_table_collection_decapitate", test_table_collection_decapitate }, + { "test_table_collection_decapitate_errors", + test_table_collection_decapitate_errors }, { "test_table_collection_takeset_indexes", test_table_collection_takeset_indexes }, { NULL, NULL }, diff --git a/c/tskit/tables.c b/c/tskit/tables.c index 5fb3c1fe38..a2cc4ecc85 100644 --- a/c/tskit/tables.c +++ b/c/tskit/tables.c @@ -12646,6 +12646,133 @@ tsk_table_collection_union(tsk_table_collection_t *self, return ret; } +int +tsk_table_collection_decapitate( + tsk_table_collection_t *self, double time, tsk_flags_t TSK_UNUSED(options)) +{ + int ret = 0; + tsk_edge_t edge; + tsk_mutation_t mutation; + tsk_migration_t migration; + tsk_edge_table_t edges; + tsk_mutation_table_t mutations; + tsk_migration_table_t migrations; + const double *restrict node_time = self->nodes.time; + tsk_id_t j, ret_id; + double mutation_time; + + memset(&edges, 0, sizeof(edges)); + memset(&mutations, 0, sizeof(mutations)); + memset(&migrations, 0, sizeof(migrations)); + + /* Note: perhaps should be stricter about what we accept here in terms + * of sorting, but compute_mutation_parents below will check anyway and + * so we're safe while we're calling that. + */ + ret = (int) tsk_table_collection_check_integrity(self, 0); + if (ret != 0) { + goto out; + } + + ret = tsk_edge_table_copy(&self->edges, &edges, 0); + if (ret != 0) { + goto out; + } + /* Note: we are assuming below that the tables are sorted, so we could save + * a bit of time and memory here by truncating part-way through the + * edges. */ + ret = tsk_edge_table_clear(&self->edges); + if (ret != 0) { + goto out; + } + for (j = 0; j < (tsk_id_t) edges.num_rows; j++) { + tsk_edge_table_get_row_unsafe(&edges, j, &edge); + if (node_time[edge.child] < time) { + if (time < node_time[edge.parent]) { + ret_id = tsk_node_table_add_row( + &self->nodes, 0, time, TSK_NULL, TSK_NULL, NULL, 0); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + edge.parent = ret_id; + } + ret_id = tsk_edge_table_add_row(&self->edges, edge.left, edge.right, + edge.parent, edge.child, edge.metadata, edge.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + } + } + /* Calling x_table_free multiple times is safe, so get rid of the + * extra edge table memory as soon as we can. */ + tsk_edge_table_free(&edges); + + ret = tsk_mutation_table_copy(&self->mutations, &mutations, 0); + if (ret != 0) { + goto out; + } + ret = tsk_mutation_table_clear(&self->mutations); + if (ret != 0) { + goto out; + } + for (j = 0; j < (tsk_id_t) mutations.num_rows; j++) { + tsk_mutation_table_get_row_unsafe(&mutations, j, &mutation); + mutation_time = tsk_is_unknown_time(mutation.time) ? node_time[mutation.node] + : mutation.time; + if (mutation_time < time) { + /* Set the mutation parent to NULL, and recalculate below */ + ret_id = tsk_mutation_table_add_row(&self->mutations, mutation.site, + mutation.node, TSK_NULL, mutation.time, mutation.derived_state, + mutation.derived_state_length, mutation.metadata, + mutation.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + } + } + tsk_mutation_table_free(&mutations); + + ret = tsk_migration_table_copy(&self->migrations, &migrations, 0); + if (ret != 0) { + goto out; + } + ret = tsk_migration_table_clear(&self->migrations); + if (ret != 0) { + goto out; + } + for (j = 0; j < (tsk_id_t) migrations.num_rows; j++) { + tsk_migration_table_get_row_unsafe(&migrations, j, &migration); + if (migration.time <= time) { + ret_id = tsk_migration_table_add_row(&self->migrations, migration.left, + migration.right, migration.node, migration.source, migration.dest, + migration.time, migration.metadata, migration.metadata_length); + if (ret_id < 0) { + ret = (int) ret_id; + goto out; + } + } + } + tsk_migration_table_free(&migrations); + + ret = tsk_table_collection_build_index(self, 0); + if (ret != 0) { + goto out; + } + ret = tsk_table_collection_compute_mutation_parents(self, 0); + if (ret != 0) { + goto out; + } + +out: + tsk_edge_table_free(&edges); + tsk_mutation_table_free(&mutations); + tsk_migration_table_free(&migrations); + return ret; +} + static int cmp_edge_cl(const void *a, const void *b) { diff --git a/c/tskit/tables.h b/c/tskit/tables.h index e625f471dc..fc0a716e26 100644 --- a/c/tskit/tables.h +++ b/c/tskit/tables.h @@ -4247,6 +4247,11 @@ int tsk_table_collection_compute_mutation_parents( int tsk_table_collection_compute_mutation_times( tsk_table_collection_t *self, double *random, tsk_flags_t TSK_UNUSED(options)); +/* Not documenting this because we may want to pass through default values + * for the new nodes (in particular the metadata) in the future */ +int tsk_table_collection_decapitate( + tsk_table_collection_t *self, double time, tsk_flags_t options); + int tsk_reference_sequence_init(tsk_reference_sequence_t *self, tsk_flags_t options); int tsk_reference_sequence_free(tsk_reference_sequence_t *self); bool tsk_reference_sequence_is_null(const tsk_reference_sequence_t *self); diff --git a/docs/python-api.md b/docs/python-api.md index b169058c4f..9b0a44ba2d 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -218,6 +218,7 @@ which perform the same actions but modify the {class}`TableCollection` in place. TreeSequence.delete_intervals TreeSequence.delete_sites TreeSequence.trim + TreeSequence.decapitate ``` (sec_python_api_tree_sequences_ibd)= diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 78e7ac2cbc..ea0fc80d82 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -4,7 +4,12 @@ **Changes** -- ``VcfWriter.write`` now prints the site ID of variants in the ID field of the output VCF files. +- Add ``TableCollection.decapitate`` and ``TreeSequence.decapitate`` operations + to remove information from that data model that is older than a specific time. + (:user:`jeromekelleher`, :issue:`2236`, :pr:`2240`) + +- ``VcfWriter.write`` now prints the site ID of variants in the ID field of the + output VCF files. (:user:`roohy`, :issue:`2103`, :pr:`2107`) - Make dumping of tables and tree sequences to disk a zero-copy operation. diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 1091377e10..570dc4fc74 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7004,6 +7004,29 @@ TableCollection_compute_mutation_parents(TableCollection *self) return ret; } +static PyObject * +TableCollection_decapitate(TableCollection *self, PyObject *args) +{ + PyObject *ret = NULL; + int err; + double time; + + if (TableCollection_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "d", &time)) { + goto out; + } + err = tsk_table_collection_decapitate(self->tables, time, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue(""); +out: + return ret; +} + static PyObject * TableCollection_compute_mutation_times(TableCollection *self) { @@ -7489,6 +7512,10 @@ static PyMethodDef TableCollection_methods[] = { .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Returns True if the parameter table collection is equal to this one." }, + { .ml_name = "decapitate", + .ml_meth = (PyCFunction) TableCollection_decapitate, + .ml_flags = METH_VARARGS, + .ml_doc = "Removes information older than the specified time." }, { .ml_name = "compute_mutation_parents", .ml_meth = (PyCFunction) TableCollection_compute_mutation_parents, .ml_flags = METH_NOARGS, diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index cbd9521a4d..1b971780b6 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -116,7 +116,7 @@ def get_multiroot_tree(self): def get_mutations_over_roots_tree(self): ts = msprime.simulate(15, random_seed=1) - ts = tsutil.decapitate(ts, 20) + ts = ts.decapitate(ts.tables.nodes.time[-1] / 2) tables = ts.dump_tables() delta = 1.0 / (ts.num_nodes + 1) x = 0 diff --git a/python/tests/test_genotypes.py b/python/tests/test_genotypes.py index 0683ea2c62..e32f5432d4 100644 --- a/python/tests/test_genotypes.py +++ b/python/tests/test_genotypes.py @@ -1702,7 +1702,8 @@ def test_non_ascii_missing_data_char(self, missing_data_char): class TestAlignmentExamples: @pytest.mark.parametrize("ts", get_example_discrete_genome_tree_sequences()) def test_defaults(self, ts): - if any(tree.num_roots > 1 for tree in ts.trees()): + has_missing_data = np.any(ts.genotype_matrix() == -1) + if has_missing_data: with pytest.raises(ValueError, match="1896"): list(ts.alignments()) else: @@ -1720,7 +1721,8 @@ def test_defaults(self, ts): @pytest.mark.parametrize("ts", get_example_discrete_genome_tree_sequences()) def test_reference_sequence(self, ts): ref = tskit.random_nucleotides(ts.sequence_length, seed=1234) - if any(tree.num_roots > 1 for tree in ts.trees()): + has_missing_data = np.any(ts.genotype_matrix() == -1) + if has_missing_data: with pytest.raises(ValueError, match="1896"): list(ts.alignments(reference_sequence=ref)) else: diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 4100328bc1..a6b4c2cb94 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -266,11 +266,11 @@ def get_decapitated_examples(): Returns example tree sequences in which the oldest edges have been removed. """ ts = msprime.simulate(10, random_seed=1234) - yield tsutil.decapitate(ts, ts.num_edges // 2) + yield ts.decapitate(ts.tables.nodes.time[-1] / 2) ts = msprime.simulate(20, recombination_rate=1, random_seed=1234) assert ts.num_trees > 2 - yield tsutil.decapitate(ts, ts.num_edges // 4) + yield ts.decapitate(ts.tables.nodes.time[-1] / 4) def get_example_tree_sequences(back_mutations=True, gaps=True, internal_samples=True): @@ -3622,7 +3622,7 @@ def test_copy_tracked_samples(self): def test_copy_multiple_roots(self): ts = msprime.simulate(20, recombination_rate=2, length=3, random_seed=42) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for root_threshold in [1, 2, 100]: tree = tskit.Tree(ts, root_threshold=root_threshold) copy = tree.copy() diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 0d3556b340..50e258079a 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -410,6 +410,19 @@ def test_union_bad_args(self): with pytest.raises(ValueError): tc.union(tc2, np.array([[1], [2]], dtype="int32")) + def test_decapitate_bad_args(self): + tc = _tskit.TableCollection(1) + self.get_example_tree_sequence().dump_tables(tc) + with pytest.raises(TypeError): + tc.decapitate() + with pytest.raises(TypeError): + tc.decapitate("1234") + + def test_decapitate_error(self): + tc = _tskit.TableCollection(-1) + with pytest.raises(_tskit.LibraryError, match="Sequence length"): + tc.decapitate(0) + def test_equals_bad_args(self): ts = msprime.simulate(10, random_seed=1242) tc = ts.tables._ll_tables diff --git a/python/tests/test_parsimony.py b/python/tests/test_parsimony.py index f48580a87f..80cec45b96 100644 --- a/python/tests/test_parsimony.py +++ b/python/tests/test_parsimony.py @@ -678,17 +678,17 @@ def test_jukes_cantor_balanced_ternary_internal_samples(self): def test_infinite_sites_n20_multiroot(self): ts = msprime.simulate(20, mutation_rate=3, random_seed=3) - self.verify(tsutil.decapitate(ts, ts.num_edges // 2)) + self.verify(ts.decapitate(np.max(ts.tables.nodes.time) / 2)) def test_jukes_cantor_n15_multiroot(self): ts = msprime.simulate(15, random_seed=1) - ts = tsutil.decapitate(ts, ts.num_edges // 3) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 5) ts = tsutil.jukes_cantor(ts, 15, 2, seed=3) self.verify(ts) def test_jukes_cantor_balanced_ternary_multiroot(self): ts = tskit.Tree.generate_balanced(50, arity=3).tree_sequence - ts = tsutil.decapitate(ts, ts.num_edges // 3) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 3) ts = tsutil.jukes_cantor(ts, 15, 2, seed=3) self.verify(ts) assert ts.num_sites > 1 @@ -696,7 +696,7 @@ def test_jukes_cantor_balanced_ternary_multiroot(self): def test_jukes_cantor_n50_multiroot(self): ts = msprime.simulate(50, random_seed=1) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) ts = tsutil.jukes_cantor(ts, 5, 2, seed=2) self.verify(ts) @@ -1389,7 +1389,7 @@ def test_mutations_over_root(self): def test_all_isolated_different_from_ancestral(self): ts = tskit.Tree.generate_star(6).tree_sequence - ts = tsutil.decapitate(ts, 0) + ts = ts.decapitate(0) tree = ts.first() genotypes = [0, 0, 0, 1, 1, 1] ancestral_state, transitions = self.do_map_mutations( diff --git a/python/tests/test_table_transforms.py b/python/tests/test_table_transforms.py new file mode 100644 index 0000000000..9c9fe4daa9 --- /dev/null +++ b/python/tests/test_table_transforms.py @@ -0,0 +1,533 @@ +# MIT License +# +# Copyright (c) 2022 Tskit Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Test cases for table transformation operations like trim(), decapitate, etc. +""" +import decimal +import fractions +import io +import json + +import numpy as np +import pytest + +import tests +import tskit +import tskit.util as util +from tests.test_highlevel import get_example_tree_sequences + +# ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when +# we can remove this. + + +def decapitate_definition(ts, time): + """ + Simple loop implementation of the decapitate operation + """ + tables = ts.dump_tables() + node_time = tables.nodes.time + tables.edges.clear() + # To avoid having a discrepancy between the C version and the Python + # version here we unset the metadata schema before adding any nodes. + # Setting empty metadata seems like the right thing to do here, but + # there will definitely be cases where doing this will break encoding + # (e.g., suppose there was a fixed-size binary schema - empty metadata + # is going to break this). It's not clear what the right answer is here, + # but this is at least fast to do and simple to describe. + md_schema = tables.nodes.metadata_schema + tables.nodes.metadata_schema = tskit.MetadataSchema(None) + for edge in ts.edges(): + if node_time[edge.parent] <= time: + tables.edges.append(edge) + elif node_time[edge.child] < time: + new_parent = tables.nodes.add_row(time=time) + tables.edges.append(edge.replace(parent=new_parent)) + tables.nodes.metadata_schema = md_schema + + tables.mutations.clear() + for mutation in ts.mutations(): + mutation_time = ( + node_time[mutation.node] + if util.is_unknown_time(mutation.time) + else mutation.time + ) + if mutation_time < time: + tables.mutations.append(mutation.replace(parent=tskit.NULL)) + + tables.migrations.clear() + for migration in ts.migrations(): + if migration.time <= time: + tables.migrations.append(migration) + + tables.build_index() + tables.compute_mutation_parents() + return tables.tree_sequence() + + +@pytest.mark.parametrize("ts", get_example_tree_sequences()) +def test_decapitate_examples(ts): + time = 0 if ts.num_nodes == 0 else np.median(ts.tables.nodes.time) + decap1 = decapitate_definition(ts, time) + decap2 = ts.decapitate(time) + decap1.tables.assert_equals(decap2.tables, ignore_provenance=True) + + +class TestDecapitateSimpleTree: + + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + def tables(self): + # Don't cache this because we modify the result! + tree = tskit.Tree.generate_balanced(3, branch_length=1) + return tree.tree_sequence.dump_tables() + + @pytest.mark.parametrize("time", [0, -0.5, -100]) + def test_t0_or_before(self, time): + tables = self.tables() + before = tables.copy() + tables.decapitate(time) + ts = tables.tree_sequence() + assert ts.num_trees == 1 + tree = ts.first() + assert tree.num_roots == 3 + assert list(sorted(tree.roots)) == [0, 1, 2] + assert before.nodes.equals(tables.nodes[: len(before.nodes)]) + assert len(tables.edges) == 0 + + @pytest.mark.parametrize("time", [0.01, 0.5, 0.999]) + def test_t0_to_1(self, time): + # + # 2.00┊ ┊ + # ┊ ┊ + # 0.99┊ 7 5 6 ┊ + # ┊ ┃ ┃ ┃ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tables = self.tables() + before = tables.copy() + tables.decapitate(time) + ts = tables.tree_sequence() + assert ts.num_trees == 1 + tree = ts.first() + assert tree.num_roots == 3 + assert list(sorted(tree.roots)) == [5, 6, 7] + assert len(tables.nodes) == 8 + assert tables.nodes[5].time == time + assert tables.nodes[6].time == time + assert tables.nodes[7].time == time + assert before.nodes.equals(tables.nodes[: len(before.nodes)]) + + def test_t1(self): + # + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 5 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tables = self.tables() + before = tables.copy() + tables.decapitate(1) + ts = tables.tree_sequence() + assert ts.num_trees == 1 + tree = ts.first() + assert tree.num_roots == 2 + assert list(sorted(tree.roots)) == [3, 5] + assert len(tables.nodes) == 6 + assert tables.nodes[5].time == 1 + assert before.nodes.equals(tables.nodes[: len(before.nodes)]) + + @pytest.mark.parametrize("time", [1.01, 1.5, 1.999]) + def test_t1_to_2(self, time): + # 2.00┊ ┊ + # ┊ ┊ + # 1.01┊ 5 6 ┊ + # ┊ ┃ ┃ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tables = self.tables() + before = tables.copy() + tables.decapitate(time) + ts = tables.tree_sequence() + assert ts.num_trees == 1 + tree = ts.first() + assert tree.num_roots == 2 + assert list(sorted(tree.roots)) == [5, 6] + assert len(tables.nodes) == 7 + assert tables.nodes[5].time == time + assert tables.nodes[6].time == time + assert before.nodes.equals(tables.nodes[: len(before.nodes)]) + + @pytest.mark.parametrize("time", [2, 2.5, 1e9]) + def test_t2(self, time): + tables = self.tables() + before = tables.copy() + tables.decapitate(time) + tables.assert_equals(before, ignore_provenance=True) + + +class TestDecapitateSimpleTreeMutationExamples: + def test_single_mutation_over_sample(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + + before = tables.copy() + tables.decapitate(1) + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 5 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + before.mutations.assert_equals(tables.mutations) + assert list(before.tree_sequence().alignments()) == list( + tables.tree_sequence().alignments() + ) + + def test_single_mutation_at_decap_time(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ x 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, time=1, derived_state="T") + + # Because the mutation is at exactly the decapitation time, we must + # remove it, or it would violate the requirement that a mutation must + # have a time less than that of the parent of the edge that its on. + tables.decapitate(1) + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 5 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert len(tables.mutations) == 0 + assert list(tables.tree_sequence().alignments()) == ["A", "A", "A"] + + def test_multi_mutation_over_sample(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ x 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + tables.mutations.add_row(site=0, node=0, parent=0, derived_state="G") + + before = tables.copy() + tables.decapitate(1) + # 2.00┊ ┊ + # ┊ 5 3 ┊ + # ┊ x ┃ ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + before.mutations.assert_equals(tables.mutations) + assert list(before.tree_sequence().alignments()) == list( + tables.tree_sequence().alignments() + ) + + def test_multi_mutation_over_sample_time(self): + # 2.00┊ 4 ┊ + # ┊ x━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, time=1.01, derived_state="T") + tables.mutations.add_row(site=0, node=0, time=0.99, parent=0, derived_state="G") + + before = tables.copy() + tables.decapitate(1) + # 2.00┊ ┊ + # ┊ 5 3 ┊ + # ┊ ┃ ┃ ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert len(tables.mutations) == 1 + # Alignments are equal because the ancestral mutation was silent anyway. + assert list(before.tree_sequence().alignments()) == list( + tables.tree_sequence().alignments() + ) + + def test_multi_mutation_over_root(self): + # x + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=4, derived_state="G") + tables.mutations.add_row(site=0, node=0, parent=0, derived_state="T") + + before = tables.copy() + tables.decapitate(1) + # 2.00┊ ┊ + # ┊ 5 3 ┊ + # ┊ ┃ ┃ ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert len(tables.mutations) == 1 + assert list(before.tree_sequence().alignments()) == ["T", "G", "G"] + # The states inherited by samples changes because we drop the old mutation + assert list(tables.tree_sequence().alignments()) == ["T", "A", "A"] + + +class TestDecapitateSimpleTreeMigrationExamples: + @tests.cached_example + def ts(self): + # 2.00┊ 4 ┊ + # ┊ o━┻┓ ┊ + # 1.00┊ o 3 ┊ + # ┊ o ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.populations.add_row() + tables.populations.add_row() + tables.migrations.add_row(source=0, dest=1, node=0, time=0.5, left=0, right=1) + tables.migrations.add_row(source=1, dest=0, node=0, time=1.0, left=0, right=1) + tables.migrations.add_row(source=0, dest=1, node=0, time=1.5, left=0, right=1) + tables.compute_mutation_parents() + ts = tables.tree_sequence() + return ts + + def test_t099(self): + ts = self.ts() + tables = ts.decapitate(0.99).tables + assert len(tables.migrations) == 1 + assert tables.migrations[0].time == 0.5 + + def test_t1(self): + ts = self.ts() + tables = ts.decapitate(1).tables + assert len(tables.migrations) == 2 + assert tables.migrations[0].time == 0.5 + assert tables.migrations[1].time == 1.0 + + +class TestSimpleTsExample: + # 9.08┊ 9 ┊ ┊ ┊ ┊ ┊ + # ┊ ┏━┻━┓ ┊ ┊ ┊ ┊ ┊ + # 6.57┊ ┃ ┃ ┊ ┊ ┊ ┊ 8 ┊ + # ┊ ┃ ┃ ┊ ┊ ┊ ┊ ┏━┻━┓ ┊ + # 5.31┊ ┃ ┃ ┊ 7 ┊ ┊ 7 ┊ ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┏━┻━┓ ┊ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ + # 1.75┊ ┃ ┃ ┊ ┃ ┃ ┊ 6 ┊ ┃ ┃ ┊ ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┃ ┃ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ ┃ ┃ ┊ + # 1.11┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ + # ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ + # 0.11┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ + # ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + # 0.00 0.06 0.79 0.91 0.91 1.00 + + @tests.cached_example + def ts(self): + nodes = io.StringIO( + """\ + id is_sample population individual time metadata + 0 1 0 -1 0 + 1 1 0 -1 0 + 2 1 0 -1 0 + 3 1 0 -1 0 + 4 0 0 -1 0.114 + 5 0 0 -1 1.110 + 6 0 0 -1 1.750 + 7 0 0 -1 5.310 + 8 0 0 -1 6.573 + 9 0 0 -1 9.083 + """ + ) + edges = io.StringIO( + """\ + id left right parent child + 0 0.00000000 1.00000000 4 0 + 1 0.00000000 1.00000000 4 1 + 2 0.00000000 1.00000000 5 2 + 3 0.00000000 1.00000000 5 3 + 4 0.79258618 0.90634460 6 4 + 5 0.79258618 0.90634460 6 5 + 6 0.05975243 0.79258618 7 4 + 7 0.90634460 0.91029435 7 4 + 8 0.05975243 0.79258618 7 5 + 9 0.90634460 0.91029435 7 5 + 10 0.91029435 1.00000000 8 4 + 11 0.91029435 1.00000000 8 5 + 12 0.00000000 0.05975243 9 4 + 13 0.00000000 0.05975243 9 5 + """ + ) + sites = io.StringIO( + """\ + position ancestral_state + 0.05 A + 0.06 0 + 0.3 C + 0.5 AAA + 0.91 T + """ + ) + muts = io.StringIO( + """\ + site node derived_state parent time + 0 9 T -1 15 + 0 9 GGG 0 9.1 + 0 5 1 1 9 + 1 4 C -1 1.6 + 1 4 G 3 1.5 + 2 7 G -1 10 + 2 3 C 5 1 + 4 3 G -1 1 + """ + ) + ts = tskit.load_text(nodes, edges, sites=sites, mutations=muts, strict=False) + return ts + + def test_at_time_of_5(self): + # NOTE: we don't remember that the edge 4-7 was shared in trees 1 and 3. + # 1.11┊ 14 5 ┊ 11 5 ┊ 10 5 ┊ 12 5 ┊ 13 5 ┊ + # ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ + # 0.11┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ + # ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + # 0.00 0.06 0.79 0.91 0.91 1.00 + ts = self.ts().decapitate(1.110) + assert ts.num_nodes == 15 + assert ts.num_trees == 5 + # Most mutations are older than this. + assert ts.num_mutations == 2 + for u in range(10, 15): + node = ts.node(u) + assert node.time == 1.110 + assert node.flags == 0 + assert [set(tree.roots) for tree in ts.trees()] == [ + {5, 14}, + {11, 5}, + {10, 5}, + {12, 5}, + {13, 5}, + ] + + def test_at_time6(self): + # 6 ┊ 12 13 ┊ ┊ ┊ ┊ 10 11 ┊ + # 5.31┊ ┃ ┃ ┊ 7 ┊ ┊ 7 ┊ ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┏━┻━┓ ┊ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ + # 1.75┊ ┃ ┃ ┊ ┃ ┃ ┊ 6 ┊ ┃ ┃ ┊ ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┃ ┃ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ ┃ ┃ ┊ + # 1.11┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ + # ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ + # 0.11┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ + # ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + # 0.00 0.06 0.79 0.91 0.91 1.00 + ts = self.ts().decapitate(6) + assert ts.num_nodes == 14 + assert ts.num_trees == 5 + assert ts.num_mutations == 4 + for u in range(10, 14): + node = ts.node(u) + assert node.time == 6 + assert node.flags == 0 + assert [set(tree.roots) for tree in ts.trees()] == [ + {12, 13}, + {7}, + {6}, + {7}, + {10, 11}, + ] + + +class TestDecapitateInterface: + @pytest.mark.parametrize("bad_type", ["x", "0.1", [], [0.1]]) + def test_bad_types(self, ts_fixture, bad_type): + with pytest.raises(TypeError, match="number"): + ts_fixture.decapitate(bad_type) + + @pytest.mark.parametrize( + "time", [1, 1.0, np.array([1])[0], fractions.Fraction(1, 1), decimal.Decimal(1)] + ) + def test_number_types(self, ts_fixture, time): + expected = ts_fixture.decapitate(1) + got = ts_fixture.decapitate(time) + expected.tables.assert_equals(got.tables, ignore_timestamps=True) + + def test_provenance(self, ts_fixture): + ts = ts_fixture.decapitate(1.5) + assert ts.num_provenances == ts_fixture.num_provenances + 1 + prov = json.loads(ts.provenance(ts.num_provenances - 1).record) + assert prov["parameters"] == {"command": "decapitate", "time": 1.5} + + def test_no_provenance(self, ts_fixture): + ts = ts_fixture.decapitate(1.5, record_provenance=False) + assert ts.num_provenances == ts_fixture.num_provenances + + def test_tables_ts_equivalent(self, ts_fixture): + time = 0.5 + ts = ts_fixture.decapitate(time) + tables = ts_fixture.dump_tables() + tables.decapitate(time) + tables.assert_equals(ts.tables, ignore_timestamps=True) + + def test_unsorted_tables_raises_error(self): + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + edges = tables.edges.copy() + tables.edges.clear() + for edge in reversed(edges): + tables.edges.append(edge) + with pytest.raises(tskit.LibraryError, match="order violated"): + tables.tree_sequence() + with pytest.raises(tskit.LibraryError, match="order violated"): + tables.decapitate(1) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index fc15494d68..aa1439871d 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -315,12 +315,12 @@ def test_nonbinary_trees(self): def test_many_multiroot_trees(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) self.verify(ts) def test_multiroot_tree(self): ts = msprime.simulate(15, random_seed=10) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) self.verify(ts) def test_all_missing_data(self): @@ -4832,7 +4832,7 @@ def verify_single_childified(self, ts, keep_unary=False): assert t1.mutations == t2.mutations def verify_multiroot_internal_samples(self, ts, keep_unary=False): - ts_multiroot = tsutil.decapitate(ts, ts.num_edges // 2) + ts_multiroot = ts.decapitate(np.max(ts.tables.nodes.time) / 2) ts1 = tsutil.jiggle_samples(ts_multiroot) ts2, node_map = self.do_simplify(ts1, keep_unary=keep_unary) assert ts1.num_trees >= ts2.num_trees @@ -5556,7 +5556,7 @@ def test_many_trees_recurrent_mutations(self): def test_single_multiroot_tree_recurrent_mutations(self): ts = msprime.simulate(6, random_seed=10) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: ts = tsutil.insert_branch_mutations(ts, mutations_per_branch) for num_samples in range(1, ts.num_samples): @@ -5567,7 +5567,7 @@ def test_single_multiroot_tree_recurrent_mutations(self): def test_many_multiroot_trees_recurrent_mutations(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: ts = tsutil.insert_branch_mutations(ts, mutations_per_branch) for num_samples in range(1, ts.num_samples): @@ -5716,7 +5716,7 @@ def test_many_trees_internal_samples(self): def test_many_multiroot_trees(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for num_samples in range(1, ts.num_samples): for samples in itertools.combinations(ts.samples(), num_samples): self.verify_keep_input_roots(ts, samples) @@ -6051,7 +6051,7 @@ def test_sim_coalescent_trees_internal_samples(self): def test_sim_many_multiroot_trees(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) ancestors = [4 * n for n in np.arange(0, ts.num_nodes // 4)] self.verify(ts, ts.samples(), ancestors) random_samples = [4 * n for n in np.arange(0, ts.num_nodes // 4)] @@ -6189,14 +6189,14 @@ def test_single_tree_three_mutations_per_branch(self): def test_single_multiroot_tree_recurrent_mutations(self): ts = msprime.simulate(6, random_seed=10) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: self.verify_branch_mutations(ts, mutations_per_branch) def test_many_multiroot_trees_recurrent_mutations(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: self.verify_branch_mutations(ts, mutations_per_branch) @@ -6328,14 +6328,14 @@ def test_single_tree_three_mutations_per_branch(self): def test_single_multiroot_tree_recurrent_mutations(self): ts = msprime.simulate(6, random_seed=10) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: self.verify_branch_mutations(ts, mutations_per_branch) def test_many_multiroot_trees_recurrent_mutations(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: self.verify_branch_mutations(ts, mutations_per_branch) @@ -7110,13 +7110,6 @@ def test_zero_sites(self): assert mts.num_trees == 1 assert mts.num_edges == 0 - def test_many_roots(self): - ts = msprime.simulate(25, random_seed=12, recombination_rate=2, length=10) - tables = tsutil.decapitate(ts, ts.num_edges // 2).dump_tables() - for x in range(10): - tables.sites.add_row(x, "0") - self.verify(tables.tree_sequence()) - def test_branch_sites(self): ts = msprime.simulate(15, random_seed=12, recombination_rate=2, length=10) ts = tsutil.insert_branch_sites(ts) diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index aa69d3d4e8..3d49a79c14 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -581,8 +581,8 @@ def test_single_tree_sequence_length(self): self.verify(ts) def test_single_tree_multiple_roots(self): - ts = msprime.simulate(8, random_seed=1) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = msprime.simulate(8, random_seed=1, end_time=0.5) + assert ts.first().num_roots > 1 self.verify(ts) def test_many_trees(self): diff --git a/python/tests/test_utilities.py b/python/tests/test_utilities.py index ef09aa3070..acf3e76e67 100644 --- a/python/tests/test_utilities.py +++ b/python/tests/test_utilities.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2021 Tskit Developers +# Copyright (c) 2019-2022 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,6 +23,7 @@ Tests for the various testing utilities. """ import msprime +import numpy as np import pytest import tests.tsutil as tsutil @@ -43,13 +44,13 @@ def verify(self, ts): def test_n10_multiroot(self): ts = msprime.simulate(10, random_seed=1) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) ts = tsutil.jukes_cantor(ts, 1, 2, seed=7) self.verify(ts) def test_n50_multiroot(self): ts = msprime.simulate(50, random_seed=1) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) ts = tsutil.jukes_cantor(ts, 5, 2, seed=2) self.verify(ts) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index cb59de95f7..60ba5b0250 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -76,24 +76,6 @@ def subsample_sites(ts, num_sites): return t.tree_sequence() -def decapitate(ts, num_edges): - """ - Returns a copy of the specified tree sequence in which the specified number of - edges have been retained. - """ - t = ts.dump_tables() - t.edges.set_columns( - left=t.edges.left[:num_edges], - right=t.edges.right[:num_edges], - parent=t.edges.parent[:num_edges], - child=t.edges.child[:num_edges], - ) - add_provenance(t.provenances, "decapitate") - # Simplify to get rid of any mutations that are lying around above roots. - t.simplify() - return t.tree_sequence() - - def insert_branch_mutations(ts, mutations_per_branch=1): """ Returns a copy of the specified tree sequence with a mutation on every branch diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 9c30d0728a..c2fe259015 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -3862,6 +3862,56 @@ def trim(self, record_provenance=True): record=json.dumps(provenance.get_provenance_dict(parameters)) ) + def decapitate(self, time, *, record_provenance=True): + """ + Delete all edge topology and mutational information older than the + specified time from this set of tables. + + Removes all edges in which the time of the child is >= the specified + time ``t``, and breaks edges that intersect with ``t``. For each edge + intersecting with ``t`` we create a new node with time equal to ``t``, + and set the parent of the edge to this new node. The node table + is not altered in any other way. Newly added nodes have empty metadata, + a ``flags`` value of 0, and NULL ``population`` and ``individual`` + references. + + .. warning:: + The empty metadata values for newly added nodes may not be compatible + with the node table's metadata schema! The current behaviour may + change in future versions to better accomodate metadata schemas. + + .. note:: + Note that each edge is treated independently, so that even if two + edges that are broken by this operation share the same parent and + child nodes, there will be two different new parent nodes inserted. + + Any mutation whose time is >= ``t`` will be removed. A mutation's time + is its associated ``time`` value, or the time of its node if the + mutation's time was marked as unknown (:data:`UNKNOWN_TIME`). + + Any migration with time > ``t`` will be removed. + + .. important:: + The tables must satisfy the standard + :ref:`sortedness requirements `. + + .. note:: + As a side-effect, this method will build the table indexes + and recompute the parents of all mutations. This is an implementation + detail and may not happen in future versions. + + :param float time: The cutoff time. + :param bool record_provenance: If ``True``, add details of this operation + to the provenance table in this TableCollection. (Default: ``True``). + """ + self._ll_tables.decapitate(time) + if record_provenance: + # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 + parameters = {"command": "decapitate", "time": float(time)} + self.provenances.add_row( + record=json.dumps(provenance.get_provenance_dict(parameters)) + ) + def clear( self, clear_provenance=False, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 048da98e22..46b5d44f12 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -5869,6 +5869,20 @@ def trim(self, record_provenance=True): tables.trim(record_provenance) return tables.tree_sequence() + def decapitate(self, time, record_provenance=True): + """ + Return a copy of this tree sequence with topology and mutation information + older than the specified time removed. Please see the + :meth:`.TableCollection.decapitate` method for details. + + :param float time: The cutoff time. + :param bool record_provenance: If True, add details of this operation to the + provenance information of the returned tree sequence. (Default: True). + """ + tables = self.dump_tables() + tables.decapitate(time, record_provenance=record_provenance) + return tables.tree_sequence() + def subset( self, nodes,