diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index a0d9dc1e45..3f3cfbc16e 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -116,7 +116,7 @@ def get_multiroot_tree(self): def get_mutations_over_roots_tree(self): ts = msprime.simulate(15, random_seed=1) - ts = tsutil.decapitate(ts, 20) + ts = ts.decapitate(ts.tables.nodes.time[-1] / 2) tables = ts.dump_tables() delta = 1.0 / (ts.num_nodes + 1) x = 0 diff --git a/python/tests/test_genotypes.py b/python/tests/test_genotypes.py index 803b80e198..ec3fd67104 100644 --- a/python/tests/test_genotypes.py +++ b/python/tests/test_genotypes.py @@ -1703,7 +1703,8 @@ def test_non_ascii_missing_data_char(self, missing_data_char): class TestAlignmentExamples: @pytest.mark.parametrize("ts", get_example_discrete_genome_tree_sequences()) def test_defaults(self, ts): - if any(tree.num_roots > 1 for tree in ts.trees()): + has_missing_data = np.any(ts.genotype_matrix() == -1) + if has_missing_data: with pytest.raises(ValueError, match="1896"): list(ts.alignments()) else: @@ -1721,7 +1722,8 @@ def test_defaults(self, ts): @pytest.mark.parametrize("ts", get_example_discrete_genome_tree_sequences()) def test_reference_sequence(self, ts): ref = tskit.random_nucleotides(ts.sequence_length, seed=1234) - if any(tree.num_roots > 1 for tree in ts.trees()): + has_missing_data = np.any(ts.genotype_matrix() == -1) + if has_missing_data: with pytest.raises(ValueError, match="1896"): list(ts.alignments(reference_sequence=ref)) else: diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index bf7651b27a..b81608209a 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -267,11 +267,11 @@ def get_decapitated_examples(): Returns example tree sequences in which the oldest edges have been removed. """ ts = msprime.simulate(10, random_seed=1234) - yield tsutil.decapitate(ts, ts.num_edges // 2) + yield ts.decapitate(ts.tables.nodes.time[-1] / 2) ts = msprime.simulate(20, recombination_rate=1, random_seed=1234) assert ts.num_trees > 2 - yield tsutil.decapitate(ts, ts.num_edges // 4) + yield ts.decapitate(ts.tables.nodes.time[-1] / 4) def get_example_tree_sequences(back_mutations=True, gaps=True, internal_samples=True): @@ -3832,7 +3832,7 @@ def test_copy_tracked_samples(self): def test_copy_multiple_roots(self): ts = msprime.simulate(20, recombination_rate=2, length=3, random_seed=42) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for root_threshold in [1, 2, 100]: tree = tskit.Tree(ts, root_threshold=root_threshold) copy = tree.copy() diff --git a/python/tests/test_parsimony.py b/python/tests/test_parsimony.py index f48580a87f..80cec45b96 100644 --- a/python/tests/test_parsimony.py +++ b/python/tests/test_parsimony.py @@ -678,17 +678,17 @@ def test_jukes_cantor_balanced_ternary_internal_samples(self): def test_infinite_sites_n20_multiroot(self): ts = msprime.simulate(20, mutation_rate=3, random_seed=3) - self.verify(tsutil.decapitate(ts, ts.num_edges // 2)) + self.verify(ts.decapitate(np.max(ts.tables.nodes.time) / 2)) def test_jukes_cantor_n15_multiroot(self): ts = msprime.simulate(15, random_seed=1) - ts = tsutil.decapitate(ts, ts.num_edges // 3) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 5) ts = tsutil.jukes_cantor(ts, 15, 2, seed=3) self.verify(ts) def test_jukes_cantor_balanced_ternary_multiroot(self): ts = tskit.Tree.generate_balanced(50, arity=3).tree_sequence - ts = tsutil.decapitate(ts, ts.num_edges // 3) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 3) ts = tsutil.jukes_cantor(ts, 15, 2, seed=3) self.verify(ts) assert ts.num_sites > 1 @@ -696,7 +696,7 @@ def test_jukes_cantor_balanced_ternary_multiroot(self): def test_jukes_cantor_n50_multiroot(self): ts = msprime.simulate(50, random_seed=1) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) ts = tsutil.jukes_cantor(ts, 5, 2, seed=2) self.verify(ts) @@ -1389,7 +1389,7 @@ def test_mutations_over_root(self): def test_all_isolated_different_from_ancestral(self): ts = tskit.Tree.generate_star(6).tree_sequence - ts = tsutil.decapitate(ts, 0) + ts = ts.decapitate(0) tree = ts.first() genotypes = [0, 0, 0, 1, 1, 1] ancestral_state, transitions = self.do_map_mutations( diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index af0f315215..3656828ea5 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -315,12 +315,12 @@ def test_nonbinary_trees(self): def test_many_multiroot_trees(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) self.verify(ts) def test_multiroot_tree(self): ts = msprime.simulate(15, random_seed=10) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) self.verify(ts) def test_all_missing_data(self): @@ -4832,7 +4832,7 @@ def verify_single_childified(self, ts, keep_unary=False): assert t1.mutations == t2.mutations def verify_multiroot_internal_samples(self, ts, keep_unary=False): - ts_multiroot = tsutil.decapitate(ts, ts.num_edges // 2) + ts_multiroot = ts.decapitate(np.max(ts.tables.nodes.time) / 2) ts1 = tsutil.jiggle_samples(ts_multiroot) ts2, node_map = self.do_simplify(ts1, keep_unary=keep_unary) assert ts1.num_trees >= ts2.num_trees @@ -5556,7 +5556,7 @@ def test_many_trees_recurrent_mutations(self): def test_single_multiroot_tree_recurrent_mutations(self): ts = msprime.simulate(6, random_seed=10) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: ts = tsutil.insert_branch_mutations(ts, mutations_per_branch) for num_samples in range(1, ts.num_samples): @@ -5567,7 +5567,7 @@ def test_single_multiroot_tree_recurrent_mutations(self): def test_many_multiroot_trees_recurrent_mutations(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: ts = tsutil.insert_branch_mutations(ts, mutations_per_branch) for num_samples in range(1, ts.num_samples): @@ -5716,7 +5716,7 @@ def test_many_trees_internal_samples(self): def test_many_multiroot_trees(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for num_samples in range(1, ts.num_samples): for samples in itertools.combinations(ts.samples(), num_samples): self.verify_keep_input_roots(ts, samples) @@ -6051,7 +6051,7 @@ def test_sim_coalescent_trees_internal_samples(self): def test_sim_many_multiroot_trees(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) ancestors = [4 * n for n in np.arange(0, ts.num_nodes // 4)] self.verify(ts, ts.samples(), ancestors) random_samples = [4 * n for n in np.arange(0, ts.num_nodes // 4)] @@ -6189,14 +6189,14 @@ def test_single_tree_three_mutations_per_branch(self): def test_single_multiroot_tree_recurrent_mutations(self): ts = msprime.simulate(6, random_seed=10) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: self.verify_branch_mutations(ts, mutations_per_branch) def test_many_multiroot_trees_recurrent_mutations(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: self.verify_branch_mutations(ts, mutations_per_branch) @@ -6245,14 +6245,14 @@ def test_single_tree_three_mutations_per_branch(self): def test_single_multiroot_tree_recurrent_mutations(self): ts = msprime.simulate(6, random_seed=10) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: self.verify_branch_mutations(ts, mutations_per_branch) def test_many_multiroot_trees_recurrent_mutations(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: self.verify_branch_mutations(ts, mutations_per_branch) @@ -6396,14 +6396,14 @@ def test_single_tree_three_mutations_per_branch(self): def test_single_multiroot_tree_recurrent_mutations(self): ts = msprime.simulate(6, random_seed=10) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: self.verify_branch_mutations(ts, mutations_per_branch) def test_many_multiroot_trees_recurrent_mutations(self): ts = msprime.simulate(7, recombination_rate=1, random_seed=10) assert ts.num_trees > 3 - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) for mutations_per_branch in [1, 2, 3]: self.verify_branch_mutations(ts, mutations_per_branch) @@ -7178,13 +7178,6 @@ def test_zero_sites(self): assert mts.num_trees == 1 assert mts.num_edges == 0 - def test_many_roots(self): - ts = msprime.simulate(25, random_seed=12, recombination_rate=2, length=10) - tables = tsutil.decapitate(ts, ts.num_edges // 2).dump_tables() - for x in range(10): - tables.sites.add_row(x, "0") - self.verify(tables.tree_sequence()) - def test_branch_sites(self): ts = msprime.simulate(15, random_seed=12, recombination_rate=2, length=10) ts = tsutil.insert_branch_sites(ts) diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index aa69d3d4e8..3d49a79c14 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -581,8 +581,8 @@ def test_single_tree_sequence_length(self): self.verify(ts) def test_single_tree_multiple_roots(self): - ts = msprime.simulate(8, random_seed=1) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = msprime.simulate(8, random_seed=1, end_time=0.5) + assert ts.first().num_roots > 1 self.verify(ts) def test_many_trees(self): diff --git a/python/tests/test_utilities.py b/python/tests/test_utilities.py index bca37df3ae..710fe24933 100644 --- a/python/tests/test_utilities.py +++ b/python/tests/test_utilities.py @@ -44,13 +44,13 @@ def verify(self, ts): def test_n10_multiroot(self): ts = msprime.simulate(10, random_seed=1) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) ts = tsutil.jukes_cantor(ts, 1, 2, seed=7) self.verify(ts) def test_n50_multiroot(self): ts = msprime.simulate(50, random_seed=1) - ts = tsutil.decapitate(ts, ts.num_edges // 2) + ts = ts.decapitate(np.max(ts.tables.nodes.time) / 2) ts = tsutil.jukes_cantor(ts, 5, 2, seed=2) self.verify(ts) diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index eeb5b4d689..00dec81fba 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -76,24 +76,6 @@ def subsample_sites(ts, num_sites): return t.tree_sequence() -def decapitate(ts, num_edges): - """ - Returns a copy of the specified tree sequence in which the specified number of - edges have been retained. - """ - t = ts.dump_tables() - t.edges.set_columns( - left=t.edges.left[:num_edges], - right=t.edges.right[:num_edges], - parent=t.edges.parent[:num_edges], - child=t.edges.child[:num_edges], - ) - add_provenance(t.provenances, "decapitate") - # Simplify to get rid of any mutations that are lying around above roots. - t.simplify() - return t.tree_sequence() - - def insert_branch_mutations(ts, mutations_per_branch=1): """ Returns a copy of the specified tree sequence with a mutation on every branch