diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index a3c8a47da5..644be33d84 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -13,6 +13,10 @@ **Features** +- Tree accessor functions (e.g. ``ts.first()``, ``ts.at()`` pass extra parameters such as + ``sample_indexes`` to the underlying ``Tree`` constructor; also ``root_threshold`` can + be specified when calling ``ts.trees()`` (:user:`hyanwong`, :issue:`847`, :pr:`848`) + - Genomic intervals returned by python functions are now namedtuples, allowing ``.left`` ``.right`` and ``.span`` usage (:user:`hyanwong`, :issue:`784`, :pr:`786`, :pr:`811`) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index fe9fd498df..685614b732 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -24,6 +24,7 @@ Test cases for the high level interface to tskit. """ import collections +import inspect import io import itertools import json @@ -879,17 +880,26 @@ def test_samples(self): def test_first_last(self): for ts in get_example_tree_sequences(): - t1 = ts.first() - t2 = next(ts.trees()) - self.assertFalse(t1 is t2) - self.assertEqual(t1.parent_dict, t2.parent_dict) - self.assertEqual(t1.index, 0) - - t1 = ts.last() - t2 = next(reversed(ts.trees())) - self.assertFalse(t1 is t2) - self.assertEqual(t1.parent_dict, t2.parent_dict) - self.assertEqual(t1.index, ts.num_trees - 1) + for kwargs in [{}, {"tracked_samples": ts.samples()}]: + t1 = ts.first(**kwargs) + t2 = next(ts.trees()) + self.assertFalse(t1 is t2) + self.assertEqual(t1.parent_dict, t2.parent_dict) + self.assertEqual(t1.index, 0) + if "tracked_samples" in kwargs: + self.assertNotEqual(t1.num_tracked_samples(), 0) + else: + self.assertEqual(t1.num_tracked_samples(), 0) + + t1 = ts.last(**kwargs) + t2 = next(reversed(ts.trees())) + self.assertFalse(t1 is t2) + self.assertEqual(t1.parent_dict, t2.parent_dict) + self.assertEqual(t1.index, ts.num_trees - 1) + if "tracked_samples" in kwargs: + self.assertNotEqual(t1.num_tracked_samples(), 0) + else: + self.assertEqual(t1.num_tracked_samples(), 0) def test_trees_interface(self): ts = list(get_example_tree_sequences())[0] @@ -1311,14 +1321,21 @@ def test_len_trees(self): def test_list(self): for ts in get_example_tree_sequences(): - tree_list = ts.aslist() - self.assertEqual(len(tree_list), ts.num_trees) - self.assertEqual(len(set(map(id, tree_list))), ts.num_trees) - for index, tree in enumerate(tree_list): - self.assertEqual(index, tree.index) - for t1, t2 in zip(tree_list, ts.trees()): - self.assertEqual(t1, t2) - self.assertEqual(t1.parent_dict, t2.parent_dict) + for kwargs in [{}, {"tracked_samples": ts.samples()}]: + tree_list = ts.aslist(**kwargs) + self.assertEqual(len(tree_list), ts.num_trees) + self.assertEqual(len(set(map(id, tree_list))), ts.num_trees) + for index, tree in enumerate(tree_list): + self.assertEqual(index, tree.index) + for t1, t2 in zip(tree_list, ts.trees(**kwargs)): + self.assertEqual(t1, t2) + self.assertEqual(t1.parent_dict, t2.parent_dict) + if "tracked_samples" in kwargs: + self.assertNotEqual(t1.num_tracked_samples(), 0) + self.assertNotEqual(t2.num_tracked_samples(), 0) + else: + self.assertEqual(t1.num_tracked_samples(), 0) + self.assertEqual(t2.num_tracked_samples(), 0) def test_reversed_trees(self): for ts in get_example_tree_sequences(): @@ -1333,31 +1350,41 @@ def test_reversed_trees(self): def test_at_index(self): for ts in get_example_tree_sequences(): - tree_list = ts.aslist() - for index in list(range(ts.num_trees)) + [-1]: - t1 = tree_list[index] - t2 = ts.at_index(index) - self.assertEqual(t1, t2) - self.assertEqual(t1.interval, t2.interval) - self.assertEqual(t1.parent_dict, t2.parent_dict) - - def test_at(self): - for ts in get_example_tree_sequences(): - tree_list = ts.aslist() - for t1 in tree_list: - left, right = t1.interval - mid = left + (right - left) / 2 - for pos in [left, left + 1e-9, mid, right - 1e-9]: - t2 = ts.at(pos) + for kwargs in [{}, {"tracked_samples": ts.samples()}]: + tree_list = ts.aslist(**kwargs) + for index in list(range(ts.num_trees)) + [-1]: + t1 = tree_list[index] + t2 = ts.at_index(index, **kwargs) self.assertEqual(t1, t2) self.assertEqual(t1.interval, t2.interval) self.assertEqual(t1.parent_dict, t2.parent_dict) - if right < ts.sequence_length: - t2 = ts.at(right) - t3 = tree_list[t1.index + 1] - self.assertEqual(t3, t2) - self.assertEqual(t3.interval, t2.interval) - self.assertEqual(t3.parent_dict, t2.parent_dict) + if "tracked_samples" in kwargs: + self.assertNotEqual(t2.num_tracked_samples(), 0) + else: + self.assertEqual(t2.num_tracked_samples(), 0) + + def test_at(self): + for ts in get_example_tree_sequences(): + for kwargs in [{}, {"tracked_samples": ts.samples()}]: + tree_list = ts.aslist(**kwargs) + for t1 in tree_list: + left, right = t1.interval + mid = left + (right - left) / 2 + for pos in [left, left + 1e-9, mid, right - 1e-9]: + t2 = ts.at(pos, **kwargs) + self.assertEqual(t1, t2) + self.assertEqual(t1.interval, t2.interval) + self.assertEqual(t1.parent_dict, t2.parent_dict) + if right < ts.sequence_length: + t2 = ts.at(right, **kwargs) + t3 = tree_list[t1.index + 1] + self.assertEqual(t3, t2) + self.assertEqual(t3.interval, t2.interval) + self.assertEqual(t3.parent_dict, t2.parent_dict) + if "tracked_samples" in kwargs: + self.assertNotEqual(t2.num_tracked_samples(), 0) + else: + self.assertEqual(t2.num_tracked_samples(), 0) def test_sequence_iteration(self): for ts in get_example_tree_sequences(): @@ -1408,6 +1435,20 @@ def test_kwargs_only(self): with self.assertRaisesRegex(TypeError, "argument"): self.ts.draw_svg("filename", True) + def test_trees_params(self): + """ + The initial .trees() iterator parameters should match those in Tree.__init__() + """ + tree_class_params = list(inspect.signature(tskit.Tree).parameters.items()) + trees_iter_params = list( + inspect.signature(tskit.TreeSequence.trees).parameters.items() + ) + # Skip the first param, which is `tree_sequence` and `self` respectively + tree_class_params = tree_class_params[1:] + # The trees iterator has some extra (deprecated) aliases + trees_iter_params = trees_iter_params[1:-3] + self.assertEqual(trees_iter_params, tree_class_params) + class TestTreeSequenceMetadata(unittest.TestCase): metadata_tables = [ diff --git a/python/tskit/trees.py b/python/tskit/trees.py index ed9bd82f8a..f6d64a94b2 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -644,7 +644,6 @@ class Tree: :param TreeSequence tree_sequence: The parent tree sequence. :param list tracked_samples: The list of samples to be tracked and counted using the :meth:`Tree.num_tracked_samples` method. - :param bool sample_counts: Deprecated since 0.2.4. :param bool sample_lists: If True, provide more efficient access to the samples beneath a give node using the :meth:`Tree.samples` method. @@ -654,6 +653,7 @@ class Tree: are roots. To efficiently restrict the roots of the tree to those subtending meaningful topology, set this to 2. This value is only relevant when trees have multiple roots. + :param bool sample_counts: Deprecated since 0.2.4. """ def __init__( @@ -661,9 +661,9 @@ def __init__( tree_sequence, tracked_samples=None, *, - sample_counts=None, sample_lists=False, root_threshold=1, + sample_counts=None, ): options = 0 if sample_counts is not None: @@ -2947,7 +2947,7 @@ def ll_tree_sequence(self): def get_ll_tree_sequence(self): return self._ll_tree_sequence - def aslist(self): + def aslist(self, **kwargs): """ Returns the trees in this tree sequence as a list. Each tree is represented by a different instance of :class:`Tree`. As such, this @@ -2956,10 +2956,13 @@ def aslist(self): method is the recommended way to efficiently iterate over the trees in a tree sequence. + :param \\**kwargs: Further arguments used as parameters when constructing the + returned trees. For example ``ts.aslist(sample_lists=True)`` will result + in a list of :class:`Tree` instances created with ``sample_lists=True``. :return: A list of the trees in this tree sequence. :rtype: list """ - return [tree.copy() for tree in self.trees()] + return [tree.copy() for tree in self.trees(**kwargs)] @classmethod def load(cls, path): @@ -3601,52 +3604,66 @@ def breakpoints(self, as_array=False): breakpoints = map(float, breakpoints) return breakpoints - def at(self, position): + def at(self, position, **kwargs): """ Returns the tree covering the specified genomic location. The returned tree will have ``tree.interval.left`` <= ``position`` < ``tree.interval.right``. See also :meth:`Tree.seek`. + :param float position: A genomic location. + :param \\**kwargs: Further arguments used as parameters when constructing the + returned :class:`Tree`. For example ``ts.at(2.5, sample_lists=True)`` will + result in a :class:`Tree` created with ``sample_lists=True``. :return: A new instance of :class:`Tree` positioned to cover the specified - position. + genomic location. :rtype: Tree """ - tree = Tree(self) + tree = Tree(self, **kwargs) tree.seek(position) return tree - def at_index(self, index): + def at_index(self, index, **kwargs): """ Returns the tree at the specified index. See also :meth:`Tree.seek_index`. + :param int index: The index of the required tree. + :param \\**kwargs: Further arguments used as parameters when constructing the + returned :class:`Tree`. For example ``ts.at_index(4, sample_lists=True)`` + will result in a :class:`Tree` created with ``sample_lists=True``. :return: A new instance of :class:`Tree` positioned at the specified index. :rtype: Tree """ - tree = Tree(self) + tree = Tree(self, **kwargs) tree.seek_index(index) return tree - def first(self): + def first(self, **kwargs): """ Returns the first tree in this :class:`TreeSequence`. To iterate over all trees in the sequence, use the :meth:`.trees` method. + :param \\**kwargs: Further arguments used as parameters when constructing the + returned :class:`Tree`. For example ``ts.first(sample_lists=True)`` will + result in a :class:`Tree` created with ``sample_lists=True``. :return: The first tree in this tree sequence. :rtype: :class:`Tree`. """ - tree = Tree(self) + tree = Tree(self, **kwargs) tree.first() return tree - def last(self): + def last(self, **kwargs): """ Returns the last tree in this :class:`TreeSequence`. To iterate over all trees in the sequence, use the :meth:`.trees` method. + :param \\**kwargs: Further arguments used as parameters when constructing the + returned :class:`Tree`. For example ``ts.first(sample_lists=True)`` will + result in a :class:`Tree` created with ``sample_lists=True``. :return: The last tree in this tree sequence. :rtype: :class:`Tree`. """ - tree = Tree(self) + tree = Tree(self, **kwargs) tree.last() return tree @@ -3654,8 +3671,9 @@ def trees( self, tracked_samples=None, *, - sample_counts=None, sample_lists=False, + root_threshold=1, + sample_counts=None, tracked_leaves=None, leaf_counts=None, leaf_lists=None, @@ -3678,10 +3696,16 @@ def trees( :param list tracked_samples: The list of samples to be tracked and counted using the :meth:`Tree.num_tracked_samples` method. - :param bool sample_counts: Deprecated since 0.2.4. :param bool sample_lists: If True, provide more efficient access to the samples beneath a give node using the :meth:`Tree.samples` method. + :param int root_threshold: The minimum number of samples that a node + must be ancestral to for it to be in the list of roots. By default + this is 1, so that isolated samples (representing missing data) + are roots. To efficiently restrict the roots of the tree to + those subtending meaningful topology, set this to 2. This value + is only relevant when trees have multiple roots. + :param bool sample_counts: Deprecated since 0.2.4. :return: An iterator over the Trees in this tree sequence. :rtype: collections.abc.Iterable, :class:`Tree` """ @@ -3698,8 +3722,9 @@ def trees( tree = Tree( self, tracked_samples=tracked_samples, - sample_counts=sample_counts, sample_lists=sample_lists, + root_threshold=root_threshold, + sample_counts=sample_counts, ) return TreeIterator(tree)