From ecaa2108612ed29f793d818ef285a2eee5813eda Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Sat, 24 Oct 2020 18:30:56 +0100 Subject: [PATCH] Add a Tree.create_star method --- python/CHANGELOG.rst | 3 +++ python/tests/test_topology.py | 34 +++++++++++++++++++++++++ python/tskit/trees.py | 48 +++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 293eb4b433..2cec6bf8d3 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -4,6 +4,9 @@ **Features** +- Added ``Tree.generate_star`` static method to create star-topologies (:user:`hyanwong`, + :pr:`934`). + - Added ``equals`` method to TableCollection and each of the tables which provides more flexible equality comparisons, for example, allowing users to ignore metadata or provenance in the comparison. diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index d588ee1a8c..cb77a87e52 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -7917,3 +7917,37 @@ def test_is_isolated_bad(self): tree.is_isolated("abc") with pytest.raises(TypeError): tree.is_isolated(1.1) + + +class TestExampleTrees: + """ + Test hard-coded example tree (sequence) generation + """ + + def test_star_equivalent(self): + for extra_params in [{}, {"span": 2.5}]: + for n in range(2, 6): + ts = tskit.Tree.generate_star(n, **extra_params).tree_sequence + equiv_ts = tskit.Tree.unrank((0, 0), n, **extra_params).tree_sequence + assert ts.tables.equals(equiv_ts.tables, ignore_provenance=True) + + def test_star_bad_params(self): + for n in [-1, 0, 1, np.array([1, 2])]: + with pytest.raises(ValueError): + tskit.Tree.generate_star(n) + for n in [None, "", []]: + with pytest.raises(TypeError): + tskit.Tree.generate_star(n) + with pytest.raises(tskit.LibraryError): + tskit.Tree.generate_star(2, span=0) + with pytest.raises(tskit.LibraryError): + tskit.Tree.generate_star(2, branch_length=0) + + def test_star_branch_length(self): + branch_length = 10 + n = 7 + ts = tskit.Tree.generate_star(n, branch_length=branch_length).tree_sequence + topological_equiv_ts = tskit.Tree.unrank((0, 0), n).tree_sequence + + assert ts.node(ts.first().root).time == branch_length + assert ts.kc_distance(topological_equiv_ts) == 0 diff --git a/python/tskit/trees.py b/python/tskit/trees.py index d542a37607..b86b9a8fb9 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -30,6 +30,7 @@ import copy import functools import itertools +import json import math import os import textwrap @@ -45,6 +46,7 @@ import tskit.exceptions as exceptions import tskit.formats as formats import tskit.metadata as metadata_module +import tskit.provenance as provenance import tskit.tables as tables import tskit.util as util import tskit.vcf as vcf @@ -2342,6 +2344,52 @@ def kc_distance(self, other, lambda_=0.0): """ return self._ll_tree.get_kc_distance(other._ll_tree, lambda_) + @staticmethod + def generate_star(num_leaves, *, span=1, branch_length=1, record_provenance=True): + """ + Generate a single :class: whose leaf nodes all have the same parent (i.e. + a "star" tree). The leaf nodes are all at time 0 and are marked as sample nodes. + + .. note:: + This is similar to ``tskit.Tree.unrank((0,0), n, span=span)`` but is more + efficient for large n. However, the ``unrank`` method provides + a concise way of generating alternative (non-star) topologies. + + :param int num_leaves: The number of leaf nodes in the returned tree (must be + be 2 or greater). + :param float span: The span of the tree, and therefore the + :attr:`~TreeSequence.sequence_length` of the :attr:`.tree_sequence` + property of the returned :class:. + :param float branch_length: The length of every branch in the tree (equivalent + to the time of the root node). + :return: A star-shaped tree. Its corresponding :class:`TreeSequence` is available + via the :attr:`.tree_sequence` attribute. + :rtype: Tree + """ + if num_leaves < 2: + raise ValueError("The number of leaves must be 2 or greater") + tc = tables.TableCollection(sequence_length=span) + tc.nodes.set_columns( + flags=np.full(num_leaves, NODE_IS_SAMPLE, dtype=np.uint32), + time=np.zeros(num_leaves), + ) + root = tc.nodes.add_row(time=branch_length) + tc.edges.set_columns( + left=np.full(num_leaves, 0), + right=np.full(num_leaves, span), + parent=np.full(num_leaves, root, dtype=np.int32), + child=np.arange(num_leaves, dtype=np.int32), + ) + if record_provenance: + # TODO replace with a version of https://github.com/tskit-dev/tskit/pull/243 + # TODO also make sure we convert all the arguments so that they are + # definitely JSON encodable. + parameters = {"command": "generate_star", "TODO": "add parameters"} + tc.provenances.add_row( + record=json.dumps(provenance.get_provenance_dict(parameters)) + ) + return tc.tree_sequence().first() + def load(file): """