diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index a23c050125..ef0cf3701c 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -17,6 +17,9 @@ - Added ``Tree.generate_star`` static method to create star-topologies (:user:`hyanwong`, :pr:`934`). +- Added ``Tree.generate_comb`` and ``Tree.generate_balanced`` methods to create + example trees. (:user:`jeromekelleher`, :pr:`1026`). + - Added ``equals`` method to TreeSequence, TableCollection and each of the tables which provides more flexible equality comparisons, for example, allowing users to ignore metadata or provenance in the comparison. diff --git a/python/tests/test_combinatorics.py b/python/tests/test_combinatorics.py index 07c4bb4439..0882a2ea2d 100644 --- a/python/tests/test_combinatorics.py +++ b/python/tests/test_combinatorics.py @@ -26,6 +26,7 @@ import collections import io import itertools +import json import random import msprime @@ -113,18 +114,18 @@ def test_group_partition(self): class TestRankTree: - def test_num_shapes(self): - for i in range(11): - all_trees = RankTree.all_unlabelled_trees(i) - assert len(list(all_trees)) == comb.num_shapes(i) - - def test_num_labellings(self): - for n in range(2, 8): - for tree in RankTree.all_unlabelled_trees(n): - tree = tree.label_unrank(0) - tree2 = tree.to_tsk_tree() - n_labellings = sum(1 for _ in RankTree.all_labellings(tree)) - assert n_labellings == RankTree.from_tsk_tree(tree2).num_labellings() + @pytest.mark.parametrize("n", range(11)) + def test_num_shapes(self, n): + all_trees = RankTree.all_unlabelled_trees(n) + assert len(list(all_trees)) == comb.num_shapes(n) + + @pytest.mark.parametrize("n", range(2, 8)) + def test_num_labellings(self, n): + for tree in RankTree.all_unlabelled_trees(n): + tree = tree.label_unrank(0) + tree2 = tree.to_tsk_tree() + n_labellings = sum(1 for _ in RankTree.all_labellings(tree)) + assert n_labellings == RankTree.from_tsk_tree(tree2).num_labellings() def test_num_labelled_trees(self): # Number of leaf-labelled trees with n leaves on OEIS @@ -209,23 +210,23 @@ def test_all_labellings_roundtrip(self): for rank_t, tsk_t in zip(rank_tree_labellings, tsk_tree_labellings): assert rank_t == RankTree.from_tsk_tree(tsk_t) - def test_unrank(self): - for n in range(6): - for shape_rank, t in enumerate(RankTree.all_unlabelled_trees(n)): - for label_rank, labelled_tree in enumerate(RankTree.all_labellings(t)): - unranked = RankTree.unrank(n, (shape_rank, label_rank)) - assert labelled_tree == unranked + @pytest.mark.parametrize("n", range(6)) + def test_unrank_labelled(self, n): + for shape_rank, t in enumerate(RankTree.all_unlabelled_trees(n)): + for label_rank, labelled_tree in enumerate(RankTree.all_labellings(t)): + unranked = RankTree.unrank(n, (shape_rank, label_rank)) + assert labelled_tree == unranked - # The number of labelled trees gets very big quickly - for n in range(6, 10): - for shape_rank in range(comb.num_shapes(n)): - rank = (shape_rank, 0) - unranked = RankTree.unrank(n, rank) - assert rank, unranked.rank() + @pytest.mark.parametrize("n", range(10)) + def test_unrank_unlabelled(self, n): + for shape_rank in range(comb.num_shapes(n)): + rank = (shape_rank, 0) + unranked = RankTree.unrank(n, rank) + assert rank, unranked.rank() - rank = (shape_rank, comb.num_labellings(n, shape_rank) - 1) - unranked = RankTree.unrank(n, rank) - assert rank, unranked.rank() + rank = (shape_rank, comb.num_labellings(n, shape_rank) - 1) + unranked = RankTree.unrank(n, rank) + assert rank, unranked.rank() def test_unrank_errors(self): self.verify_unrank_errors((-1, 0), 1) @@ -253,53 +254,52 @@ def verify_unrank_errors(self, rank, n): with pytest.raises(ValueError): tskit.Tree.unrank(n, rank) - def test_shape_rank(self): - for n in range(10): - for rank, tree in enumerate(RankTree.all_unlabelled_trees(n)): - assert tree.shape_rank() == rank - - def test_shape_unrank(self): - for n in range(6): - for rank, tree in enumerate(RankTree.all_unlabelled_trees(n)): - t = RankTree.shape_unrank(n, rank) - assert tree.shape_equal(t) - - for n in range(2, 9): - for shape_rank, tree in enumerate(RankTree.all_unlabelled_trees(n)): - tsk_tree = tskit.Tree.unrank(n, (shape_rank, 0)) - assert shape_rank == tree.shape_rank() - shape_rank, _ = tsk_tree.rank() - assert shape_rank == tree.shape_rank() - - def test_label_rank(self): - for n in range(7): - for tree in RankTree.all_unlabelled_trees(n): - for rank, labelled_tree in enumerate(RankTree.all_labellings(tree)): - assert labelled_tree.label_rank() == rank - - def test_label_unrank(self): - for n in range(7): - for shape_rank, tree in enumerate(RankTree.all_unlabelled_trees(n)): - for label_rank, labelled_tree in enumerate( - RankTree.all_labellings(tree) - ): - rank = (shape_rank, label_rank) - unranked = tree.label_unrank(label_rank) - assert labelled_tree.rank() == rank - assert unranked.rank() == rank - - def test_unrank_rank_round_trip(self): - for n in range(6): # Can do more but gets slow pretty quickly after 6 - for shape_rank in range(comb.num_shapes(n)): - tree = RankTree.shape_unrank(n, shape_rank) - tree = tree.label_unrank(0) - assert tree.shape_rank() == shape_rank - for label_rank in range(tree.num_labellings()): - tree = tree.label_unrank(label_rank) - assert tree.label_rank() == label_rank - tsk_tree = tree.label_unrank(label_rank).to_tsk_tree() - _, tsk_label_rank = tsk_tree.rank() - assert tsk_label_rank == label_rank + @pytest.mark.parametrize("n", range(6)) + def test_shape_rank(self, n): + for rank, tree in enumerate(RankTree.all_unlabelled_trees(n)): + assert tree.shape_rank() == rank + + @pytest.mark.parametrize("n", range(6)) + def test_shape_unrank(self, n): + for rank, tree in enumerate(RankTree.all_unlabelled_trees(n)): + t = RankTree.shape_unrank(n, rank) + assert tree.shape_equal(t) + + @pytest.mark.parametrize("n", range(2, 9)) + def test_shape_unrank_tsk_tree(self, n): + for shape_rank, tree in enumerate(RankTree.all_unlabelled_trees(n)): + tsk_tree = tskit.Tree.unrank(n, (shape_rank, 0)) + assert shape_rank == tree.shape_rank() + shape_rank, _ = tsk_tree.rank() + assert shape_rank == tree.shape_rank() + + @pytest.mark.parametrize("n", range(7)) + def test_label_rank(self, n): + for tree in RankTree.all_unlabelled_trees(n): + for rank, labelled_tree in enumerate(RankTree.all_labellings(tree)): + assert labelled_tree.label_rank() == rank + + @pytest.mark.parametrize("n", range(7)) + def test_label_unrank(self, n): + for shape_rank, tree in enumerate(RankTree.all_unlabelled_trees(n)): + for label_rank, labelled_tree in enumerate(RankTree.all_labellings(tree)): + rank = (shape_rank, label_rank) + unranked = tree.label_unrank(label_rank) + assert labelled_tree.rank() == rank + assert unranked.rank() == rank + + @pytest.mark.parametrize("n", range(6)) + def test_unrank_rank_round_trip(self, n): + for shape_rank in range(comb.num_shapes(n)): + tree = RankTree.shape_unrank(n, shape_rank) + tree = tree.label_unrank(0) + assert tree.shape_rank() == shape_rank + for label_rank in range(tree.num_labellings()): + tree = tree.label_unrank(label_rank) + assert tree.label_rank() == label_rank + tsk_tree = tree.label_unrank(label_rank).to_tsk_tree() + _, tsk_label_rank = tsk_tree.rank() + assert tsk_label_rank == label_rank def test_is_canonical(self): for n in range(7): @@ -347,25 +347,41 @@ def test_is_canonical(self): ) assert not labels_not_canonical.is_canonical() - def test_unranking_is_canonical(self): - for n in range(7): - for shape_rank in range(comb.num_shapes(n)): - for label_rank in range(comb.num_labellings(n, shape_rank)): - t = RankTree.shape_unrank(n, shape_rank) - assert t.is_canonical() - t = t.label_unrank(label_rank) - assert t.is_canonical() - t = tskit.Tree.unrank(n, (shape_rank, label_rank)) - assert RankTree.from_tsk_tree(t).is_canonical() - - def test_to_from_tsk_tree(self): - for n in range(5): - for tree in RankTree.all_labelled_trees(n): - assert tree.is_canonical() - tsk_tree = tree.to_tsk_tree() - reconstructed = RankTree.from_tsk_tree(tsk_tree) - assert tree.is_canonical() - assert tree == reconstructed + @pytest.mark.parametrize("n", range(7)) + def test_unranking_is_canonical(self, n): + for shape_rank in range(comb.num_shapes(n)): + for label_rank in range(comb.num_labellings(n, shape_rank)): + t = RankTree.shape_unrank(n, shape_rank) + assert t.is_canonical() + t = t.label_unrank(label_rank) + assert t.is_canonical() + t = tskit.Tree.unrank(n, (shape_rank, label_rank)) + assert RankTree.from_tsk_tree(t).is_canonical() + + @pytest.mark.parametrize("n", range(5)) + def test_to_from_tsk_tree(self, n): + for tree in RankTree.all_labelled_trees(n): + assert tree.is_canonical() + tsk_tree = tree.to_tsk_tree() + reconstructed = RankTree.from_tsk_tree(tsk_tree) + assert tree.is_canonical() + assert tree == reconstructed + + @pytest.mark.parametrize("n", range(6)) + def test_to_tsk_tree_internal_nodes(self, n): + branch_length = 1234 + for tree in RankTree.all_labelled_trees(n): + tsk_tree = tree.to_tsk_tree(branch_length=branch_length) + internal_nodes = [ + u for u in tsk_tree.nodes(order="postorder") if tsk_tree.is_internal(u) + ] + assert np.all(internal_nodes == n + np.arange(len(internal_nodes))) + for u in tsk_tree.nodes(): + if tsk_tree.is_internal(u): + max_child_time = max(tsk_tree.time(v) for v in tsk_tree.children(u)) + assert tsk_tree.time(u) == max_child_time + branch_length + else: + assert tsk_tree.time(u) == 0 def test_from_unary_tree(self): tables = tskit.TableCollection(sequence_length=1) @@ -984,14 +1000,11 @@ class TestTreeNode: Tests for the TreeNode class used to build simple trees in memory. """ - @pytest.mark.parametrize("n", [2, 3, 5, 10]) - def test_random_binary_tree(self, n): + def verify_tree(self, root, labels): # Note this doesn't check any statistical properties of the returned # trees, just that a single instance returned in a valid binary tree. - rng = random.Random(32) - labels = range(n) - root = comb.TreeNode.random_binary_tree(labels, rng) - + # Structural properties are best verified using the tskit API, and so + # we test these properties elsewhere. stack = [root] num_nodes = 0 recovered_labels = [] @@ -1004,5 +1017,421 @@ def test_random_binary_tree(self, n): for child in node.children: assert child.parent == node stack.append(child) - assert len(recovered_labels) == n assert sorted(recovered_labels) == list(labels) + + @pytest.mark.parametrize("n", range(1, 16)) + def test_random_binary_tree(self, n): + rng = random.Random(32) + labels = range(n) + root = comb.TreeNode.random_binary_tree(labels, rng) + self.verify_tree(root, range(n)) + + @pytest.mark.parametrize("n", range(1, 16)) + def test_balanced_binary(self, n): + root = comb.TreeNode.balanced_tree(range(n), 2) + self.verify_tree(root, range(n)) + + @pytest.mark.parametrize("arity", range(2, 8)) + def test_balanced_arity(self, arity): + labels = range(30) + root = comb.TreeNode.balanced_tree(labels, arity) + self.verify_tree(root, labels) + + +def num_leaf_labelled_binary_trees(n): + """ + Returns the number of leaf labelled binary trees with n leaves. + + TODO: this would probably be helpful to have in the combinatorics + module. + + https://oeis.org/A005373/ + """ + return int(np.math.factorial(2 * n - 3) / (2 ** (n - 2) * np.math.factorial(n - 2))) + + +class TestPolytomySplitting: + """ + Test the ability to randomly split polytomies + """ + + # A complex ts with polytomies + # + # 1.00┊ 6 ┊ 6 ┊ 6 ┊ ┊ 6 ┊ + # ┊ ┏━┳┻┳━┓ ┊ ┏━┳┻┳━┓ ┊ ┏━━╋━┓ ┊ ┊ ┏━┳┻┳━┓ ┊ + # 0.50┊ 5 ┃ ┃ ┃ ┊ 5 ┃ ┃ ┃ ┊ 5 ┃ ┃ ┊ 5 ┊ ┃ ┃ ┃ ┃ ┊ + # ┊ ┃ ┃ ┃ ┃ . ┊ ┃ ┃ ┃ ┃ ┊ . ┏┻┓ ┃ ┃ ┊ . ┏━┳┻┳━┓ ┊ . ┃ ┃ ┃ ┃ ┊ + # 0.00┊ 0 2 3 4 1 ┊ 0 1 2 3 4 ┊ 0 1 2 3 4 ┊ 0 1 2 3 4 ┊ 0 1 2 3 4 ┊ + # 0.00 0.20 0.40 0.60 0.80 1.00 + nodes_polytomy_44344 = """\ + id is_sample population time + 0 1 0 0.0 + 1 1 0 0.0 + 2 1 0 0.0 + 3 1 0 0.0 + 4 1 0 0.0 + 5 0 0 0.5 + 6 0 0 1.0 + """ + edges_polytomy_44344 = """\ + id left right parent child + 0 0.0 0.2 5 0 + 1 0.0 0.8 5 1 + 2 0.0 0.4 6 2 + 3 0.4 0.8 5 2 + 4 0.0 0.6 6 3,4 + 5 0.0 0.6 6 5 + 6 0.6 0.8 5 3,4 + 7 0.8 1.0 6 1,2,3,4 + """ + + def tree_polytomy_4(self): + return tskit.Tree.generate_star(4) + + def ts_polytomy_44344(self): + return tskit.load_text( + nodes=io.StringIO(self.nodes_polytomy_44344), + edges=io.StringIO(self.edges_polytomy_44344), + strict=False, + ) + + @pytest.mark.slow + @pytest.mark.parametrize("n", [2, 3, 4, 5]) + def test_all_topologies(self, n): + N = num_leaf_labelled_binary_trees(n) + ranks = collections.Counter() + tree = tskit.Tree.generate_star(n) + for seed in range(20 * N): + split_tree = tree.split_polytomies(random_seed=seed) + ranks[split_tree.rank()] += 1 + # There are N possible binary trees here, we should have seen them + # all with high probability after 20 N attempts. + assert len(ranks) == N + + def verify_trees(self, source_tree, split_tree, epsilon=None): + if epsilon is None: + epsilon = 1e-10 + N = 0 + for u in split_tree.nodes(): + assert split_tree.num_children(u) < 3 + N += 1 + if u >= source_tree.tree_sequence.num_nodes: + # This is a new node + assert epsilon == pytest.approx(split_tree.branch_length(u)) + assert N == len(list(split_tree.leaves())) * 2 - 1 + for u in source_tree.nodes(): + if source_tree.num_children(u) <= 2: + assert source_tree.children(u) == split_tree.children(u) + else: + assert len(split_tree.children(u)) == 2 + + @pytest.mark.parametrize("n", [2, 3, 4, 5, 6]) + def test_resolve_star(self, n): + tree = tskit.Tree.generate_star(n) + self.verify_trees(tree, tree.split_polytomies(random_seed=12)) + + def test_large_epsilon(self): + tree = tskit.Tree.generate_star(10, branch_length=100) + eps = 10 + split = tree.split_polytomies(random_seed=12234, epsilon=eps) + self.verify_trees(tree, split, epsilon=eps) + + def verify_tree_sequence_splits(self, ts): + n_poly = 0 + for e in ts.edgesets(): + if len(e.children) > 2: + n_poly += 1 + assert n_poly > 3 + assert ts.num_trees > 3 + for tree in ts.trees(): + binary_tree = tree.split_polytomies(random_seed=11) + assert binary_tree.interval == tree.interval + for u in binary_tree.nodes(): + assert binary_tree.num_children(u) < 3 + for u in tree.nodes(): + assert binary_tree.time(u) == tree.time(u) + resolved_ts = binary_tree.tree_sequence + assert resolved_ts.sequence_length == ts.sequence_length + assert resolved_ts.num_trees <= 3 + if tree.interval[0] == 0: + assert resolved_ts.num_trees == 2 + null_tree = resolved_ts.last() + assert null_tree.num_roots == ts.num_samples + elif tree.interval[1] == ts.sequence_length: + assert resolved_ts.num_trees == 2 + null_tree = resolved_ts.first() + assert null_tree.num_roots == ts.num_samples + else: + null_tree = resolved_ts.first() + assert null_tree.num_roots == ts.num_samples + null_tree.next() + assert null_tree.num_roots == tree.num_roots + null_tree.next() + assert null_tree.num_roots == ts.num_samples + + def test_complex_examples(self): + self.verify_tree_sequence_splits(self.ts_polytomy_44344()) + + def test_nonbinary_simulation(self): + demographic_events = [ + msprime.SimpleBottleneck(time=1.0, population=0, proportion=0.95) + ] + ts = msprime.simulate( + 20, + recombination_rate=10, + mutation_rate=5, + demographic_events=demographic_events, + random_seed=7, + ) + self.verify_tree_sequence_splits(ts) + + def test_seeds(self): + base = tskit.Tree.generate_star(5) + t1 = base.split_polytomies(random_seed=1234) + t2 = base.split_polytomies(random_seed=1234) + assert t1.tree_sequence.tables.equals( + t2.tree_sequence.tables, ignore_timestamps=True + ) + t2 = base.split_polytomies(random_seed=1) + assert not t1.tree_sequence.tables.equals( + t2.tree_sequence.tables, ignore_provenance=True + ) + + def test_internal_polytomy(self): + # 9 + # ┏━┳━━━┻┳━━━━┓ + # ┃ ┃ 8 ┃ + # ┃ ┃ ┏━━╋━━┓ ┃ + # ┃ ┃ ┃ 7 ┃ ┃ + # ┃ ┃ ┃ ┏┻┓ ┃ ┃ + # 0 1 2 3 5 4 6 + t1 = tskit.Tree.unrank(7, (6, 25)) + t2 = t1.split_polytomies(random_seed=1234) + assert t2.parent(3) == 7 + assert t2.parent(5) == 7 + assert t2.root == 9 + for u in t2.nodes(): + assert t2.num_children(u) in [0, 2] + + def test_binary_tree(self): + t1 = msprime.simulate(10, random_seed=1234).first() + t2 = t1.split_polytomies(random_seed=1234) + tables = t1.tree_sequence.dump_tables() + assert tables.equals(t2.tree_sequence.tables, ignore_provenance=True) + + def test_bad_method(self): + with pytest.raises(ValueError, match="Method"): + self.tree_polytomy_4().split_polytomies(method="something_else") + + def test_epsilon_for_nodes(self): + with pytest.raises( + tskit.LibraryError, + match="not small enough to create new nodes below a polytomy", + ): + self.tree_polytomy_4().split_polytomies(epsilon=1, random_seed=12) + + def test_epsilon_for_mutations(self): + tables = tskit.Tree.generate_star(3).tree_sequence.dump_tables() + root_time = tables.nodes.time[-1] + assert root_time == 1 + site = tables.sites.add_row(position=0.5, ancestral_state="0") + tables.mutations.add_row(site=site, time=0.9, node=0, derived_state="1") + tables.mutations.add_row(site=site, time=0.9, node=1, derived_state="1") + tree = tables.tree_sequence().first() + with pytest.raises( + tskit.LibraryError, + match="not small enough to create new nodes below a polytomy", + ): + tree.split_polytomies(epsilon=0.5, random_seed=123) + + def test_provenance(self): + tree = self.tree_polytomy_4() + ts_split = tree.split_polytomies(random_seed=14).tree_sequence + record = json.loads(ts_split.provenance(ts_split.num_provenances - 1).record) + assert record["parameters"]["command"] == "split_polytomies" + ts_split = tree.split_polytomies( + random_seed=12, record_provenance=False + ).tree_sequence + record = json.loads(ts_split.provenance(ts_split.num_provenances - 1).record) + assert record["parameters"]["command"] != "split_polytomies" + + +class TreeGeneratorTestBase: + """ + Abstract superclass of tree generator test methods. + + Concrete subclasses should defined "method_name" class variable. + """ + + def method(self, n, **kwargs): + return getattr(tskit.Tree, self.method_name)(n, **kwargs) + + @pytest.mark.parametrize("n", range(2, 10)) + def test_leaves(self, n): + tree = self.method(n) + assert list(tree.leaves()) == list(range(n)) + + def test_bad_n(self): + for n in [-1, 0, np.array([1, 2])]: + with pytest.raises(ValueError): + self.method(n) + for n in [None, "", []]: + with pytest.raises(TypeError): + self.method(n) + + def test_bad_span(self): + with pytest.raises(tskit.LibraryError): + self.method(2, span=0) + + def test_bad_branch_length(self): + with pytest.raises(tskit.LibraryError): + self.method(2, branch_length=0) + + @pytest.mark.parametrize("span", [0.1, 1, 100]) + def test_span(self, span): + tree = self.method(5, span=span) + assert tree.tree_sequence.sequence_length == span + + @pytest.mark.parametrize("branch_length", [0.25, 1, 100]) + def test_branch_length(self, branch_length): + tree = self.method(5, branch_length=branch_length) + for u in tree.nodes(): + if u != tree.root: + assert tree.branch_length(u) >= branch_length + + def test_provenance(self): + ts = self.method(2).tree_sequence + assert ts.num_provenances == 1 + record = json.loads(ts.provenance(0).record) + assert record["parameters"]["command"] == self.method_name + ts = self.method(2, record_provenance=False).tree_sequence + assert ts.num_provenances == 0 + + @pytest.mark.parametrize("n", range(2, 10)) + def test_rank_unrank_round_trip(self, n): + tree1 = self.method(n) + rank = tree1.rank() + tree2 = tskit.Tree.unrank(n, rank) + tables1 = tree1.tree_sequence.tables + tables2 = tree2.tree_sequence.tables + assert tables1.equals(tables2, ignore_provenance=True) + + +class TestGenerateStar(TreeGeneratorTestBase): + method_name = "generate_star" + + @pytest.mark.parametrize("n", range(2, 10)) + def test_unrank_equal(self, n): + for extra_params in [{}, {"span": 2.5}, {"branch_length": 3}]: + ts = tskit.Tree.generate_star(n, **extra_params).tree_sequence + equiv_ts = tskit.Tree.unrank(n, (0, 0), **extra_params).tree_sequence + assert ts.tables.equals(equiv_ts.tables, ignore_provenance=True) + + def test_branch_length_semantics(self): + branch_length = 10 + ts = tskit.Tree.generate_star(7, branch_length=branch_length).tree_sequence + time = ts.tables.nodes.time + edges = ts.tables.edges + length = time[edges.parent] - time[edges.child] + assert np.all(length == branch_length) + + +class TestGenerateBalanced(TreeGeneratorTestBase): + method_name = "generate_balanced" + + @pytest.mark.parametrize("arity", range(2, 10)) + def test_arity_leaves(self, arity): + n = 20 + tree = tskit.Tree.generate_balanced(n, arity=arity) + assert list(tree.leaves()) == list(range(n)) + + @pytest.mark.parametrize("n", range(1, 13)) + def test_binary_unrank_equal(self, n): + for extra_params in [{}, {"span": 2.5}, {"branch_length": 3}]: + ts = tskit.Tree.generate_balanced(n, **extra_params).tree_sequence + N = tskit.combinatorics.num_shapes(n) + equiv_ts = tskit.Tree.unrank(n, (N - 1, 0), **extra_params).tree_sequence + assert ts.tables.equals(equiv_ts.tables, ignore_provenance=True) + + @pytest.mark.parametrize( + ("n", "arity"), [(2, 2), (8, 2), (27, 3), (29, 3), (11, 5), (5, 10)] + ) + def test_rank_unrank_round_trip_arity(self, n, arity): + tree1 = tskit.Tree.generate_balanced(n, arity=arity) + rank = tree1.rank() + tree2 = tskit.Tree.unrank(n, rank) + tables1 = tree1.tree_sequence.tables + tables2 = tree2.tree_sequence.tables + assert tables1.equals(tables2, ignore_provenance=True) + + def test_bad_arity(self): + for arity in [-1, 0, 1]: + with pytest.raises(ValueError): + tskit.Tree.generate_balanced(10, arity=arity) + + def test_branch_length_semantics(self): + branch_length = 10 + tree = tskit.Tree.generate_balanced(8, branch_length=branch_length) + for u in tree.nodes(): + for v in tree.children(u): + # Special case cause n is a power of 2 + assert tree.time(u) == tree.time(v) + branch_length + + +class TestGenerateComb(TreeGeneratorTestBase): + method_name = "generate_comb" + + # Hard-code in some pre-computed ranks for the comb(n) tree. + @pytest.mark.parametrize(["n", "rank"], [(2, 0), (3, 1), (4, 3), (5, 8), (6, 20)]) + def test_unrank_equal(self, n, rank): + for extra_params in [{}, {"span": 2.5}, {"branch_length": 3}]: + ts = tskit.Tree.generate_comb(n, **extra_params).tree_sequence + equiv_ts = tskit.Tree.unrank(n, (rank, 0), **extra_params).tree_sequence + assert ts.tables.equals(equiv_ts.tables, ignore_provenance=True) + + def test_branch_length_semantics(self): + branch_length = 10 + tree = tskit.Tree.generate_comb(2, branch_length=branch_length) + assert tree.time(tree.root) == branch_length + + +class TestEqualChunks: + @pytest.mark.parametrize(("n", "k"), [(2, 1), (4, 2), (9, 3), (100, 10)]) + def test_evenly_divisible(self, n, k): + lst = range(n) + chunks = list(comb.equal_chunks(lst, k)) + assert len(chunks) == k + for chunk in chunks: + assert len(chunk) == n // k + assert list(itertools.chain(*chunks)) == list(range(n)) + + @pytest.mark.parametrize("n", range(1, 5)) + def test_one_chunk(self, n): + lst = list(range(n)) + chunks = list(comb.equal_chunks(lst, 1)) + assert chunks == [lst] + + @pytest.mark.parametrize(("n", "k"), [(1, 2), (5, 6), (10, 20), (5, 100)]) + def test_empty_chunks(self, n, k): + lst = range(n) + chunks = list(comb.equal_chunks(lst, k)) + assert len(chunks) == n + for chunk in chunks: + assert len(chunk) == 1 + assert list(itertools.chain(*chunks)) == list(range(n)) + + @pytest.mark.parametrize(("n", "k"), [(3, 2), (10, 3), (11, 5), (13, 10)]) + def test_trailing_chunk(self, n, k): + lst = range(n) + chunks = list(comb.equal_chunks(lst, k)) + assert len(chunks) == k + assert list(itertools.chain(*chunks)) == list(range(n)) + + def test_empty_list(self): + assert len(list(comb.equal_chunks([], 1))) == 0 + assert len(list(comb.equal_chunks([], 2))) == 0 + + def test_bad_num_chunks(self): + for bad_num_chunks in [0, -1, 0.1]: + with pytest.raises(ValueError): + list(comb.equal_chunks([1], bad_num_chunks)) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index b61222353e..b135afc234 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -23,7 +23,6 @@ """ Test cases for the supported topological variations and operations. """ -import collections import functools import io import itertools @@ -7977,255 +7976,3 @@ def test_is_isolated_bad(self): tree.is_isolated("abc") with pytest.raises(TypeError): tree.is_isolated(1.1) - - -class TestExampleTrees: - """ - Test hard-coded example tree (sequence) generation - """ - - def test_star_equivalent(self): - for extra_params in [{}, {"span": 2.5}]: - for n in range(2, 6): - ts = tskit.Tree.generate_star(n, **extra_params).tree_sequence - equiv_ts = tskit.Tree.unrank(n, (0, 0), **extra_params).tree_sequence - assert ts.tables.equals(equiv_ts.tables, ignore_provenance=True) - - def test_star_bad_params(self): - for n in [-1, 0, 1, np.array([1, 2])]: - with pytest.raises(ValueError): - tskit.Tree.generate_star(n) - for n in [None, "", []]: - with pytest.raises(TypeError): - tskit.Tree.generate_star(n) - with pytest.raises(tskit.LibraryError): - tskit.Tree.generate_star(2, span=0) - with pytest.raises(tskit.LibraryError): - tskit.Tree.generate_star(2, branch_length=0) - - def test_star_branch_length(self): - branch_length = 10 - n = 7 - ts = tskit.Tree.generate_star(n, branch_length=branch_length).tree_sequence - topological_equiv_ts = tskit.Tree.unrank(n, (0, 0)).tree_sequence - - assert ts.node(ts.first().root).time == branch_length - assert ts.kc_distance(topological_equiv_ts) == 0 - - -def num_leaf_labelled_binary_trees(n): - """ - Returns the number of leaf labelled binary trees with n leaves. - - TODO: this would probably be helpful to have in the combinatorics - module. - - https://oeis.org/A005373/ - """ - return int(np.math.factorial(2 * n - 3) / (2 ** (n - 2) * np.math.factorial(n - 2))) - - -class TestPolytomySplitting: - """ - Test the ability to randomly split polytomies - """ - - # A complex ts with polytomies - # - # 1.00┊ 6 ┊ 6 ┊ 6 ┊ ┊ 6 ┊ - # ┊ ┏━┳┻┳━┓ ┊ ┏━┳┻┳━┓ ┊ ┏━━╋━┓ ┊ ┊ ┏━┳┻┳━┓ ┊ - # 0.50┊ 5 ┃ ┃ ┃ ┊ 5 ┃ ┃ ┃ ┊ 5 ┃ ┃ ┊ 5 ┊ ┃ ┃ ┃ ┃ ┊ - # ┊ ┃ ┃ ┃ ┃ . ┊ ┃ ┃ ┃ ┃ ┊ . ┏┻┓ ┃ ┃ ┊ . ┏━┳┻┳━┓ ┊ . ┃ ┃ ┃ ┃ ┊ - # 0.00┊ 0 2 3 4 1 ┊ 0 1 2 3 4 ┊ 0 1 2 3 4 ┊ 0 1 2 3 4 ┊ 0 1 2 3 4 ┊ - # 0.00 0.20 0.40 0.60 0.80 1.00 - nodes_polytomy_44344 = """\ - id is_sample population time - 0 1 0 0.0 - 1 1 0 0.0 - 2 1 0 0.0 - 3 1 0 0.0 - 4 1 0 0.0 - 5 0 0 0.5 - 6 0 0 1.0 - """ - edges_polytomy_44344 = """\ - id left right parent child - 0 0.0 0.2 5 0 - 1 0.0 0.8 5 1 - 2 0.0 0.4 6 2 - 3 0.4 0.8 5 2 - 4 0.0 0.6 6 3,4 - 5 0.0 0.6 6 5 - 6 0.6 0.8 5 3,4 - 7 0.8 1.0 6 1,2,3,4 - """ - - def tree_polytomy_4(self): - return tskit.Tree.generate_star(4) - - def ts_polytomy_44344(self): - return tskit.load_text( - nodes=io.StringIO(self.nodes_polytomy_44344), - edges=io.StringIO(self.edges_polytomy_44344), - strict=False, - ) - - @pytest.mark.slow - @pytest.mark.parametrize("n", [2, 3, 4, 5]) - def test_all_topologies(self, n): - N = num_leaf_labelled_binary_trees(n) - ranks = collections.Counter() - tree = tskit.Tree.generate_star(n) - for seed in range(20 * N): - split_tree = tree.split_polytomies(random_seed=seed) - ranks[split_tree.rank()] += 1 - # There are N possible binary trees here, we should have seen them - # all with high probability after 20 N attempts. - assert len(ranks) == N - - def verify_trees(self, source_tree, split_tree, epsilon=None): - if epsilon is None: - epsilon = 1e-10 - N = 0 - for u in split_tree.nodes(): - assert split_tree.num_children(u) < 3 - N += 1 - if u >= source_tree.tree_sequence.num_nodes: - # This is a new node - assert epsilon == pytest.approx(split_tree.branch_length(u)) - assert N == len(list(split_tree.leaves())) * 2 - 1 - for u in source_tree.nodes(): - if source_tree.num_children(u) <= 2: - assert source_tree.children(u) == split_tree.children(u) - else: - assert len(split_tree.children(u)) == 2 - - @pytest.mark.parametrize("n", [2, 3, 4, 5, 6]) - def test_resolve_star(self, n): - tree = tskit.Tree.generate_star(n) - self.verify_trees(tree, tree.split_polytomies(random_seed=12)) - - def test_large_epsilon(self): - tree = tskit.Tree.generate_star(10, branch_length=100) - eps = 10 - split = tree.split_polytomies(random_seed=12234, epsilon=eps) - self.verify_trees(tree, split, epsilon=eps) - - def verify_tree_sequence_splits(self, ts): - n_poly = 0 - for e in ts.edgesets(): - if len(e.children) > 2: - n_poly += 1 - assert n_poly > 3 - assert ts.num_trees > 3 - for tree in ts.trees(): - binary_tree = tree.split_polytomies(random_seed=11) - assert binary_tree.interval == tree.interval - for u in binary_tree.nodes(): - assert binary_tree.num_children(u) < 3 - for u in tree.nodes(): - assert binary_tree.time(u) == tree.time(u) - resolved_ts = binary_tree.tree_sequence - assert resolved_ts.sequence_length == ts.sequence_length - assert resolved_ts.num_trees <= 3 - if tree.interval[0] == 0: - assert resolved_ts.num_trees == 2 - null_tree = resolved_ts.last() - assert null_tree.num_roots == ts.num_samples - elif tree.interval[1] == ts.sequence_length: - assert resolved_ts.num_trees == 2 - null_tree = resolved_ts.first() - assert null_tree.num_roots == ts.num_samples - else: - null_tree = resolved_ts.first() - assert null_tree.num_roots == ts.num_samples - null_tree.next() - assert null_tree.num_roots == tree.num_roots - null_tree.next() - assert null_tree.num_roots == ts.num_samples - - def test_complex_examples(self): - self.verify_tree_sequence_splits(self.ts_polytomy_44344()) - - def test_nonbinary_simulation(self): - demographic_events = [ - msprime.SimpleBottleneck(time=1.0, population=0, proportion=0.95) - ] - ts = msprime.simulate( - 20, - recombination_rate=10, - mutation_rate=5, - demographic_events=demographic_events, - random_seed=7, - ) - self.verify_tree_sequence_splits(ts) - - def test_seeds(self): - base = tskit.Tree.generate_star(5) - t1 = base.split_polytomies(random_seed=1234) - t2 = base.split_polytomies(random_seed=1234) - assert t1.tree_sequence.tables.equals( - t2.tree_sequence.tables, ignore_timestamps=True - ) - t2 = base.split_polytomies(random_seed=1) - assert not t1.tree_sequence.tables.equals( - t2.tree_sequence.tables, ignore_provenance=True - ) - - def test_internal_polytomy(self): - # 9 - # ┏━┳━━━┻┳━━━━┓ - # ┃ ┃ 8 ┃ - # ┃ ┃ ┏━━╋━━┓ ┃ - # ┃ ┃ ┃ 7 ┃ ┃ - # ┃ ┃ ┃ ┏┻┓ ┃ ┃ - # 0 1 2 3 5 4 6 - t1 = tskit.Tree.unrank(7, (6, 25)) - t2 = t1.split_polytomies(random_seed=1234) - assert t2.parent(3) == 7 - assert t2.parent(5) == 7 - assert t2.root == 9 - for u in t2.nodes(): - assert t2.num_children(u) in [0, 2] - - def test_binary_tree(self): - t1 = msprime.simulate(10, random_seed=1234).first() - t2 = t1.split_polytomies(random_seed=1234) - tables = t1.tree_sequence.dump_tables() - assert tables.equals(t2.tree_sequence.tables, ignore_provenance=True) - - def test_bad_method(self): - with pytest.raises(ValueError, match="Method"): - self.tree_polytomy_4().split_polytomies(method="something_else") - - def test_epsilon_for_nodes(self): - with pytest.raises( - _tskit.LibraryError, - match="not small enough to create new nodes below a polytomy", - ): - self.tree_polytomy_4().split_polytomies(epsilon=1, random_seed=12) - - def test_epsilon_for_mutations(self): - tables = tskit.Tree.generate_star(3).tree_sequence.dump_tables() - root_time = tables.nodes.time[-1] - assert root_time == 1 - site = tables.sites.add_row(position=0.5, ancestral_state="0") - tables.mutations.add_row(site=site, time=0.9, node=0, derived_state="1") - tables.mutations.add_row(site=site, time=0.9, node=1, derived_state="1") - tree = tables.tree_sequence().first() - with pytest.raises( - _tskit.LibraryError, - match="not small enough to create new nodes below a polytomy", - ): - tree.split_polytomies(epsilon=0.5, random_seed=123) - - def test_provenance(self): - tree = self.tree_polytomy_4() - ts_split = tree.split_polytomies(random_seed=14).tree_sequence - record = json.loads(ts_split.provenance(ts_split.num_provenances - 1).record) - assert record["parameters"]["command"] == "split_polytomies" - ts_split = tree.split_polytomies( - random_seed=12, record_provenance=False - ).tree_sequence - record = json.loads(ts_split.provenance(ts_split.num_provenances - 1).record) - assert record["parameters"]["command"] != "split_polytomies" diff --git a/python/tskit/combinatorics.py b/python/tskit/combinatorics.py index e55f546c54..ab364bd4c5 100644 --- a/python/tskit/combinatorics.py +++ b/python/tskit/combinatorics.py @@ -37,6 +37,30 @@ import tskit +def equal_chunks(lst, k): + """ + Yield k successive equally sized chunks from lst of size n. + + If k >= n, we return n chunks of size 1. + + Otherwise, we always return k chunks. The first k - 1 chunks will + contain exactly n // k items, and the last chunk the remainder. + """ + n = len(lst) + if k <= 0 or int(k) != k: + raise ValueError("Number of chunks must be a positive integer") + + if n > 0: + chunk_size = max(1, n // k) + offset = 0 + j = 0 + while offset < n - chunk_size and j < k - 1: + yield lst[offset : offset + chunk_size] + offset += chunk_size + j += 1 + yield lst[offset:] + + @attr.s(eq=False) class TreeNode: """ @@ -82,6 +106,135 @@ def random_binary_tree(leaf_labels, rng): root = root.parent return root + @classmethod + def balanced_tree(cls, leaf_labels, arity): + """ + Returns a balanced tree of the specified arity. At each node the + leaf labels are split equally among the arity children using the + equal_chunks method. + """ + assert len(leaf_labels) > 0 + if len(leaf_labels) == 1: + root = cls(label=leaf_labels[0]) + else: + children = [ + cls.balanced_tree(chunk, arity) + for chunk in equal_chunks(leaf_labels, arity) + ] + root = cls(children=children) + for child in children: + child.parent = root + return root + + +def generate_star(num_leaves, *, span, branch_length, record_provenance): + """ + Generate a star tree for the specified number of leaves. + + See the documentation for Tree.generate_balanced for more details. + """ + if num_leaves < 2: + raise ValueError("The number of leaves must be 2 or greater") + tables = tskit.TableCollection(sequence_length=span) + tables.nodes.set_columns( + flags=np.full(num_leaves, tskit.NODE_IS_SAMPLE, dtype=np.uint32), + time=np.zeros(num_leaves), + ) + root = tables.nodes.add_row(time=branch_length) + tables.edges.set_columns( + left=np.full(num_leaves, 0), + right=np.full(num_leaves, span), + parent=np.full(num_leaves, root, dtype=np.int32), + child=np.arange(num_leaves, dtype=np.int32), + ) + if record_provenance: + # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 + # TODO also make sure we convert all the arguments so that they are + # definitely JSON encodable. + parameters = {"command": "generate_star", "TODO": "add parameters"} + tables.provenances.add_row( + record=json.dumps(tskit.provenance.get_provenance_dict(parameters)) + ) + return tables.tree_sequence().first() + + +def generate_comb(num_leaves, *, span, branch_length, record_provenance): + """ + Generate a comb tree for the specified number of leaves. + + See the documentation for Tree.generate_balanced for more details. + """ + if num_leaves < 2: + raise ValueError("The number of leaves must be 2 or greater") + tables = tskit.TableCollection(sequence_length=span) + tables.nodes.set_columns( + flags=np.full(num_leaves, tskit.NODE_IS_SAMPLE, dtype=np.uint32), + time=np.zeros(num_leaves), + ) + right_child = num_leaves - 1 + time = branch_length + for left_child in range(num_leaves - 2, -1, -1): + parent = tables.nodes.add_row(time=time) + time += branch_length + tables.edges.add_row(0, span, parent, left_child) + tables.edges.add_row(0, span, parent, right_child) + right_child = parent + + if record_provenance: + # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 + # TODO also make sure we convert all the arguments so that they are + # definitely JSON encodable. + parameters = {"command": "generate_comb", "TODO": "add parameters"} + tables.provenances.add_row( + record=json.dumps(tskit.provenance.get_provenance_dict(parameters)) + ) + return tables.tree_sequence().first() + + +def generate_balanced(num_leaves, *, arity, span, branch_length, record_provenance): + """ + Generate a balanced tree for the specified number of leaves. + + See the documentation for Tree.generate_balanced for more details. + """ + if num_leaves < 1: + raise ValueError("The number of leaves must be at least 1") + if arity < 2: + raise ValueError("The arity must be at least 2") + + tables = tskit.TableCollection(span) + for _ in range(num_leaves): + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + + def assign_internal_labels(node): + if len(node.children) == 0: + node.time = 0 + else: + max_child_time = 0 + for child in node.children: + assign_internal_labels(child) + max_child_time = max(max_child_time, child.time) + node.time = max_child_time + branch_length + node.label = tables.nodes.add_row(time=node.time) + for child in node.children: + tables.edges.add_row(0, span, node.label, child.label) + + root = TreeNode.balanced_tree(range(num_leaves), arity) + # Do a postorder traversal to assign the internal node labels and times. + assign_internal_labels(root) + tables.sort() + + if record_provenance: + # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 + # TODO also make sure we convert all the arguments so that they are + # definitely JSON encodable. + parameters = {"command": "generate_balanced", "TODO": "add parameters"} + tables.provenances.add_row( + record=json.dumps(tskit.provenance.get_provenance_dict(parameters)) + ) + + return tables.tree_sequence().first() + def split_polytomies( tree, @@ -640,14 +793,17 @@ def from_tsk_tree(tree): return RankTree.from_tsk_tree_node(tree, tree.root) - def to_tsk_tree(self, span=1): + def to_tsk_tree(self, span=1, branch_length=1): """ - Convert a ``RankTree`` into the only tree in a new tree sequence. + Convert a ``RankTree`` into the only tree in a new tree sequence. Internal + nodes and their times are assigned via a postorder traversal of the tree. :param float span: The genomic span of the returned tree. The tree will cover the interval :math:`[0, span)` and the :attr:`~Tree.tree_sequence` from which the tree is taken will have its :attr:`~tskit.TreeSequence.sequence_length` equal to ``span``. + :param float branch_length: The minimum length of a branch in the returned + tree. """ if set(self.labels) != set(range(self.num_leaves)): raise ValueError("Labels set must be equivalent to [0, num_leaves)") @@ -660,9 +816,8 @@ def add_node(node): return node.label child_ids = [add_node(child) for child in node.children] - # Arbitrarily set parent time +1 from their oldest child max_child_time = max(tables.nodes.time[c] for c in child_ids) - parent_id = tables.nodes.add_row(time=max_child_time + 1) + parent_id = tables.nodes.add_row(time=max_child_time + branch_length) for child_id in child_ids: tables.edges.add_row(0, span, parent_id, child_id) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 5acfeabf2a..6e8d01a3b6 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -30,7 +30,6 @@ import copy import functools import itertools -import json import math import textwrap import warnings @@ -45,7 +44,6 @@ import tskit.exceptions as exceptions import tskit.formats as formats import tskit.metadata as metadata_module -import tskit.provenance as provenance import tskit.tables as tables import tskit.util as util import tskit.vcf as vcf @@ -861,12 +859,14 @@ def rank(self): return combinatorics.RankTree.from_tsk_tree(self).rank() @staticmethod - def unrank(num_leaves, rank, *, span=1): + def unrank(num_leaves, rank, *, span=1, branch_length=1): """ Reconstruct the tree of the given ``rank`` (see :meth:`tskit.Tree.rank`) with ``num_leaves`` leaves. - The labels and times of internal nodes are chosen arbitrarily, and - the time of each leaf is 0. + The labels and times of internal nodes are assigned by a postorder + traversal of the nodes, such that the time of each internal node + is the maximum time of its children plus the specified ``branch_length``. + The time of each leaf is 0. See the :ref:`sec_tree_ranks` section for details on ranking and unranking trees and what constitutes valid ranks. @@ -877,11 +877,13 @@ def unrank(num_leaves, rank, *, span=1): the interval :math:`[0, \\text{span})` and the :attr:`~Tree.tree_sequence` from which the tree is taken will have its :attr:`~tskit.TreeSequence.sequence_length` equal to ``span``. + :param: float branch_length: The minimum length of a branch in this tree. :rtype: Tree :raises: ValueError: If the given rank is out of bounds for trees with ``num_leaves`` leaves. """ - return combinatorics.RankTree.unrank(num_leaves, rank).to_tsk_tree(span=span) + rank_tree = combinatorics.RankTree.unrank(num_leaves, rank) + return rank_tree.to_tsk_tree(span=span, branch_length=branch_length) def count_topologies(self, sample_sets=None): """ @@ -2445,13 +2447,11 @@ def split_polytomies( @staticmethod def generate_star(num_leaves, *, span=1, branch_length=1, record_provenance=True): """ - Generate a single :class: whose leaf nodes all have the same parent (i.e. + Generate a :class: whose leaf nodes all have the same parent (i.e. a "star" tree). The leaf nodes are all at time 0 and are marked as sample nodes. - .. note:: - This is similar to ``tskit.Tree.unrank(n, (0,0), span=span)`` but is more - efficient for large n. However, the ``unrank`` method provides - a concise way of generating alternative (non-star) topologies. + The tree produced by this method is identical to + ``tskit.Tree.unrank(n, (0, 0))``, but generated more efficiently for large ``n``. :param int num_leaves: The number of leaf nodes in the returned tree (must be be 2 or greater). @@ -2464,29 +2464,82 @@ def generate_star(num_leaves, *, span=1, branch_length=1, record_provenance=True via the :attr:`.tree_sequence` attribute. :rtype: Tree """ - if num_leaves < 2: - raise ValueError("The number of leaves must be 2 or greater") - tc = tables.TableCollection(sequence_length=span) - tc.nodes.set_columns( - flags=np.full(num_leaves, NODE_IS_SAMPLE, dtype=np.uint32), - time=np.zeros(num_leaves), + return combinatorics.generate_star( + num_leaves, + span=span, + branch_length=branch_length, + record_provenance=record_provenance, ) - root = tc.nodes.add_row(time=branch_length) - tc.edges.set_columns( - left=np.full(num_leaves, 0), - right=np.full(num_leaves, span), - parent=np.full(num_leaves, root, dtype=np.int32), - child=np.arange(num_leaves, dtype=np.int32), + + @staticmethod + def generate_balanced( + num_leaves, *, arity=2, span=1, branch_length=1, record_provenance=True + ): + """ + Generate a :class: with the specified number of leaves that is maximally + balanced. By default, the tree returned is binary, such that for each + node that subtends :math:`n` leaves, the left child will subtend + :math:`\\floor{n / 2}` leaves and the right child the remainder. Balanced + trees with higher arity can also generated using the ``arity`` parameter, + where the leaves subtending a node are distributed among its children + analogously. + + In the returned tree, the leaf nodes are all at time 0, marked as samples, + and labelled 0 to n from left-to-right. Internal node IDs are assigned + sequentially from n in a postorder traversal, and the time of an internal + node is the maximum time of its children plus the specified ``branch_length``. + + :param int num_leaves: The number of leaf nodes in the returned tree (must be + be 2 or greater). + :param int arity: The maximum number of children a node can have in the returned + tree. + :param float span: The span of the tree, and therefore the + :attr:`~TreeSequence.sequence_length` of the :attr:`.tree_sequence` + property of the returned :class:. + :param float branch_length: The minimum length of a branch in the tree (see + above for details on how internal node times are assigned). + :return: A balanced tree. Its corresponding :class:`TreeSequence` is available + via the :attr:`.tree_sequence` attribute. + :rtype: Tree + """ + return combinatorics.generate_balanced( + num_leaves, + arity=arity, + span=span, + branch_length=branch_length, + record_provenance=record_provenance, + ) + + @staticmethod + def generate_comb(num_leaves, *, span=1, branch_length=1, record_provenance=True): + """ + Generate a :class: in which all internal nodes have two children + and the left child is a leaf. This is a "comb", "ladder" or "pectinate" + phylogeny, and also known as a `caterpiller tree + `_. + + The leaf nodes are all at time 0, marked as samples, + and labelled 0 to n from left-to-right. Internal node IDs are assigned + sequentially from n as we ascend the tree, and the time of an internal + node is the maximum time of its children plus the specified ``branch_length``. + + :param int num_leaves: The number of leaf nodes in the returned tree (must be + be 2 or greater). + :param float span: The span of the tree, and therefore the + :attr:`~TreeSequence.sequence_length` of the :attr:`.tree_sequence` + property of the returned :class:. + :param float branch_length: The length of every branch in the tree (equivalent + to the time of the root node). + :return: A star-shaped tree. Its corresponding :class:`TreeSequence` is available + via the :attr:`.tree_sequence` attribute. + :rtype: Tree + """ + return combinatorics.generate_comb( + num_leaves, + span=span, + branch_length=branch_length, + record_provenance=record_provenance, ) - if record_provenance: - # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 - # TODO also make sure we convert all the arguments so that they are - # definitely JSON encodable. - parameters = {"command": "generate_star", "TODO": "add parameters"} - tc.provenances.add_row( - record=json.dumps(provenance.get_provenance_dict(parameters)) - ) - return tc.tree_sequence().first() def load(file):