diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index d5245c1c4e..74694916e1 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -4,6 +4,9 @@ **Features** +- Added ``Tree.generate_random_binary`` static method to create random + binary trees (:user:`hyanwong`, :user:`jeromekelleher`, :pr:`1037`). + - Change the default behaviour of Tree.split_polytomies to generate the shortest possible branch lengths instead of a fixed epsilon of 1e-10. (:user:`jeromekelleher`, :issue:`1089`, :pr:`1090`) diff --git a/python/tests/test_combinatorics.py b/python/tests/test_combinatorics.py index e5f62ec0b3..19bd9e9ed5 100644 --- a/python/tests/test_combinatorics.py +++ b/python/tests/test_combinatorics.py @@ -1092,19 +1092,6 @@ def ts_polytomy_44344(self): strict=False, ) - @pytest.mark.slow - @pytest.mark.parametrize("n", [2, 3, 4, 5]) - def test_all_topologies(self, n): - N = num_leaf_labelled_binary_trees(n) - ranks = collections.Counter() - tree = tskit.Tree.generate_star(n) - for seed in range(20 * N): - split_tree = tree.split_polytomies(random_seed=seed) - ranks[split_tree.rank()] += 1 - # There are N possible binary trees here, we should have seen them - # all with high probability after 20 N attempts. - assert len(ranks) == N - def verify_trees(self, source_tree, split_tree, epsilon=None): N = 0 for u in split_tree.nodes(): @@ -1346,6 +1333,19 @@ def test_kwargs(self): split_tree = tree.split_polytomies(random_seed=14, tracked_samples=[0, 1]) assert split_tree.num_tracked_samples() == 2 + @pytest.mark.slow + @pytest.mark.parametrize("n", [3, 4, 5]) + def test_all_topologies(self, n): + N = num_leaf_labelled_binary_trees(n) + ranks = collections.Counter() + for seed in range(20 * N): + star = tskit.Tree.generate_star(n) + random_tree = star.split_polytomies(random_seed=seed) + ranks[random_tree.rank()] += 1 + # There are N possible binary trees here, we should have seen them + # all with high probability after 20 N attempts. + assert len(ranks) == N + class TreeGeneratorTestBase: """ @@ -1473,6 +1473,41 @@ def test_branch_length_semantics(self): assert tree.time(u) == tree.time(v) + branch_length +class TestGenerateRandomBinary(TreeGeneratorTestBase): + method_name = "generate_random_binary" + + def method(self, n, **kwargs): + return tskit.Tree.generate_random_binary(n, random_seed=53, **kwargs) + + @pytest.mark.slow + @pytest.mark.parametrize("n", [3, 4, 5]) + def test_all_topologies(self, n): + N = num_leaf_labelled_binary_trees(n) + ranks = collections.Counter() + for seed in range(20 * N): + random_tree = tskit.Tree.generate_random_binary(n, random_seed=seed) + ranks[random_tree.rank()] += 1 + # There are N possible binary trees here, we should have seen them + # all with high probability after 20 N attempts. + assert len(ranks) == N + + @pytest.mark.parametrize("n", range(2, 10)) + def test_leaves(self, n): + tree = tskit.Tree.generate_random_binary(n, random_seed=1234) + # The leaves should be a permutation of range(n) + assert list(sorted(tree.leaves())) == list(range(n)) + + @pytest.mark.parametrize("seed", range(1, 20)) + def test_rank_unrank_round_trip_seeds(self, seed): + n = 10 + tree1 = tskit.Tree.generate_random_binary(n, random_seed=seed) + rank = tree1.rank() + tree2 = tskit.Tree.unrank(n, rank) + tables1 = tree1.tree_sequence.tables + tables2 = tree2.tree_sequence.tables + assert tables1.equals(tables2, ignore_provenance=True) + + class TestGenerateComb(TreeGeneratorTestBase): method_name = "generate_comb" diff --git a/python/tskit/combinatorics.py b/python/tskit/combinatorics.py index af6b785ee7..3547710f46 100644 --- a/python/tskit/combinatorics.py +++ b/python/tskit/combinatorics.py @@ -71,6 +71,33 @@ class TreeNode: children = attr.ib(factory=list) label = attr.ib(default=None) + def as_tables(self, *, num_leaves, span, branch_length): + """ + Convert the tree rooted at this node into an equivalent + TableCollection. Internal nodes are allocated in postorder. + """ + tables = tskit.TableCollection(span) + for _ in range(num_leaves): + tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) + + def assign_internal_labels(node): + if len(node.children) == 0: + node.time = 0 + else: + max_child_time = 0 + for child in node.children: + assign_internal_labels(child) + max_child_time = max(max_child_time, child.time) + node.time = max_child_time + branch_length + node.label = tables.nodes.add_row(time=node.time) + for child in node.children: + tables.edges.add_row(0, span, node.label, child.label) + + # Do a postorder traversal to assign the internal node labels and times. + assign_internal_labels(self) + tables.sort() + return tables + @staticmethod def random_binary_tree(leaf_labels, rng): """ @@ -104,6 +131,24 @@ def random_binary_tree(leaf_labels, rng): root = nodes[0] while root.parent is not None: root = root.parent + + # Canonicalise the order of the children within a node. This + # is given by (num_leaves, min_label). See also the + # RankTree.canonical_order function for the definition of + # how these are ordered during rank/unrank. + + def reorder_children(node): + if len(node.children) == 0: + return 1, node.label + keys = [reorder_children(child) for child in node.children] + if keys[0] > keys[1]: + node.children = node.children[::-1] + return ( + sum(leaf_count for leaf_count, _ in keys), + min(min_label for _, min_label in keys), + ) + + reorder_children(root) return root @classmethod @@ -204,27 +249,10 @@ def generate_balanced( if arity < 2: raise ValueError("The arity must be at least 2") - tables = tskit.TableCollection(span) - for _ in range(num_leaves): - tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0) - - def assign_internal_labels(node): - if len(node.children) == 0: - node.time = 0 - else: - max_child_time = 0 - for child in node.children: - assign_internal_labels(child) - max_child_time = max(max_child_time, child.time) - node.time = max_child_time + branch_length - node.label = tables.nodes.add_row(time=node.time) - for child in node.children: - tables.edges.add_row(0, span, node.label, child.label) - root = TreeNode.balanced_tree(range(num_leaves), arity) - # Do a postorder traversal to assign the internal node labels and times. - assign_internal_labels(root) - tables.sort() + tables = root.as_tables( + num_leaves=num_leaves, span=span, branch_length=branch_length + ) if record_provenance: # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 @@ -238,6 +266,35 @@ def assign_internal_labels(node): return tables.tree_sequence().first(**kwargs) +def generate_random_binary( + num_leaves, *, span, branch_length, random_seed, record_provenance, **kwargs +): + """ + Sample a leaf-labelled binary tree uniformly. + + See the documentation for :meth:`Tree.generate_random_binary` for more details. + """ + if num_leaves < 1: + raise ValueError("The number of leaves must be at least 1") + + rng = random.Random(random_seed) + root = TreeNode.random_binary_tree(range(num_leaves), rng) + tables = root.as_tables( + num_leaves=num_leaves, span=span, branch_length=branch_length + ) + + if record_provenance: + # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 + # TODO also make sure we convert all the arguments so that they are + # definitely JSON encodable. + parameters = {"command": "generate_random_binary", "TODO": "add parameters"} + tables.provenances.add_row( + record=json.dumps(tskit.provenance.get_provenance_dict(parameters)) + ) + ts = tables.tree_sequence() + return ts.first(**kwargs) + + def split_polytomies( tree, *, @@ -252,8 +309,7 @@ def split_polytomies( so that any any node with more than two children is resolved into a binary tree. - For further documentation, please refer to the :meth:`Tree.split_polytomies` - method, which is the usual route through which this function is called. + See the documentation for :meth:`Tree.split_polytomies` for more details. """ allowed_methods = ["random"] if method is None: diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 5daf47f097..92e93b79f0 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -2584,6 +2584,60 @@ def generate_comb( **kwargs, ) + @staticmethod + def generate_random_binary( + num_leaves, + *, + span=1, + branch_length=1, + random_seed=None, + record_provenance=True, + **kwargs, + ): + """ + Generate a random binary :class: with :math:`n` = ``num_leaves`` + leaves with an equal probability of returning any topology and + leaf label permutation among the :math:`(2n - 3)! / (2^(n - 2) (n - 2)!)` + leaf-labelled binary trees. + + The leaf nodes are marked as samples, labelled 0 to n, and placed at + time 0. Internal node IDs are assigned sequentially from n as we ascend + the tree, and the time of an internal node is the maximum time of its + children plus the specified ``branch_length``. + + .. note:: + The returned tree has not been created under any explicit model of + evolution. In order to simulate such trees, additional software + such as `msprime `` is required. + + :param int num_leaves: The number of leaf nodes in the returned tree (must + be 2 or greater). + :param float span: The span of the tree, and therefore the + :attr:`~TreeSequence.sequence_length` of the :attr:`.tree_sequence` + property of the returned :class:. + :param float branch_length: The minimum time between parent and child nodes. + :param int random_seed: The random seed. If this is None, a random seed will + be automatically generated. Valid random seeds must be between 1 and + :math:`2^32 − 1`. + :param bool record_provenance: If True, add details of this operation to the + provenance information of the returned tree sequence. (Default: True). + :param \\**kwargs: Further arguments used as parameters when constructing the + returned :class:`Tree`. For example + ``tskit.Tree.generate_comb(sample_lists=True)`` will + return a :class:`Tree` created with ``sample_lists=True``. + :return: A random binary tree. Its corresponding :class:`TreeSequence` is + available via the :attr:`.tree_sequence` attribute. + :rtype: Tree + """ + return combinatorics.generate_random_binary( + num_leaves, + span=span, + branch_length=branch_length, + random_seed=random_seed, + record_provenance=record_provenance, + **kwargs, + ) + def load(file): """