diff --git a/docs/python-api.rst b/docs/python-api.rst index 287e81ec38..d75a24deec 100644 --- a/docs/python-api.rst +++ b/docs/python-api.rst @@ -86,6 +86,9 @@ directly, but are the return types for the various iterators provided by the .. autoclass:: Edge() :members: +.. autoclass:: Interval() + :members: + .. autoclass:: Site() :members: diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index fdc6b90811..4bffc7c6f3 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -4,6 +4,10 @@ **Features** +- Expose ``TreeSequence.coiterate()`` method to allow iteration over 2 sequences + simultaneously, aiding comparison of trees from two sequences. + (:user:`jeromekelleher`, :user:`hyanwong`, :issue:`1021`, :pr:`1022`) + - tskit is now supported on, and has wheels for, python3.9. (:user:`benjeffery`, :issue:`982`, :pr:`907`) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index f24ca18712..b61222353e 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -4758,6 +4758,95 @@ def test_partial_overlap_contradictory_children(self): tskit.load_text(nodes=nodes, edges=edges, strict=False) +class TestCoiteration: + """ + Test ability to iterate over multiple (currently 2) tree sequences simultaneously + """ + + def test_identical_ts(self): + ts = msprime.simulate(4, recombination_rate=1, random_seed=123) + assert ts.num_trees > 1 + total_iterations = 0 + for tree, (_, t1, t2) in zip(ts.trees(), ts.coiterate(ts)): + total_iterations += 1 + assert tree == t1 == t2 + assert ts.num_trees == total_iterations + + def test_intervals(self): + ts1 = msprime.simulate(4, recombination_rate=1, random_seed=1) + assert ts1.num_trees > 1 + one_tree_ts = msprime.simulate(5, random_seed=2) + multi_tree_ts = msprime.simulate(5, recombination_rate=1, random_seed=2) + assert multi_tree_ts.num_trees > 1 + for ts2 in (one_tree_ts, multi_tree_ts): + bp1 = set(ts1.breakpoints()) + bp2 = set(ts2.breakpoints()) + assert bp1 != bp2 + breaks = set() + for interval, t1, t2 in ts1.coiterate(ts2): + assert set(interval) <= set(t1.interval) | set(t2.interval) + breaks.add(interval.left) + breaks.add(interval.right) + assert t1.tree_sequence == ts1 + assert t2.tree_sequence == ts2 + assert breaks == bp1 | bp2 + + def test_simple_ts(self): + nodes = """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 0 1 + 4 0 2 + """ + edges1 = """\ + left right parent child + 0 0.2 3 0,1 + 0 0.2 4 2,3 + 0.2 1 3 2,1 + 0.2 1 4 0,3 + """ + edges2 = """\ + left right parent child + 0 0.8 3 2,1 + 0 0.8 4 0,3 + 0.8 1 3 0,1 + 0.8 1 4 2,3 + """ + ts1 = tskit.load_text(io.StringIO(nodes), io.StringIO(edges1), strict=False) + ts2 = tskit.load_text(io.StringIO(nodes), io.StringIO(edges2), strict=False) + coiterator = ts1.coiterate(ts2) + interval, tree1, tree2 = next(coiterator) + assert interval.left == 0 + assert interval.right == 0.2 + assert tree1 == ts1.at_index(0) + assert tree2 == ts2.at_index(0) + interval, tree1, tree2 = next(coiterator) + assert interval.left == 0.2 + assert interval.right == 0.8 + assert tree1 == ts1.at_index(1) + assert tree2 == ts2.at_index(0) + interval, tree1, tree2 = next(coiterator) + assert interval.left == 0.8 + assert interval.right == 1 + assert tree1 == ts1.at_index(1) + assert tree2 == ts2.at_index(1) + + def test_nonequal_lengths(self): + ts1 = msprime.simulate(4, random_seed=1, length=2) + ts2 = msprime.simulate(4, random_seed=1) + with pytest.raises(ValueError, match="equal sequence length"): + next(ts1.coiterate(ts2)) + + def test_kwargs(self): + ts = msprime.simulate(4, recombination_rate=1, random_seed=123) + for _, t1, t2 in ts.coiterate(ts): + assert t1.num_tracked_samples() == t2.num_tracked_samples() == 0 + for _, t1, t2 in ts.coiterate(ts, tracked_samples=ts.samples()): + assert t1.num_tracked_samples() == t2.num_tracked_samples() == 4 + + class SimplifyTestBase: """ Base class for simplify tests. @@ -5695,9 +5784,7 @@ def verify_keep_input_roots(self, ts, samples): new_to_input_map = { value: key for key, value in enumerate(node_map) if value != tskit.NULL } - for (left, right), input_tree, tree_with_roots in tsutil.coiterate( - ts, ts_with_roots - ): + for (left, right), input_tree, tree_with_roots in ts.coiterate(ts_with_roots): input_roots = input_tree.roots assert len(tree_with_roots.roots) > 0 for root in tree_with_roots.roots: diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 4eddc4cb3f..81ddc9ecdc 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -1282,27 +1282,3 @@ def genealogical_nearest_neighbours(ts, focal, reference_sets): L[L == 0] = 1 A /= L.reshape((len(focal), 1)) return A - - -def coiterate(ts1, ts2, **kwargs): - """ - Returns an iterator over the pairs of trees for each distinct - interval in the specified pair of tree sequences. - """ - if ts1.sequence_length != ts2.sequence_length: - raise ValueError("Tree sequences must be equal length.") - L = ts1.sequence_length - trees1 = ts1.trees(**kwargs) - trees2 = ts2.trees(**kwargs) - tree1 = next(trees1) - tree2 = next(trees2) - right = 0 - while right != L: - left = right - right = min(tree1.interval[1], tree2.interval[1]) - yield (left, right), tree1, tree2 - # Advance - if tree1.interval[1] == right: - tree1 = next(trees1, None) - if tree2.interval[1] == right: - tree2 = next(trees2, None) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 58ef346386..5acfeabf2a 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -62,6 +62,19 @@ class Interval(BaseInterval): + """ + A tuple of 2 numbers, ``[left, right)``, defining an interval over the genome. + + :ivar left: The left hand end of the interval. By convention this value is included + in the interval. + :vartype left: float + :ivar right: The right hand end of the iterval. By convention this value is *not* + included in the interval, i.e. the interval is half-open. + :vartype right: float + :ivar span: The span of the genome covered by this interval, simply ``right-left``. + :vartype span: float + """ + @property def span(self): return self.right - self.left @@ -3933,6 +3946,43 @@ def trees( ) return TreeIterator(tree) + def coiterate(self, other, **kwargs): + """ + Returns an iterator over the pairs of trees for each distinct + interval in the specified pair of tree sequences. + + :param TreeSequence other: The other tree sequence from which to take trees. The + sequence length must be the same as the current tree sequence. + :param \\**kwargs: Further named arguments that will be passed to the + :meth:`.trees` method when constructing the returned trees. + + :return: An iterator returning successive tuples of the form + ``(interval, tree_self, tree_other)``. For example, the first item returned + will consist of an tuple of the initial interval, the first tree of the + current tree sequence, and the first tree of the ``other`` tree sequence; + the ``.left`` attribute of the initial interval will be 0 and the ``.right`` + attribute will be the smallest non-zero breakpoint of the 2 tree sequences. + :rtype: iter(:class:`Interval`, :class:`Tree`, :class:`Tree`) + + """ + if self.sequence_length != other.sequence_length: + raise ValueError("Tree sequences must be of equal sequence length.") + L = self.sequence_length + trees1 = self.trees(**kwargs) + trees2 = other.trees(**kwargs) + tree1 = next(trees1) + tree2 = next(trees2) + right = 0 + while right != L: + left = right + right = min(tree1.interval[1], tree2.interval[1]) + yield Interval(left, right), tree1, tree2 + # Advance + if tree1.interval[1] == right: + tree1 = next(trees1, None) + if tree2.interval[1] == right: + tree2 = next(trees2, None) + def haplotypes( self, *,