Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/python-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ directly, but are the return types for the various iterators provided by the
.. autoclass:: Edge()
:members:

.. autoclass:: Interval()
:members:

.. autoclass:: Site()
:members:

Expand Down
4 changes: 4 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

**Features**

- Expose ``TreeSequence.coiterate()`` method to allow iteration over 2 sequences
simultaneously, aiding comparison of trees from two sequences.
(:user:`jeromekelleher`, :user:`hyanwong`, :issue:`1021`, :pr:`1022`)

- tskit is now supported on, and has wheels for, python3.9.
(:user:`benjeffery`, :issue:`982`, :pr:`907`)

Expand Down
93 changes: 90 additions & 3 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -4758,6 +4758,95 @@ def test_partial_overlap_contradictory_children(self):
tskit.load_text(nodes=nodes, edges=edges, strict=False)


class TestCoiteration:
"""
Test ability to iterate over multiple (currently 2) tree sequences simultaneously
"""

def test_identical_ts(self):
ts = msprime.simulate(4, recombination_rate=1, random_seed=123)
assert ts.num_trees > 1
total_iterations = 0
for tree, (_, t1, t2) in zip(ts.trees(), ts.coiterate(ts)):
total_iterations += 1
assert tree == t1 == t2
assert ts.num_trees == total_iterations

def test_intervals(self):
ts1 = msprime.simulate(4, recombination_rate=1, random_seed=1)
assert ts1.num_trees > 1
one_tree_ts = msprime.simulate(5, random_seed=2)
multi_tree_ts = msprime.simulate(5, recombination_rate=1, random_seed=2)
assert multi_tree_ts.num_trees > 1
for ts2 in (one_tree_ts, multi_tree_ts):
bp1 = set(ts1.breakpoints())
bp2 = set(ts2.breakpoints())
assert bp1 != bp2
breaks = set()
for interval, t1, t2 in ts1.coiterate(ts2):
assert set(interval) <= set(t1.interval) | set(t2.interval)
breaks.add(interval.left)
breaks.add(interval.right)
assert t1.tree_sequence == ts1
assert t2.tree_sequence == ts2
assert breaks == bp1 | bp2

def test_simple_ts(self):
nodes = """\
id is_sample time
0 1 0
1 1 0
2 1 0
3 0 1
4 0 2
"""
edges1 = """\
left right parent child
0 0.2 3 0,1
0 0.2 4 2,3
0.2 1 3 2,1
0.2 1 4 0,3
"""
edges2 = """\
left right parent child
0 0.8 3 2,1
0 0.8 4 0,3
0.8 1 3 0,1
0.8 1 4 2,3
"""
ts1 = tskit.load_text(io.StringIO(nodes), io.StringIO(edges1), strict=False)
ts2 = tskit.load_text(io.StringIO(nodes), io.StringIO(edges2), strict=False)
coiterator = ts1.coiterate(ts2)
interval, tree1, tree2 = next(coiterator)
assert interval.left == 0
assert interval.right == 0.2
assert tree1 == ts1.at_index(0)
assert tree2 == ts2.at_index(0)
interval, tree1, tree2 = next(coiterator)
assert interval.left == 0.2
assert interval.right == 0.8
assert tree1 == ts1.at_index(1)
assert tree2 == ts2.at_index(0)
interval, tree1, tree2 = next(coiterator)
assert interval.left == 0.8
assert interval.right == 1
assert tree1 == ts1.at_index(1)
assert tree2 == ts2.at_index(1)

