Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

**Features**

- Tree accessor functions (e.g. ``ts.first()``, ``ts.at()`` pass extra parameters such as
``sample_indexes`` to the underlying ``Tree`` constructor; also ``root_threshold`` can
be specified when calling ``ts.trees()`` (:user:`hyanwong`, :issue:`847`, :pr:`848`)

- Genomic intervals returned by python functions are now namedtuples, allowing ``.left``
``.right`` and ``.span`` usage (:user:`hyanwong`, :issue:`784`, :pr:`786`, :pr:`811`)

Expand Down
123 changes: 82 additions & 41 deletions python/tests/test_highlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Test cases for the high level interface to tskit.
"""
import collections
import inspect
import io
import itertools
import json
Expand Down Expand Up @@ -879,17 +880,26 @@ def test_samples(self):

def test_first_last(self):
for ts in get_example_tree_sequences():
t1 = ts.first()
t2 = next(ts.trees())
self.assertFalse(t1 is t2)
self.assertEqual(t1.parent_dict, t2.parent_dict)
self.assertEqual(t1.index, 0)

t1 = ts.last()
t2 = next(reversed(ts.trees()))
self.assertFalse(t1 is t2)
self.assertEqual(t1.parent_dict, t2.parent_dict)
self.assertEqual(t1.index, ts.num_trees - 1)
for kwargs in [{}, {"tracked_samples": ts.samples()}]:
t1 = ts.first(**kwargs)
t2 = next(ts.trees())
self.assertFalse(t1 is t2)
self.assertEqual(t1.parent_dict, t2.parent_dict)
self.assertEqual(t1.index, 0)
if "tracked_samples" in kwargs:
self.assertNotEqual(t1.num_tracked_samples(), 0)
else:
self.assertEqual(t1.num_tracked_samples(), 0)

t1 = ts.last(**kwargs)
t2 = next(reversed(ts.trees()))
self.assertFalse(t1 is t2)
self.assertEqual(t1.parent_dict, t2.parent_dict)
self.assertEqual(t1.index, ts.num_trees - 1)
if "tracked_samples" in kwargs:
self.assertNotEqual(t1.num_tracked_samples(), 0)
else:
self.assertEqual(t1.num_tracked_samples(), 0)

def test_trees_interface(self):
ts = list(get_example_tree_sequences())[0]
Expand Down Expand Up @@ -1311,14 +1321,21 @@ def test_len_trees(self):

def test_list(self):
for ts in get_example_tree_sequences():
tree_list = ts.aslist()
self.assertEqual(len(tree_list), ts.num_trees)
self.assertEqual(len(set(map(id, tree_list))), ts.num_trees)
for index, tree in enumerate(tree_list):
self.assertEqual(index, tree.index)
for t1, t2 in zip(tree_list, ts.trees()):
self.assertEqual(t1, t2)
self.assertEqual(t1.parent_dict, t2.parent_dict)
for kwargs in [{}, {"tracked_samples": ts.samples()}]:
tree_list = ts.aslist(**kwargs)
self.assertEqual(len(tree_list), ts.num_trees)
self.assertEqual(len(set(map(id, tree_list))), ts.num_trees)
for index, tree in enumerate(tree_list):
self.assertEqual(index, tree.index)
for t1, t2 in zip(tree_list, ts.trees(**kwargs)):
self.assertEqual(t1, t2)
self.assertEqual(t1.parent_dict, t2.parent_dict)
if "tracked_samples" in kwargs:
self.assertNotEqual(t1.num_tracked_samples(), 0)
self.assertNotEqual(t2.num_tracked_samples(), 0)
else:
self.assertEqual(t1.num_tracked_samples(), 0)
self.assertEqual(t2.num_tracked_samples(), 0)

def test_reversed_trees(self):
for ts in get_example_tree_sequences():
Expand All @@ -1333,31 +1350,41 @@ def test_reversed_trees(self):

def test_at_index(self):
for ts in get_example_tree_sequences():
tree_list = ts.aslist()
for index in list(range(ts.num_trees)) + [-1]:
t1 = tree_list[index]
t2 = ts.at_index(index)
self.assertEqual(t1, t2)
self.assertEqual(t1.interval, t2.interval)
self.assertEqual(t1.parent_dict, t2.parent_dict)

def test_at(self):
for ts in get_example_tree_sequences():
tree_list = ts.aslist()
for t1 in tree_list:
left, right = t1.interval
mid = left + (right - left) / 2
for pos in [left, left + 1e-9, mid, right - 1e-9]:
t2 = ts.at(pos)
for kwargs in [{}, {"tracked_samples": ts.samples()}]:
tree_list = ts.aslist(**kwargs)
for index in list(range(ts.num_trees)) + [-1]:
t1 = tree_list[index]
t2 = ts.at_index(index, **kwargs)
self.assertEqual(t1, t2)
self.assertEqual(t1.interval, t2.interval)
self.assertEqual(t1.parent_dict, t2.parent_dict)
if right < ts.sequence_length:
t2 = ts.at(right)
t3 = tree_list[t1.index + 1]
self.assertEqual(t3, t2)
self.assertEqual(t3.interval, t2.interval)
self.assertEqual(t3.parent_dict, t2.parent_dict)
if "tracked_samples" in kwargs:
self.assertNotEqual(t2.num_tracked_samples(), 0)
else:
self.assertEqual(t2.num_tracked_samples(), 0)

def test_at(self):
for ts in get_example_tree_sequences():
for kwargs in [{}, {"tracked_samples": ts.samples()}]:
tree_list = ts.aslist(**kwargs)
for t1 in tree_list:
left, right = t1.interval
mid = left + (right - left) / 2
for pos in [left, left + 1e-9, mid, right - 1e-9]:
t2 = ts.at(pos, **kwargs)
self.assertEqual(t1, t2)
self.assertEqual(t1.interval, t2.interval)
self.assertEqual(t1.parent_dict, t2.parent_dict)
if right < ts.sequence_length:
t2 = ts.at(right, **kwargs)
t3 = tree_list[t1.index + 1]
self.assertEqual(t3, t2)
self.assertEqual(t3.interval, t2.interval)
self.assertEqual(t3.parent_dict, t2.parent_dict)
if "tracked_samples" in kwargs:
self.assertNotEqual(t2.num_tracked_samples(), 0)
else:
self.assertEqual(t2.num_tracked_samples(), 0)

def test_sequence_iteration(self):
for ts in get_example_tree_sequences():
Expand Down Expand Up @@ -1408,6 +1435,20 @@ def test_kwargs_only(self):
with self.assertRaisesRegex(TypeError, "argument"):
self.ts.draw_svg("filename", True)

def test_trees_params(self):
"""
The initial .trees() iterator parameters should match those in Tree.__init__()
"""
tree_class_params = list(inspect.signature(tskit.Tree).parameters.items())
trees_iter_params = list(
inspect.signature(tskit.TreeSequence.trees).parameters.items()
)
# Skip the first param, which is `tree_sequence` and `self` respectively
tree_class_params = tree_class_params[1:]
# The trees iterator has some extra (deprecated) aliases
trees_iter_params = trees_iter_params[1:-3]
self.assertEqual(trees_iter_params, tree_class_params)


class TestTreeSequenceMetadata(unittest.TestCase):
metadata_tables = [
Expand Down
57 changes: 41 additions & 16 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,6 @@ class Tree:
:param TreeSequence tree_sequence: The parent tree sequence.
:param list tracked_samples: The list of samples to be tracked and
counted using the :meth:`Tree.num_tracked_samples` method.
:param bool sample_counts: Deprecated since 0.2.4.
:param bool sample_lists: If True, provide more efficient access
to the samples beneath a give node using the
:meth:`Tree.samples` method.
Expand All @@ -654,16 +653,17 @@ class Tree:
are roots. To efficiently restrict the roots of the tree to
those subtending meaningful topology, set this to 2. This value
is only relevant when trees have multiple roots.
:param bool sample_counts: Deprecated since 0.2.4.
"""

def __init__(
self,
tree_sequence,
tracked_samples=None,
*,
sample_counts=None,
sample_lists=False,
root_threshold=1,
sample_counts=None,
):
options = 0
if sample_counts is not None:
Expand Down Expand Up @@ -2947,7 +2947,7 @@ def ll_tree_sequence(self):
def get_ll_tree_sequence(self):
return self._ll_tree_sequence

def aslist(self):
def aslist(self, **kwargs):
"""
Returns the trees in this tree sequence as a list. Each tree is
represented by a different instance of :class:`Tree`. As such, this
Expand All @@ -2956,10 +2956,13 @@ def aslist(self):
method is the recommended way to efficiently iterate over the trees
in a tree sequence.

:param \\**kwargs: Further arguments used as parameters when constructing the
returned trees. For example ``ts.aslist(sample_lists=True)`` will result
in a list of :class:`Tree` instances created with ``sample_lists=True``.
:return: A list of the trees in this tree sequence.
:rtype: list
"""
return [tree.copy() for tree in self.trees()]
return [tree.copy() for tree in self.trees(**kwargs)]

@classmethod
def load(cls, path):
Expand Down Expand Up @@ -3601,61 +3604,76 @@ def breakpoints(self, as_array=False):
breakpoints = map(float, breakpoints)
return breakpoints

def at(self, position):
def at(self, position, **kwargs):
"""
Returns the tree covering the specified genomic location. The returned tree
will have ``tree.interval.left`` <= ``position`` < ``tree.interval.right``.
See also :meth:`Tree.seek`.

:param float position: A genomic location.
:param \\**kwargs: Further arguments used as parameters when constructing the
returned :class:`Tree`. For example ``ts.at(2.5, sample_lists=True)`` will
result in a :class:`Tree` created with ``sample_lists=True``.
:return: A new instance of :class:`Tree` positioned to cover the specified
position.
genomic location.
:rtype: Tree
"""
tree = Tree(self)
tree = Tree(self, **kwargs)
tree.seek(position)
return tree

def at_index(self, index):
def at_index(self, index, **kwargs):
"""
Returns the tree at the specified index. See also :meth:`Tree.seek_index`.

:param int index: The index of the required tree.
:param \\**kwargs: Further arguments used as parameters when constructing the
returned :class:`Tree`. For example ``ts.at_index(4, sample_lists=True)``
will result in a :class:`Tree` created with ``sample_lists=True``.
:return: A new instance of :class:`Tree` positioned at the specified index.
:rtype: Tree
"""
tree = Tree(self)
tree = Tree(self, **kwargs)
tree.seek_index(index)
return tree

def first(self):
def first(self, **kwargs):
"""
Returns the first tree in this :class:`TreeSequence`. To iterate over all
trees in the sequence, use the :meth:`.trees` method.

:param \\**kwargs: Further arguments used as parameters when constructing the
returned :class:`Tree`. For example ``ts.first(sample_lists=True)`` will
result in a :class:`Tree` created with ``sample_lists=True``.
:return: The first tree in this tree sequence.
:rtype: :class:`Tree`.
"""
tree = Tree(self)
tree = Tree(self, **kwargs)
tree.first()
return tree

def last(self):
def last(self, **kwargs):
"""
Returns the last tree in this :class:`TreeSequence`. To iterate over all
trees in the sequence, use the :meth:`.trees` method.

:param \\**kwargs: Further arguments used as parameters when constructing the
returned :class:`Tree`. For example ``ts.first(sample_lists=True)`` will
result in a :class:`Tree` created with ``sample_lists=True``.
:return: The last tree in this tree sequence.
:rtype: :class:`Tree`.
"""
tree = Tree(self)
tree = Tree(self, **kwargs)
tree.last()
return tree

def trees(
self,
tracked_samples=None,
*,
sample_counts=None,
sample_lists=False,
root_threshold=1,
sample_counts=None,
tracked_leaves=None,
leaf_counts=None,
leaf_lists=None,
Expand All @@ -3678,10 +3696,16 @@ def trees(

:param list tracked_samples: The list of samples to be tracked and
counted using the :meth:`Tree.num_tracked_samples` method.
:param bool sample_counts: Deprecated since 0.2.4.
:param bool sample_lists: If True, provide more efficient access
to the samples beneath a give node using the
:meth:`Tree.samples` method.
:param int root_threshold: The minimum number of samples that a node
must be ancestral to for it to be in the list of roots. By default
this is 1, so that isolated samples (representing missing data)
are roots. To efficiently restrict the roots of the tree to
those subtending meaningful topology, set this to 2. This value
is only relevant when trees have multiple roots.
:param bool sample_counts: Deprecated since 0.2.4.
:return: An iterator over the Trees in this tree sequence.
:rtype: collections.abc.Iterable, :class:`Tree`
"""
Expand All @@ -3698,8 +3722,9 @@ def trees(
tree = Tree(
self,
tracked_samples=tracked_samples,
sample_counts=sample_counts,
sample_lists=sample_lists,
root_threshold=root_threshold,
sample_counts=sample_counts,
)
return TreeIterator(tree)

Expand Down