diff --git a/docs/python-api.md b/docs/python-api.md index 8a9db91812..4b99ff5df5 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -219,6 +219,7 @@ which perform the same actions but modify the {class}`TableCollection` in place. TreeSequence.delete_sites TreeSequence.trim TreeSequence.split_edges + TreeSequence.decapitate ``` (sec_python_api_tree_sequences_ibd)= @@ -683,6 +684,7 @@ a functional way, returning a new tree sequence while leaving the original uncha TableCollection.delete_sites TableCollection.trim TableCollection.union + TableCollection.delete_older ``` (sec_tables_api_creating_valid_tree_sequence)= diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 8bdfaf9874..6cebfa942d 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -30,6 +30,11 @@ edges at a specific time. (:user:`jeromekelleher`, :issue:`2276`, :pr:`2296`). +- Add the ``TreeSequence.decapitate`` (and closely related + ``TableCollection.delete_older``) operation to remove topology and mutations + older than a give time. + (:user:`jeromekelleher`, :issue:`2236`, :pr:`2302`, :pr:`2331`). + - Add the ``TreeSequence.individuals_time`` and ``TreeSequence.individuals_population`` methods to return arrays of per-individual times and populations, respectively. (:user:`petrelharp`, :issue:`1481`, :pr:`2298`). diff --git a/python/tests/test_table_transforms.py b/python/tests/test_table_transforms.py index 06f941271e..a137c8d074 100644 --- a/python/tests/test_table_transforms.py +++ b/python/tests/test_table_transforms.py @@ -22,6 +22,9 @@ """ Test cases for table transformation operations like trim(), decapitate, etc. """ +import decimal +import fractions +import io import math import numpy as np @@ -320,7 +323,7 @@ def test_older(self, time): tables.migrations.assert_equals(before.migrations) -def split_edges_definition(ts, time, *, flags=None, population=None, metadata=None): +def split_edges_definition(ts, time, *, flags=0, population=None, metadata=None): tables = ts.dump_tables() if ts.num_migrations > 0: raise ValueError("Migrations not supported") @@ -329,13 +332,6 @@ def split_edges_definition(ts, time, *, flags=None, population=None, metadata=No # -1 is a valid value if population < -1 or population >= ts.num_populations: raise ValueError("Population out of bounds") - flags = 0 if flags is None else flags - if metadata is None: - metadata = tables.nodes.metadata_schema.empty_value - metadata = tables.nodes.metadata_schema.validate_and_encode_row(metadata) - # This is the easiest way to turn off encoding when calling add_row below - schema = tables.nodes.metadata_schema - tables.nodes.metadata_schema = tskit.MetadataSchema(None) node_time = tables.nodes.time node_population = tables.nodes.population @@ -353,8 +349,6 @@ def split_edges_definition(ts, time, *, flags=None, population=None, metadata=No split_edge[edge.id] = u else: tables.edges.append(edge) - # Reinstate schema - tables.nodes.metadata_schema = schema tables.mutations.clear() for mutation in ts.mutations(): @@ -712,3 +706,521 @@ def test_default_metadata_with_schema(self): def test_specify_metadata_with_schema(self, metadata): ts = self.ts_with_schema().split_edges(0.5, metadata=metadata) assert ts.node(2).metadata == metadata + + +def decapitate_definition(ts, time, *, flags=0, population=None, metadata=None): + """ + Simple loop implementation of the decapitate operation + """ + default_population = population is None + + tables = ts.dump_tables() + node_time = tables.nodes.time + tables.edges.clear() + for edge in ts.edges(): + if node_time[edge.parent] <= time: + tables.edges.append(edge) + elif node_time[edge.child] < time: + if default_population: + population = tables.nodes[edge.child].population + new_parent = tables.nodes.add_row( + time=time, population=population, flags=flags, metadata=metadata + ) + tables.edges.append(edge.replace(parent=new_parent)) + + tables.mutations.clear() + for mutation in ts.mutations(): + mutation_time = ( + node_time[mutation.node] + if util.is_unknown_time(mutation.time) + else mutation.time + ) + if mutation_time < time: + tables.mutations.append(mutation.replace(parent=tskit.NULL)) + + tables.migrations.clear() + for migration in ts.migrations(): + if migration.time <= time: + tables.migrations.append(migration) + + tables.build_index() + tables.compute_mutation_parents() + return tables.tree_sequence() + + +class TestDecapitateExamples: + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_defaults(self, ts): + time = 0 if ts.num_nodes == 0 else np.median(ts.tables.nodes.time) + if ts.num_migrations == 0: + decap1 = decapitate_definition(ts, time) + decap2 = ts.decapitate(time) + decap1.tables.assert_equals(decap2.tables, ignore_provenance=True) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_no_population(self, ts): + time = 0 if ts.num_nodes == 0 else np.median(ts.tables.nodes.time) + if ts.num_migrations == 0: + decap1 = decapitate_definition(ts, time, population=-1) + decap2 = ts.decapitate(time, population=-1) + decap1.tables.assert_equals(decap2.tables, ignore_provenance=True) + + +class TestDecapitateSimpleTree: + + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + @tests.cached_example + def ts(self): + tree = tskit.Tree.generate_balanced(3, branch_length=1) + return tree.tree_sequence + + @pytest.mark.parametrize("time", [0, -0.5, -100]) + def test_t0_or_before(self, time): + before = self.ts() + ts = before.decapitate(time) + assert ts.num_trees == 1 + tree = ts.first() + assert tree.num_roots == 3 + assert list(sorted(tree.roots)) == [0, 1, 2] + assert before.tables.nodes.equals(ts.tables.nodes[: before.num_nodes]) + assert ts.num_edges == 0 + + @pytest.mark.parametrize("time", [0.01, 0.5, 0.999]) + def test_t0_to_1(self, time): + # + # 2.00┊ ┊ + # ┊ ┊ + # 0.99┊ 7 5 6 ┊ + # ┊ ┃ ┃ ┃ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + before = self.ts() + ts = before.decapitate(time) + assert ts.num_trees == 1 + tree = ts.first() + assert tree.num_roots == 3 + assert list(sorted(tree.roots)) == [5, 6, 7] + assert ts.num_nodes == 8 + assert ts.tables.nodes[5].time == time + assert ts.tables.nodes[6].time == time + assert ts.tables.nodes[7].time == time + + def test_t1(self): + # + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 5 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + before = self.ts() + ts = before.decapitate(1) + assert ts.num_trees == 1 + tree = ts.first() + assert tree.num_roots == 2 + assert list(sorted(tree.roots)) == [3, 5] + assert ts.num_nodes == 6 + assert ts.tables.nodes[5].time == 1 + + @pytest.mark.parametrize("time", [1.01, 1.5, 1.999]) + def test_t1_to_2(self, time): + # 2.00┊ ┊ + # ┊ ┊ + # 1.01┊ 5 6 ┊ + # ┊ ┃ ┃ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + before = self.ts() + ts = before.decapitate(time) + assert ts.num_trees == 1 + tree = ts.first() + assert tree.num_roots == 2 + assert list(sorted(tree.roots)) == [5, 6] + assert ts.num_nodes == 7 + assert ts.tables.nodes[5].time == time + assert ts.tables.nodes[6].time == time + + @pytest.mark.parametrize("time", [2, 2.5, 1e9]) + def test_t2(self, time): + before = self.ts() + ts = before.decapitate(time) + ts.tables.assert_equals(before.tables, ignore_provenance=True) + + +class TestDecapitateSimpleTreeMutationExamples: + def test_single_mutation_over_sample(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + before = tables.tree_sequence() + + ts = before.decapitate(1) + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 5 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + before.tables.mutations.assert_equals(ts.tables.mutations) + assert list(before.alignments()) == list(ts.alignments()) + + def test_single_mutation_at_decap_time(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ x 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, time=1, derived_state="T") + before = tables.tree_sequence() + + # Because the mutation is at exactly the decapitation time, we must + # remove it, or it would violate the requirement that a mutation must + # have a time less than that of the parent of the edge that its on. + ts = before.decapitate(1) + # 2.00┊ ┊ + # ┊ ┊ + # 1.00┊ 5 3 ┊ + # ┊ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert ts.num_mutations == 0 + assert list(ts.alignments()) == ["A", "A", "A"] + + def test_multi_mutation_over_sample(self): + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ x 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, derived_state="T") + tables.mutations.add_row(site=0, node=0, parent=0, derived_state="G") + before = tables.tree_sequence() + + ts = before.decapitate(1) + # 2.00┊ ┊ + # ┊ 5 3 ┊ + # ┊ x ┃ ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + before.tables.mutations.assert_equals(ts.tables.mutations) + assert list(before.alignments()) == list(ts.alignments()) + + def test_multi_mutation_over_sample_time(self): + # 2.00┊ 4 ┊ + # ┊ x━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=0, time=1.01, derived_state="T") + tables.mutations.add_row(site=0, node=0, time=0.99, parent=0, derived_state="G") + before = tables.tree_sequence() + + ts = before.decapitate(1) + # 2.00┊ ┊ + # ┊ 5 3 ┊ + # ┊ ┃ ┃ ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert ts.num_mutations == 1 + # Alignments are equal because the ancestral mutation was silent anyway. + assert list(before.alignments()) == list(ts.alignments()) + + def test_multi_mutation_over_root(self): + # x + # 2.00┊ 4 ┊ + # ┊ ┏━┻┓ ┊ + # 1.00┊ ┃ 3 ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + tree = tskit.Tree.generate_balanced(3, branch_length=1) + tables = tree.tree_sequence.dump_tables() + tables.sites.add_row(0, "A") + tables.mutations.add_row(site=0, node=4, derived_state="G") + tables.mutations.add_row(site=0, node=0, parent=0, derived_state="T") + before = tables.tree_sequence() + + ts = before.decapitate(1) + # 2.00┊ ┊ + # ┊ 5 3 ┊ + # ┊ ┃ ┃ ┊ + # ┊ x ┏┻┓ ┊ + # 0.00┊ 0 1 2 ┊ + # 0 1 + assert ts.num_mutations == 1 + assert list(before.alignments()) == ["T", "G", "G"] + # The states inherited by samples changes because we drop the old mutation + assert list(ts.alignments()) == ["T", "A", "A"] + + +class TestDecapitateSimpleTsExample: + # 9.08┊ 9 ┊ ┊ ┊ ┊ ┊ + # ┊ ┏━┻━┓ ┊ ┊ ┊ ┊ ┊ + # 6.57┊ ┃ ┃ ┊ ┊ ┊ ┊ 8 ┊ + # ┊ ┃ ┃ ┊ ┊ ┊ ┊ ┏━┻━┓ ┊ + # 5.31┊ ┃ ┃ ┊ 7 ┊ ┊ 7 ┊ ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┏━┻━┓ ┊ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ + # 1.75┊ ┃ ┃ ┊ ┃ ┃ ┊ 6 ┊ ┃ ┃ ┊ ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┃ ┃ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ ┃ ┃ ┊ + # 1.11┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ + # ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ + # 0.11┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ + # ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + # 0.00 0.06 0.79 0.91 0.91 1.00 + + @tests.cached_example + def ts(self): + nodes = io.StringIO( + """\ + id is_sample population individual time metadata + 0 1 0 -1 0 + 1 1 0 -1 0 + 2 1 0 -1 0 + 3 1 0 -1 0 + 4 0 0 -1 0.114 + 5 0 0 -1 1.110 + 6 0 0 -1 1.750 + 7 0 0 -1 5.310 + 8 0 0 -1 6.573 + 9 0 0 -1 9.083 + """ + ) + edges = io.StringIO( + """\ + id left right parent child + 0 0.00000000 1.00000000 4 0 + 1 0.00000000 1.00000000 4 1 + 2 0.00000000 1.00000000 5 2 + 3 0.00000000 1.00000000 5 3 + 4 0.79258618 0.90634460 6 4 + 5 0.79258618 0.90634460 6 5 + 6 0.05975243 0.79258618 7 4 + 7 0.90634460 0.91029435 7 4 + 8 0.05975243 0.79258618 7 5 + 9 0.90634460 0.91029435 7 5 + 10 0.91029435 1.00000000 8 4 + 11 0.91029435 1.00000000 8 5 + 12 0.00000000 0.05975243 9 4 + 13 0.00000000 0.05975243 9 5 + """ + ) + sites = io.StringIO( + """\ + position ancestral_state + 0.05 A + 0.06 0 + 0.3 C + 0.5 AAA + 0.91 T + """ + ) + muts = io.StringIO( + """\ + site node derived_state parent time + 0 9 T -1 15 + 0 9 GGG 0 9.1 + 0 5 1 1 9 + 1 4 C -1 1.6 + 1 4 G 3 1.5 + 2 7 G -1 10 + 2 3 C 5 1 + 4 3 G -1 1 + """ + ) + ts = tskit.load_text(nodes, edges, sites=sites, mutations=muts, strict=False) + return ts + + def test_at_time_of_5(self): + # NOTE: we don't remember that the edge 4-7 was shared in trees 1 and 3. + # 1.11┊ 14 5 ┊ 11 5 ┊ 10 5 ┊ 12 5 ┊ 13 5 ┊ + # ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ + # 0.11┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ + # ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + # 0.00 0.06 0.79 0.91 0.91 1.00 + ts = self.ts().decapitate(1.110) + assert ts.num_nodes == 15 + assert ts.num_trees == 5 + # Most mutations are older than this. + assert ts.num_mutations == 2 + for u in range(10, 15): + node = ts.node(u) + assert node.time == 1.110 + assert node.flags == 0 + assert [set(tree.roots) for tree in ts.trees()] == [ + {5, 14}, + {11, 5}, + {10, 5}, + {12, 5}, + {13, 5}, + ] + + def test_at_time6(self): + # 6 ┊ 12 13 ┊ ┊ ┊ ┊ 10 11 ┊ + # 5.31┊ ┃ ┃ ┊ 7 ┊ ┊ 7 ┊ ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┏━┻━┓ ┊ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ + # 1.75┊ ┃ ┃ ┊ ┃ ┃ ┊ 6 ┊ ┃ ┃ ┊ ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┃ ┃ ┊ ┏━┻━┓ ┊ ┃ ┃ ┊ ┃ ┃ ┊ + # 1.11┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ ┃ 5 ┊ + # ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┊ + # 0.11┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ 4 ┃ ┃ ┊ + # ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ 0 1 2 3 ┊ + # 0.00 0.06 0.79 0.91 0.91 1.00 + ts = self.ts().decapitate(6) + assert ts.num_nodes == 14 + assert ts.num_trees == 5 + assert ts.num_mutations == 4 + for u in range(10, 14): + node = ts.node(u) + assert node.time == 6 + assert node.flags == 0 + assert [set(tree.roots) for tree in ts.trees()] == [ + {12, 13}, + {7}, + {6}, + {7}, + {10, 11}, + ] + + +class TestDecapitateNodeValues: + @tests.cached_example + def ts(self): + tables = tskit.TableCollection(1) + for _ in range(5): + tables.populations.add_row() + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, population=0, time=0) + tables.nodes.add_row(time=1) + tables.edges.add_row(0, 1, 1, 0) + return tables.tree_sequence() + + @tests.cached_example + def ts_with_schema(self): + tables = tskit.TableCollection(1) + for _ in range(5): + tables.populations.add_row() + tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json() + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, population=0, time=0) + tables.nodes.add_row(time=1) + tables.edges.add_row(0, 1, 1, 0) + return tables.tree_sequence() + + def test_default_population(self): + ts = self.ts().decapitate(0.5) + assert ts.node(2).population == 0 + + @pytest.mark.parametrize("population", range(-1, 5)) + def test_specify_population(self, population): + ts = self.ts().decapitate(0.5, population=population) + assert ts.node(2).population == population + + def test_default_flags(self): + ts = self.ts().decapitate(0.5) + assert ts.node(2).flags == 0 + + @pytest.mark.parametrize("flags", range(0, 5)) + def test_specify_flags(self, flags): + ts = self.ts().decapitate(0.5, flags=flags) + assert ts.node(2).flags == flags + + def test_default_metadata_no_schema(self): + ts = self.ts().decapitate(0.5) + assert ts.node(2).metadata == b"" + + @pytest.mark.parametrize("metadata", [b"", b"some bytes"]) + def test_specify_metadata_no_schema(self, metadata): + ts = self.ts().decapitate(0.5, metadata=metadata) + assert ts.node(2).metadata == metadata + + def test_default_metadata_with_schema(self): + ts = self.ts_with_schema().decapitate(0.5) + assert ts.node(2).metadata == {} + + @pytest.mark.parametrize("metadata", [{}, {"some": "json"}]) + def test_specify_metadata_with_schema(self, metadata): + ts = self.ts_with_schema().decapitate(0.5, metadata=metadata) + assert ts.node(2).metadata == metadata + + +class TestDecapitateInterface: + @tests.cached_example + def ts(self): + tree = tskit.Tree.generate_balanced(3, branch_length=1) + return tree.tree_sequence + + @pytest.mark.parametrize("bad_type", ["x", "0.1", [], [0.1]]) + def test_bad_types(self, bad_type): + with pytest.raises(TypeError, match="number"): + self.ts().decapitate(bad_type) + + @pytest.mark.parametrize( + "time", [1, 1.0, np.array([1])[0], fractions.Fraction(1, 1), decimal.Decimal(1)] + ) + def test_number_types(self, time): + expected = self.ts().decapitate(1) + got = self.ts().decapitate(time) + expected.tables.assert_equals(got.tables, ignore_timestamps=True) + + def test_migrations_not_supported(self, ts_fixture): + with pytest.raises(tskit.LibraryError, match="MIGRATIONS_NOT_SUPPORTED"): + ts_fixture.decapitate(0) + + def test_population_out_of_bounds(self): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + with pytest.raises(tskit.LibraryError, match="POPULATION_OUT_OF_BOUNDS"): + ts.decapitate(0, population=0) + + def test_bad_flags(self): + ts = tskit.TableCollection(1).tree_sequence() + with pytest.raises(TypeError): + ts.decapitate(0, flags="asdf") + + def test_bad_metadata_no_schema(self): + ts = tskit.TableCollection(1).tree_sequence() + with pytest.raises(TypeError): + ts.decapitate(0, metadata="asdf") + + def test_bad_metadata_json_schema(self): + tables = tskit.TableCollection(1) + tables.nodes.metadata_schema = tskit.MetadataSchema.permissive_json() + ts = tables.tree_sequence() + with pytest.raises(tskit.MetadataEncodingError): + ts.decapitate(0, metadata=b"bytes") + + @pytest.mark.parametrize("time", [math.inf, np.inf, tskit.UNKNOWN_TIME, np.nan]) + def test_nonfinite_time(self, time): + tables = tskit.TableCollection(1) + ts = tables.tree_sequence() + with pytest.raises(tskit.LibraryError, match="TIME_NONFINITE"): + ts.decapitate(time) diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 8ac2e6a1fc..156fbc14f6 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -3867,9 +3867,16 @@ def delete_older(self, time): Deletes edge, mutation and migration information at least as old as the specified time. - For the purposes of this method, an edge covers the times from the child node - up until the *parent* node, so that any any edge with parent node time > ``time`` - will be removed. + .. seealso:: This method is similar to the higher-level + :meth:`TreeSequence.decapitate` method, which also splits + edges that intersect with the given time. + :meth:`TreeSequence.decapitate` + is more useful for most purposes, and may be what + you need instead of this method! + + For the purposes of this method, an edge covers the times from the + child node up until the *parent* node, so that any any edge with parent + node time > ``time`` will be removed. Any mutation whose time is >= ``time`` will be removed. A mutation's time is its associated ``time`` value, or the time of its node if the diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 2d099846dc..cd6ebda056 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6081,6 +6081,53 @@ def split_edges(self, time, *, flags=None, population=None, metadata=None): ) return TreeSequence(ll_ts) + def decapitate(self, time, *, flags=None, population=None, metadata=None): + """ + Delete all edge topology and mutational information at least as old + as the specified time from this tree sequence. + + Removes all edges in which the time of the child is >= the specified + time ``t``, and breaks edges that intersect with ``t``. For each edge + intersecting with ``t`` we create a new node with time equal to ``t``, + and set the parent of the edge to this new node. The node table + is not altered in any other way. Newly added nodes have values + for ``flags``, ``population`` and ``metadata`` controlled by parameters + to this function in the same way as :meth:`.split_edges`. + + .. note:: + Note that each edge is treated independently, so that even if two + edges that are broken by this operation share the same parent and + child nodes, there will be two different new parent nodes inserted. + + Any mutation whose time is >= ``t`` will be removed. A mutation's time + is its associated ``time`` value, or the time of its node if the + mutation's time was marked as unknown (:data:`UNKNOWN_TIME`). + + Migrations are not supported, and a LibraryError will be raise if + called on a tree sequence containing migration information. + + .. seealso:: This method is implemented using the :meth:`.split_edges` + and :meth:`TableCollection.delete_older` functions. + + :param float time: The cutoff time. + :param int flags: The flags value for newly-inserted nodes. (Default = 0) + :param int population: The population value for newly inserted nodes. + Defaults to the population of the child node of the split edge + if not specified. + :param metadata: The metadata for any newly inserted nodes. See + :meth:`.NodeTable.add_row` for details on how default metadata + is produced for a given schema (or none). + :return: A copy of this tree sequence with edges split at the specified time. + :rtype: tskit.TreeSequence + """ + split_ts = self.split_edges( + time, flags=flags, population=population, metadata=metadata + ) + tables = split_ts.dump_tables() + del split_ts + tables.delete_older(time) + return tables.tree_sequence() + def subset( self, nodes,