Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -7951,3 +7951,22 @@ def test_star_branch_length(self):

assert ts.node(ts.first().root).time == branch_length
assert ts.kc_distance(topological_equiv_ts) == 0

def test_balanced_equivalent(self):
for extra_params in [{}, {"span": 2.5}]:
for n in range(2, 10):
tree = tskit.Tree.generate_balanced(n, **extra_params)
# balanced shape is the last shape in the rank
assert tree.rank()[0] == tskit.combinatorics.num_shapes(n) - 1

def test_balanced_bad_params(self):
for n in [-1, 0, 1, np.array([1, 2])]:
with pytest.raises(ValueError):
tskit.Tree.generate_balanced(n)
for n in [None, "", []]:
with pytest.raises(TypeError):
tskit.Tree.generate_balanced(n)
with pytest.raises(tskit.LibraryError):
tskit.Tree.generate_balanced(2, span=0)
with pytest.raises(tskit.LibraryError):
tskit.Tree.generate_balanced(2, branch_length=0)
59 changes: 59 additions & 0 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -2390,6 +2390,65 @@ def generate_star(num_leaves, *, span=1, branch_length=1, record_provenance=True
)
return tc.tree_sequence().first()

@staticmethod
def generate_balanced(
num_leaves, *, span=1, branch_length=1, record_provenance=True
):
"""
Generate a single bifurcating :class:<Tree> where the tree is as balanced as
possible (i.e. all parents have exactly 2 children, and the number of branches
(edges) between each leaf and the root is either :math:`n` or :math:`n+1` where
:math:`n = \\text{floor}(\\text{log}_{2}(\\text{num\\_leaves}))`. If
``num_leaves`` is an integer power of 2, the tree will be fully balanced, such
that all leaves have :math:`n` edges to the root.

.. note::
This creates a tree topology identical to that generated by
``tskit.Tree.unrank((tskit.combinatorics.num_shapes(n)-1, 0), n, span=span)``
but is more efficient for large n. However, the ``unrank`` method provides
a concise way of generating alternative (non-balanced) topologies.

:param int num_leaves: The number of leaf nodes in the returned tree (must be
be 2 or greater).
:param float span: The span of the tree, and therefore the
:attr:`~TreeSequence.sequence_length` of the :attr:`~Tree.tree_sequence`
property of the returned :class:<Tree>
:param float branch_length: The default length of branches in the tree. If the
tree is fully balanced (i.e. ``num_leaves`` is an integer power of 2) then
all branches will be this length, otherwise some branches may be double
:return: A balanced tree. Its corresponding :class:`TreeSequence` is available
via the :attr:`.tree_sequence` attribute.
:rtype: Tree
"""
if num_leaves < 2:
raise ValueError("The number of leaves must be 2 or greater")
tc = tables.TableCollection(sequence_length=span)
tc.nodes.set_columns(
flags=np.full(num_leaves, NODE_IS_SAMPLE, dtype=np.uint32),
time=np.zeros(num_leaves),
)
merge_nodes = list(range(num_leaves))
while len(merge_nodes) > 1:
t = tc.nodes[merge_nodes[-1]].time + branch_length
merge_up_to = (len(merge_nodes) // 2) * 2
for merge_index in range(0, merge_up_to, 2):
node1 = merge_nodes[merge_index]
node2 = merge_nodes[merge_index + 1]
parent = tc.nodes.add_row(time=t)
merge_nodes.append(parent)
tc.edges.add_row(left=0, right=span, parent=parent, child=node1)
tc.edges.add_row(left=0, right=span, parent=parent, child=node2)
del merge_nodes[0:merge_up_to]
if record_provenance:
# TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243
# TODO also make sure we convert all the arguments so that they are
# definitely JSON encodable.
parameters = {"command": "generate_balanced", "TODO": "add parameters"}
tc.provenances.add_row(
record=json.dumps(provenance.get_provenance_dict(parameters))
)
return tc.tree_sequence().first()


def load(file):
"""
Expand Down