diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index cb77a87e52..c7122a4726 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -7951,3 +7951,22 @@ def test_star_branch_length(self): assert ts.node(ts.first().root).time == branch_length assert ts.kc_distance(topological_equiv_ts) == 0 + + def test_balanced_equivalent(self): + for extra_params in [{}, {"span": 2.5}]: + for n in range(2, 10): + tree = tskit.Tree.generate_balanced(n, **extra_params) + # balanced shape is the last shape in the rank + assert tree.rank()[0] == tskit.combinatorics.num_shapes(n) - 1 + + def test_balanced_bad_params(self): + for n in [-1, 0, 1, np.array([1, 2])]: + with pytest.raises(ValueError): + tskit.Tree.generate_balanced(n) + for n in [None, "", []]: + with pytest.raises(TypeError): + tskit.Tree.generate_balanced(n) + with pytest.raises(tskit.LibraryError): + tskit.Tree.generate_balanced(2, span=0) + with pytest.raises(tskit.LibraryError): + tskit.Tree.generate_balanced(2, branch_length=0) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index b86b9a8fb9..a1a8dc7bda 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -2390,6 +2390,65 @@ def generate_star(num_leaves, *, span=1, branch_length=1, record_provenance=True ) return tc.tree_sequence().first() + @staticmethod + def generate_balanced( + num_leaves, *, span=1, branch_length=1, record_provenance=True + ): + """ + Generate a single bifurcating :class: where the tree is as balanced as + possible (i.e. all parents have exactly 2 children, and the number of branches + (edges) between each leaf and the root is either :math:`n` or :math:`n+1` where + :math:`n = \\text{floor}(\\text{log}_{2}(\\text{num\\_leaves}))`. If + ``num_leaves`` is an integer power of 2, the tree will be fully balanced, such + that all leaves have :math:`n` edges to the root. + + .. note:: + This creates a tree topology identical to that generated by + ``tskit.Tree.unrank((tskit.combinatorics.num_shapes(n)-1, 0), n, span=span)`` + but is more efficient for large n. However, the ``unrank`` method provides + a concise way of generating alternative (non-balanced) topologies. + + :param int num_leaves: The number of leaf nodes in the returned tree (must be + be 2 or greater). + :param float span: The span of the tree, and therefore the + :attr:`~TreeSequence.sequence_length` of the :attr:`~Tree.tree_sequence` + property of the returned :class: + :param float branch_length: The default length of branches in the tree. If the + tree is fully balanced (i.e. ``num_leaves`` is an integer power of 2) then + all branches will be this length, otherwise some branches may be double + :return: A balanced tree. Its corresponding :class:`TreeSequence` is available + via the :attr:`.tree_sequence` attribute. + :rtype: Tree + """ + if num_leaves < 2: + raise ValueError("The number of leaves must be 2 or greater") + tc = tables.TableCollection(sequence_length=span) + tc.nodes.set_columns( + flags=np.full(num_leaves, NODE_IS_SAMPLE, dtype=np.uint32), + time=np.zeros(num_leaves), + ) + merge_nodes = list(range(num_leaves)) + while len(merge_nodes) > 1: + t = tc.nodes[merge_nodes[-1]].time + branch_length + merge_up_to = (len(merge_nodes) // 2) * 2 + for merge_index in range(0, merge_up_to, 2): + node1 = merge_nodes[merge_index] + node2 = merge_nodes[merge_index + 1] + parent = tc.nodes.add_row(time=t) + merge_nodes.append(parent) + tc.edges.add_row(left=0, right=span, parent=parent, child=node1) + tc.edges.add_row(left=0, right=span, parent=parent, child=node2) + del merge_nodes[0:merge_up_to] + if record_provenance: + # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 + # TODO also make sure we convert all the arguments so that they are + # definitely JSON encodable. + parameters = {"command": "generate_balanced", "TODO": "add parameters"} + tc.provenances.add_row( + record=json.dumps(provenance.get_provenance_dict(parameters)) + ) + return tc.tree_sequence().first() + def load(file): """