def test_nonequal_lengths(self):
ts1 = msprime.simulate(4, random_seed=1, length=2)
ts2 = msprime.simulate(4, random_seed=1)
with pytest.raises(ValueError, match="equal sequence length"):
next(ts1.coiterate(ts2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't need the next here I think

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do. If I remove it, we don't get the error raised, because we don't actually start the generator going?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generator functions are only executed when after the first next so yeah, we do need it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, my bad.


def test_kwargs(self):
ts = msprime.simulate(4, recombination_rate=1, random_seed=123)
for _, t1, t2 in ts.coiterate(ts):
assert t1.num_tracked_samples() == t2.num_tracked_samples() == 0
for _, t1, t2 in ts.coiterate(ts, tracked_samples=ts.samples()):
assert t1.num_tracked_samples() == t2.num_tracked_samples() == 4


class SimplifyTestBase:
"""
Base class for simplify tests.
Expand Down Expand Up @@ -5695,9 +5784,7 @@ def verify_keep_input_roots(self, ts, samples):
new_to_input_map = {
value: key for key, value in enumerate(node_map) if value != tskit.NULL
}
for (left, right), input_tree, tree_with_roots in tsutil.coiterate(
ts, ts_with_roots
):
for (left, right), input_tree, tree_with_roots in ts.coiterate(ts_with_roots):
input_roots = input_tree.roots
assert len(tree_with_roots.roots) > 0
for root in tree_with_roots.roots:
Expand Down
24 changes: 0 additions & 24 deletions python/tests/tsutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,27 +1282,3 @@ def genealogical_nearest_neighbours(ts, focal, reference_sets):
L[L == 0] = 1
A /= L.reshape((len(focal), 1))
return A


def coiterate(ts1, ts2, **kwargs):
"""
Returns an iterator over the pairs of trees for each distinct
interval in the specified pair of tree sequences.
"""
if ts1.sequence_length != ts2.sequence_length:
raise ValueError("Tree sequences must be equal length.")
L = ts1.sequence_length
trees1 = ts1.trees(**kwargs)
trees2 = ts2.trees(**kwargs)
tree1 = next(trees1)
tree2 = next(trees2)
right = 0
while right != L:
left = right
right = min(tree1.interval[1], tree2.interval[1])
yield (left, right), tree1, tree2
# Advance
if tree1.interval[1] == right:
tree1 = next(trees1, None)
if tree2.interval[1] == right:
tree2 = next(trees2, None)
50 changes: 50 additions & 0 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@


class Interval(BaseInterval):
"""
A tuple of 2 numbers, ``[left, right)``, defining an interval over the genome.

:ivar left: The left hand end of the interval. By convention this value is included
in the interval.
:vartype left: float
:ivar right: The right hand end of the iterval. By convention this value is *not*
included in the interval, i.e. the interval is half-open.
:vartype right: float
:ivar span: The span of the genome covered by this interval, simply ``right-left``.
:vartype span: float
"""

@property
def span(self):
return self.right - self.left
Expand Down Expand Up @@ -3933,6 +3946,43 @@ def trees(
)
return TreeIterator(tree)

def coiterate(self, other, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, might make more sense to put this next to the trees iterator in the file? It's nice to keep methods somewhat grouped by topic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Erm. That's where it is, isn't it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, so it is. GitHub's way of showing context is a bit weird sometimes

"""
Returns an iterator over the pairs of trees for each distinct
interval in the specified pair of tree sequences.

:param TreeSequence other: The other tree sequence from which to take trees. The
sequence length must be the same as the current tree sequence.
:param \\**kwargs: Further named arguments that will be passed to the
:meth:`.trees` method when constructing the returned trees.

:return: An iterator returning successive tuples of the form
``(interval, tree_self, tree_other)``. For example, the first item returned
will consist of an tuple of the initial interval, the first tree of the
current tree sequence, and the first tree of the ``other`` tree sequence;
the ``.left`` attribute of the initial interval will be 0 and the ``.right``
attribute will be the smallest non-zero breakpoint of the 2 tree sequences.
:rtype: iter(:class:`Interval`, :class:`Tree`, :class:`Tree`)

"""
if self.sequence_length != other.sequence_length:
raise ValueError("Tree sequences must be of equal sequence length.")
L = self.sequence_length
trees1 = self.trees(**kwargs)
trees2 = other.trees(**kwargs)
tree1 = next(trees1)
tree2 = next(trees2)
right = 0
while right != L:
left = right
right = min(tree1.interval[1], tree2.interval[1])
yield Interval(left, right), tree1, tree2
# Advance
if tree1.interval[1] == right:
tree1 = next(trees1, None)
if tree2.interval[1] == right:
tree2 = next(trees2, None)

def haplotypes(
self,
*,
Expand Down