From cd99a785368f5c6e328c3eaf8531f7399d32ef61 Mon Sep 17 00:00:00 2001 From: Georgia Tsambos Date: Wed, 11 Mar 2020 14:29:58 +1100 Subject: [PATCH 1/4] Added functions for computing 'naive' versions of IBD. Baby commit of IBD tests. Working, but messy Python implementation of fast IBD algorithm. ibdFinder now returns segments rather than intervals. verify_equal_ibd function has been changed Bug fix. Rudimentary tests of IbdFinder for path_IBD=true, mrca_IBD=true seem to work. --- python/tests/ibd.py | 293 +++++++++++++++++++++++++++++++++++++++ python/tests/test_ibd.py | 241 ++++++++++++++++++++++++++++++++ 2 files changed, 534 insertions(+) create mode 100644 python/tests/ibd.py create mode 100644 python/tests/test_ibd.py diff --git a/python/tests/ibd.py b/python/tests/ibd.py new file mode 100644 index 0000000000..17fb1f8759 --- /dev/null +++ b/python/tests/ibd.py @@ -0,0 +1,293 @@ +# MIT License +# +# Copyright (c) 2018-2020 Tskit Developers +# Copyright (c) 2015-2018 University of Oxford +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +Python implementation of the IBD-finding algorithms. +""" + +import tskit +import numpy as np +import itertools +import sys + + +class Segment(object): + """ + A class representing a single segment. Each segment has a left and right, + denoting the loci over which it spans, a node and a next, giving the next + in the chain. + + The node it records is the *output* node ID. + """ + def __init__(self, left=None, right=None, node=None, next=None): + self.left = left + self.right = right + self.node = node + self.next = next + + def __str__(self): + s = "({}-{}->{}:next={})".format( + self.left, self.right, self.node, repr(self.next)) + return s + + def __repr__(self): + return repr((self.left, self.right, self.node)) + + def __eq__(self, other): + return (self.left == other.left and self.right == other.right and self.node == other.node) + + def __lt__(self, other): + return (self.node, self.left, self.right) < (other.node, other.left, other.right) + + +class IbdFinder(object): + """ + Finds all IBD relationships between specified samples in a tree sequence. + """ + + def __init__( + self, + ts, + samples, + min_length=0): + + self.ts = ts + self.samples = samples + self.min_length = min_length + self.current_parent = self.ts.tables.edges.parent[0] + self.A_head = [None for _ in range(ts.num_nodes)] + self.A_tail = [None for _ in range(ts.num_nodes)] + self.tables = tskit.TableCollection(sequence_length=ts.sequence_length) + # self.ibd_segments = dict.fromkeys(itertools.combinations(self.ts.samples(), 2), []) + + + def find_ibd_segments_of_length(self, min_length=0): + + # 1 + A = [[] for n in range(0, self.ts.num_nodes)] + ibd_segments = dict.fromkeys(itertools.combinations(self.ts.samples(), 2), []) + edges = self.ts.edges() + parent_list = self.list_of_parents() ## Needed for memory-pruning step + + # 2 + mygen = iter(self.ts.edges()) + e = next(mygen) + + # 3 + while e is not None: + + # 3a + S = [] + self.current_parent = e.parent + + # 3b + while e is not None and self.current_parent == e.parent: + # Create the list S of immediate descendants of u. + S.append(Segment(e.left, e.right, e.child)) + # if e.id < edges.num_rows - 1: + if e.id < self.ts.num_edges - 1: + e = next(mygen) + continue + else: + e = None + # break + + # 3c + for seg in S: + # Create A[u] from S. + # Do we still need to do the initialisation if the below is there?? + u = seg.node + if u in self.ts.samples(): + A[self.current_parent].append([seg]) + else: + list_to_add = [] + for s in A[u]: + l = (max(seg.left, s.left), min(seg.right, s.right)) + if l[1] - l[0] > 0: + list_to_add.append(Segment(l[0], l[1], s.node)) + A[self.current_parent].append(list_to_add) + + # d. Squash + # A[self.current_parent] = self.squash(A[self.current_parent]) + + # e. Process A[self.current_parent] + if len(A[self.current_parent]) > 1: + new_segs, nodes_to_remove = self.update_A_and_find_ibd_segs( + A[self.current_parent], ibd_segments) + + # e. Add any new IBD segments discovered. + for key, val in new_segs.items(): + for v in val: + if len(ibd_segments[key]) == 0: + ibd_segments[key] = [v] + else: + ibd_segments[key].append(v) + + # g. Remove elements of A[u] if they are no longer needed. + ## (memory-pruning step) + for n in nodes_to_remove: + if self.current_parent in parent_list[n]: + parent_list[n].remove(self.current_parent) + if len(parent_list[n]) == 0: + A[n] = [] + + # Unlist the ancestral segments in A. + A[self.current_parent] = list(itertools.chain(*A[self.current_parent])) + + # 4 + return ibd_segments + + + def update_A_and_find_ibd_segs(self, ancestral_segs, ibd_segments, mrca_ibd=False): + + new_segments = dict.fromkeys(itertools.combinations(self.ts.samples(), 2), []) + num_coalescing_sets = len(ancestral_segs) + index_pairs = list(itertools.combinations(range(0, num_coalescing_sets), 2)) + + for setpair in index_pairs: + for seg0 in ancestral_segs[setpair[0]]: + for seg1 in ancestral_segs[setpair[1]]: + + if seg0.node == seg1.node: + continue + left = max(seg0.left, seg1.left) + right = min(seg0.right, seg1.right) + if left >= right: + continue + nodes = [seg0.node, seg1.node] + nodes.sort() + + if mrca_ibd: + pass # for now + # existing_segs = ibd_segments[(nodes[0], nodes[1])].copy() + # if right - left > self.min_length: + # if len(existing_segs) == 0: + # new_segments[(nodes[0], nodes[1])] = [Segment(left, right, self.current_parent)] + # existing_segs.append(Segment(left, right, self.current_parent)) + # else: + # for i in existing_segs: + # # no overlap. + # if right <= i.left or left >= i.right: + # if len(new_segments[(nodes[0], nodes[1])]) == 0: + # new_segments[(nodes[0], nodes[1])] = [Segment(left, right, self.current_parent)] + # else: + # new_segments[(nodes[0], nodes[1])].append(Segment(left, right, self.current_parent)) + # existing_segs.append(Segment(left, right, self.current_parent)) + # # partial overlap -- does this even happen? + # elif (left < i.left and right < i.right) or (i.left < left and i.right < right): + # print('partial overlap') + # Yes, but I think it's okay to leave these segments... + else: + if len(new_segments[(nodes[0], nodes[1])]) == 0: + new_segments[(nodes[0], nodes[1])] = [Segment(left, right, self.current_parent)] + else: + new_segments[(nodes[0], nodes[1])].append(Segment(left, right, self.current_parent)) + + # iv. specify elements of A that can be removed (for memory-pruning step) + processed_child_nodes = [] + for seglist in ancestral_segs: + processed_child_nodes += [seg.node for seg in seglist] + processed_child_nodes = list(set(processed_child_nodes)) + + return new_segments, processed_child_nodes + + + def list_of_parents(self): + parents = [[] for i in range(0, self.ts.num_nodes)] + edges = self.ts.tables.edges + for e in edges: + if len(parents[e.child]) == 0 or e.parent != parents[e.child][-1]: + parents[e.child].append(e.parent) + return parents + + + # def squash(self, segment_lists): + + # # Concatenate the input lists and record the number of + # # segments in each. + # A_u = [] + # num_desc_edges = [] + # for L in segment_lists: + # for l in L: + # A_u.append(l) + # num_desc_edges.append(len(L)) + + # # Sort the list, keeping track of the original order. + # sorted_A = sorted(enumerate(A_u), key=lambda i:i[1]) + + # # Squash the list. + # next_ind = len(sorted_A) + # inds_to_remove = [] + # ind = 1 + # while ind < len(sorted_A): + # if sorted_A[ind][1].node == sorted_A[ind - 1][1].node: + # if sorted_A[ind][1].right > sorted_A[ind - 1][1].right and\ + # sorted_A[ind][1].left <= sorted_A[ind - 1][1].right: + # # Squash the previous int into the current one. + # sorted_A[ind][1].left = sorted_A[ind - 1][1].left + # # Flag the interval to be removed. + # inds_to_remove.append(ind - 1) + # # Change order index. + # sorted_A[ind] = (next_ind, sorted_A[ind][1]) + # next_ind += 1 + # ind += 1 + + # # Remove any unnecessary list items. + # for i in reversed(inds_to_remove): + # # Needs to be done in reverse order!! + # sorted_A.pop(i) + + # # Restore the original order as lists of lists. + # cum_sum = np.cumsum(num_desc_edges) + # squashed_sorted_A = [[] for _ in range(0, next_ind)] + # for a in sorted_A: + # ind = a[0] + # if ind < cum_sum[-1]: + # s = 0 + # while s < len(cum_sum): + # if a[0] < cum_sum[s]: + # squashed_sorted_A[s].append(a[1]) + # break + # s += 1 + + # else: + # squashed_sorted_A[ind].append(a[1]) + + # # Remove lists of length 0. + # squashed_sorted_A = [_ for _ in squashed_sorted_A if len(_) > 0] + + # return squashed_sorted_A + + +if __name__ == "__main__": + # Simple CLI for running simplifier/ancestor mapping above. + + ts = tskit.load(sys.argv[1]) + s = IbdFinder(ts, samples = ts.samples()) + all_segs = s.find_ibd_segments_of_length() + + if sys.argv[2] is not None and sys.argv[3] is not None: + sample0 = int(sys.argv[2]) + sample1 = int(sys.argv[3]) + print(all_segs[(sample0, sample1)]) + else: + print(all_segs) diff --git a/python/tests/test_ibd.py b/python/tests/test_ibd.py new file mode 100644 index 0000000000..efa5dc7119 --- /dev/null +++ b/python/tests/test_ibd.py @@ -0,0 +1,241 @@ + +""" +Tests of IBD finding algorithms. +""" +import unittest +import sys +import random +import io +import itertools + +import tests as tests +import tests.ibd as ibd + +import tskit +import msprime +import numpy as np + +# Functions for computing IBD 'naively'. + +class Segment(object): + """ + A class representing a single segment. Each segment has a left and right, + denoting the loci over which it spans, a node and a next, giving the next + in the chain. + + The node it records is the *output* node ID. + """ + def __init__(self, left=None, right=None, node=None, next=None): + self.left = left + self.right = right + self.node = node + self.next = next + + def __str__(self): + s = "({}-{}->{}:next={})".format( + self.left, self.right, self.node, repr(self.next)) + return s + + def __repr__(self): + return repr((self.left, self.right, self.node)) + + def __lt__(self, other): + return (self.node, self.left, self.right) < (other.node, other.left, other.right) + + +def get_ibd(sample0, sample1, treeSequence, min_length=0, max_time=None, + path_ibd=True, mrca_ibd=False): + """ + Returns all IBD segments for a given pair of nodes in a tree + using a naive algorithm. + """ + + ibd_list = [] + ts, node_map = treeSequence.simplify(samples=[sample0, sample1], keep_unary=True, + map_nodes=True) + node_map = node_map.tolist() + + for n in ts.nodes(): + + if max_time is not None and n.time > max_time: + break + + node_id = n.id + interval_list = [] + if n.flags == 1: + continue + + prev_dict = None + for t in ts.trees(): + + if len(list(t.nodes(n.id))) == 1 or t.num_samples(n.id) < 2: + continue + if mrca_ibd and n.id != t.mrca(0, 1): + continue + + current_int = t.get_interval() + if len(interval_list) == 0: + interval_list.append(current_int) + else: + prev_int = interval_list[-1] + if not path_ibd and prev_int[1] == current_int[0]: + interval_list[-1] = (prev_int[0], current_int[1]) + elif prev_dict is not None and subtrees_are_equal(t, prev_dict, node_id): + interval_list[-1] = (prev_int[0], current_int[1]) + else: + interval_list.append(current_int) + + prev_dict = t.get_parent_dict() + + for interval in interval_list: + if min_length == 0 or interval[1] - interval[0] > min_length: + orig_id = node_map.index(node_id) + ibd_list.append(Segment(interval[0], interval[1], orig_id)) + + return(ibd_list) + + +def get_ibd_all_pairs(treeSequence, samples=None, min_length=0, max_time=None, + path_ibd=True, mrca_ibd=False): + + ibd_dict = {} + + if samples is None: + samples = treeSequence.samples().tolist() + + pairs = itertools.combinations(samples, 2) + for pair in pairs: + ibd_list = get_ibd(pair[0], pair[1], treeSequence, + min_length=min_length, max_time=max_time, + path_ibd=path_ibd, mrca_ibd=mrca_ibd) + if len(ibd_list) > 0: + ibd_dict[pair] = ibd_list + + return(ibd_dict) + + +def subtrees_are_equal(tree1, pdict0, root): + pdict1 = tree1.get_parent_dict() + if root not in pdict0.values() or root not in pdict1.values(): + return False + leaves1 = set(tree1.leaves(root)) + for l in leaves1: + node = l + while node != root: + p1 = pdict1[node] + if p1 not in pdict0.values(): + return False + p0 = pdict0[node] + if p0 != p1: + return False + node = p1 + + return True + + +def verify_equal_ibd(treeSequence): + """ + Calculates IBD segments using both the 'naive' and sophisticated algorithms, + verifies that the same output is produced. + NB: May be good to expand this in the future so that many different combos + of IBD options are tested simultaneously (all the MRCA and path-IBD combos), + for example. + """ + ts = treeSequence + ibd0 = ibd.IbdFinder(ts, samples = ts.samples()) + ibd0 = ibd0.find_ibd_segments_of_length() + ibd1 = get_ibd_all_pairs(ts, path_ibd=True, mrca_ibd=True) + + for key0, val0 in ibd0.items(): + # print(key0) + assert key0 in ibd1.keys() + val1 = ibd1[key0] + val0.sort() + val1.sort() + + # print('IBD from IBDFinder') + # print(val0) + # print('IBD from naive function') + # print(val1) + + if val0 is None: # Get rid of this later -- don't want empty dict values at all + assert val1 is None + continue + elif val1 is None: + # print(val0) + # print(val1) + assert val0 is None + assert len(val0) == len(val1) + for i in range(0, len(val0)): + assert val0[i] == val1[i] + + +class TestIbdByLength(unittest.TestCase): + """ + Tests of length-based IBD function. + """ + # 11 * * * + # / \ * * * 10 + # / \ * 9 * * / \ + # / \ * / \ * 8 * / 8 + # | | * / \ * / \ * / / \ + # | 7 * / 7 * / 7 * / / 7 + # | / \ * | / \ * / / \ * / / / \ + # | / 6 * | / 6 * / / 6 * | / | | + # 5 | / \ * 5 | / \ * 5 | / \ * | 5 | | + # / \ | / \ * / \ | / \ * / \ | / \ * | / \ | | + # 0 4 1 2 3 * 0 4 1 2 3 * 0 4 1 2 3 * 3 0 4 1 2 + # + # ------------------------------------------------------------------------------ + # 0 0.10 0.50 0.75 1.00 + + small_tree_ex_nodes = """\ + id flags population individual time + 0 1 0 -1 0.00000000000000 + 1 1 0 -1 0.00000000000000 + 2 1 0 -1 0.00000000000000 + 3 1 0 -1 0.00000000000000 + 4 1 0 -1 0.00000000000000 + 5 0 0 -1 1.00000000000000 + 6 0 0 -1 2.00000000000000 + 7 0 0 -1 3.00000000000000 + 8 0 0 -1 4.00000000000000 + 9 0 0 -1 5.00000000000000 + 10 0 0 -1 6.00000000000000 + 11 0 0 -1 7.00000000000000 + """ + small_tree_ex_edges = """\ + id left right parent child + 0 0.00000000 1.00000000 5 0 + 1 0.00000000 1.00000000 5 4 + 2 0.00000000 0.75000000 6 2 + 3 0.00000000 0.75000000 6 3 + 4 0.00000000 1.00000000 7 1 + 5 0.75000000 1.00000000 7 2 + 6 0.00000000 0.75000000 7 6 + 7 0.50000000 1.00000000 8 5 + 8 0.50000000 1.00000000 8 7 + 9 0.10000000 0.50000000 9 5 + 10 0.10000000 0.50000000 9 7 + 11 0.75000000 1.00000000 10 3 + 12 0.75000000 1.00000000 10 8 + 13 0.00000000 0.10000000 11 5 + 14 0.00000000 0.10000000 11 7 + """ + + + def test_canonical_example1(self): + ts0 = msprime.simulate(sample_size=5, recombination_rate=.5, random_seed=2) + verify_equal_ibd(ts0) + + def test_canonical_example2(self): + ts1 = msprime.simulate(sample_size=5, recombination_rate=.5, random_seed=23) + verify_equal_ibd(ts1) + + def test_canonical_example3(self): + ts2 = msprime.simulate(sample_size=5, recombination_rate=.5, random_seed=232) + verify_equal_ibd(ts2) + + def test_random_example(self): + ts_r = msprime.simulate(sample_size=10, recombination_rate=.3, random_seed=726) + verify_equal_ibd(ts_r) From fa6d2f95b1e660b66fc301ad1b64e68bbe101d61 Mon Sep 17 00:00:00 2001 From: Georgia Tsambos Date: Thu, 19 Mar 2020 17:02:48 +1100 Subject: [PATCH 2/4] Improved CLI for Python implementation of IBD finder. min-length and max-time argument of IBDFinder should now be working. Added a new test class, TestIbdTopologies. Added functions to test for equality of ibd segments. extra test topologies in test_ibd.py Converted list of ancetral segments into more C-like objects. --- python/tests/ibd.py | 169 +++++++++++++++----- python/tests/test_ibd.py | 338 ++++++++++++++++++++++++++++++++++----- 2 files changed, 421 insertions(+), 86 deletions(-) diff --git a/python/tests/ibd.py b/python/tests/ibd.py index 17fb1f8759..6d68dc4f7b 100644 --- a/python/tests/ibd.py +++ b/python/tests/ibd.py @@ -28,6 +28,7 @@ import numpy as np import itertools import sys +import argparse class Segment(object): @@ -59,6 +60,27 @@ def __lt__(self, other): return (self.node, self.left, self.right) < (other.node, other.left, other.right) +class SegmentList(object): + """ + A class representing a list of segments that are descended from a given ancestral node + via a particular child of the ancestor. + """ + def __init__(self, head=None, tail=None, next=None): + self.head = head + self.tail = tail + self.next = next + + def __str__(self): + s = "head={},tail={},next={}".format( + self.head, self.right, repr(self.next)) + return s + + def __repr__(self): + s = "[{}...]".format(repr(self.next.head)) + return s + + + class IbdFinder(object): """ Finds all IBD relationships between specified samples in a tree sequence. @@ -67,23 +89,28 @@ class IbdFinder(object): def __init__( self, ts, - samples, - min_length=0): + samples=None, + min_length=0, + max_time=None): self.ts = ts - self.samples = samples + if samples is None: + samples = ts.samples() + else: + self.samples = samples self.min_length = min_length + self.max_time = max_time self.current_parent = self.ts.tables.edges.parent[0] - self.A_head = [None for _ in range(ts.num_nodes)] - self.A_tail = [None for _ in range(ts.num_nodes)] - self.tables = tskit.TableCollection(sequence_length=ts.sequence_length) + self.current_time = 0 + self.A = [[] for _ in range(ts.num_nodes)] # Descendant segments + self.tables = self.ts.tables # self.ibd_segments = dict.fromkeys(itertools.combinations(self.ts.samples(), 2), []) - def find_ibd_segments_of_length(self, min_length=0): + def find_ibd_segments(self, min_length=0, max_time=None): # 1 - A = [[] for n in range(0, self.ts.num_nodes)] + # A = [[] for n in range(0, self.ts.num_nodes)] ibd_segments = dict.fromkeys(itertools.combinations(self.ts.samples(), 2), []) edges = self.ts.edges() parent_list = self.list_of_parents() ## Needed for memory-pruning step @@ -96,43 +123,53 @@ def find_ibd_segments_of_length(self, min_length=0): while e is not None: # 3a - S = [] + S_head = None + S_tail = None self.current_parent = e.parent + self.current_time = self.tables.nodes.time[self.current_parent] + if self.max_time is not None and self.current_time > self.max_time: + # Stop looking for IBD segments once the + # processed nodes are older than the max time. + break # 3b while e is not None and self.current_parent == e.parent: # Create the list S of immediate descendants of u. - S.append(Segment(e.left, e.right, e.child)) - # if e.id < edges.num_rows - 1: + seg = Segment(e.left, e.right, e.child) + if S_head is None: + S_head = seg + S_tail = seg + else: + S_tail.next = seg + S_tail = S_tail.next if e.id < self.ts.num_edges - 1: e = next(mygen) continue else: e = None - # break # 3c - for seg in S: + while S_head is not None: # Create A[u] from S. # Do we still need to do the initialisation if the below is there?? - u = seg.node + u = S_head.node if u in self.ts.samples(): - A[self.current_parent].append([seg]) + self.A[self.current_parent].append([S_head]) else: list_to_add = [] - for s in A[u]: - l = (max(seg.left, s.left), min(seg.right, s.right)) + for s in self.A[u]: + l = (max(S_head.left, s.left), min(S_head.right, s.right)) if l[1] - l[0] > 0: list_to_add.append(Segment(l[0], l[1], s.node)) - A[self.current_parent].append(list_to_add) + self.A[self.current_parent].append(list_to_add) + S_head = S_head.next # d. Squash - # A[self.current_parent] = self.squash(A[self.current_parent]) + # A[self.current_parent] = self.squash(self.A[self.current_parent]) # e. Process A[self.current_parent] - if len(A[self.current_parent]) > 1: - new_segs, nodes_to_remove = self.update_A_and_find_ibd_segs( - A[self.current_parent], ibd_segments) + if len(self.A[self.current_parent]) > 1: + new_segs, nodes_to_remove = self.update_A_and_find_ibd_segs(ibd_segments) # e. Add any new IBD segments discovered. for key, val in new_segs.items(): @@ -142,30 +179,30 @@ def find_ibd_segments_of_length(self, min_length=0): else: ibd_segments[key].append(v) - # g. Remove elements of A[u] if they are no longer needed. + # g. Remove elements of self.A[u] if they are no longer needed. ## (memory-pruning step) for n in nodes_to_remove: if self.current_parent in parent_list[n]: parent_list[n].remove(self.current_parent) if len(parent_list[n]) == 0: - A[n] = [] + self.A[n] = [] - # Unlist the ancestral segments in A. - A[self.current_parent] = list(itertools.chain(*A[self.current_parent])) + # Unlist the ancestral segments in self.A. + self.A[self.current_parent] = list(itertools.chain(*self.A[self.current_parent])) # 4 return ibd_segments - def update_A_and_find_ibd_segs(self, ancestral_segs, ibd_segments, mrca_ibd=False): + def update_A_and_find_ibd_segs(self, ibd_segments): new_segments = dict.fromkeys(itertools.combinations(self.ts.samples(), 2), []) - num_coalescing_sets = len(ancestral_segs) + num_coalescing_sets = len(self.A[self.current_parent]) index_pairs = list(itertools.combinations(range(0, num_coalescing_sets), 2)) for setpair in index_pairs: - for seg0 in ancestral_segs[setpair[0]]: - for seg1 in ancestral_segs[setpair[1]]: + for seg0 in self.A[self.current_parent][setpair[0]]: + for seg1 in self.A[self.current_parent][setpair[1]]: if seg0.node == seg1.node: continue @@ -176,8 +213,8 @@ def update_A_and_find_ibd_segs(self, ancestral_segs, ibd_segments, mrca_ibd=Fals nodes = [seg0.node, seg1.node] nodes.sort() - if mrca_ibd: - pass # for now + # if mrca_ibd: + # pass # for now # existing_segs = ibd_segments[(nodes[0], nodes[1])].copy() # if right - left > self.min_length: # if len(existing_segs) == 0: @@ -196,7 +233,8 @@ def update_A_and_find_ibd_segs(self, ancestral_segs, ibd_segments, mrca_ibd=Fals # elif (left < i.left and right < i.right) or (i.left < left and i.right < right): # print('partial overlap') # Yes, but I think it's okay to leave these segments... - else: + # else: + if right - left > self.min_length: if len(new_segments[(nodes[0], nodes[1])]) == 0: new_segments[(nodes[0], nodes[1])] = [Segment(left, right, self.current_parent)] else: @@ -204,7 +242,7 @@ def update_A_and_find_ibd_segs(self, ancestral_segs, ibd_segments, mrca_ibd=Fals # iv. specify elements of A that can be removed (for memory-pruning step) processed_child_nodes = [] - for seglist in ancestral_segs: + for seglist in self.A[self.current_parent]: processed_child_nodes += [seg.node for seg in seglist] processed_child_nodes = list(set(processed_child_nodes)) @@ -279,15 +317,60 @@ def list_of_parents(self): if __name__ == "__main__": - # Simple CLI for running simplifier/ancestor mapping above. + # Simple CLI for running IBDFinder. + + parser = argparse.ArgumentParser(description="Command line interface for the IBDFinder.") + + parser.add_argument('--infile', + type=str, + dest='infile', + nargs=1, + metavar="IN_FILE", + help="The tree sequence to be analysed.") + + parser.add_argument('--min-length', + type=float, + dest='min_length', + nargs=1, + metavar="MIN_LENGTH", + help="Only segments longer than this cutoff will be returned.") + + parser.add_argument('--max-time', + type=float, + dest='max_time', + nargs=1, + metavar="MAX_TIME", + help="Only segments younger this time will be returned.") + + parser.add_argument('--samples', + type=int, + dest='samples', + nargs=2, + metavar="SAMPLES", + help="If provided, only IBD relationships between the given node pair are returned." + ) + + args = parser.parse_args() + + ts = tskit.load(args.infile[0]) + if args.min_length is None: + min_length = 0 + else: + min_length = args.min_length[0] + if args.max_time is None: + max_time = None + else: + max_time = args.max_time[0] - ts = tskit.load(sys.argv[1]) - s = IbdFinder(ts, samples = ts.samples()) - all_segs = s.find_ibd_segments_of_length() + s = IbdFinder(ts, samples = ts.samples(), + min_length=min_length, max_time=max_time) + all_segs = s.find_ibd_segments() - if sys.argv[2] is not None and sys.argv[3] is not None: - sample0 = int(sys.argv[2]) - sample1 = int(sys.argv[3]) - print(all_segs[(sample0, sample1)]) - else: + + if args.samples is None: print(all_segs) + else: + samples = args.samples + print(all_segs[(samples[0], samples[1])]) + + diff --git a/python/tests/test_ibd.py b/python/tests/test_ibd.py index efa5dc7119..b3f31619cb 100644 --- a/python/tests/test_ibd.py +++ b/python/tests/test_ibd.py @@ -17,32 +17,6 @@ # Functions for computing IBD 'naively'. -class Segment(object): - """ - A class representing a single segment. Each segment has a left and right, - denoting the loci over which it spans, a node and a next, giving the next - in the chain. - - The node it records is the *output* node ID. - """ - def __init__(self, left=None, right=None, node=None, next=None): - self.left = left - self.right = right - self.node = node - self.next = next - - def __str__(self): - s = "({}-{}->{}:next={})".format( - self.left, self.right, self.node, repr(self.next)) - return s - - def __repr__(self): - return repr((self.left, self.right, self.node)) - - def __lt__(self, other): - return (self.node, self.left, self.right) < (other.node, other.left, other.right) - - def get_ibd(sample0, sample1, treeSequence, min_length=0, max_time=None, path_ibd=True, mrca_ibd=False): """ @@ -90,7 +64,7 @@ def get_ibd(sample0, sample1, treeSequence, min_length=0, max_time=None, for interval in interval_list: if min_length == 0 or interval[1] - interval[0] > min_length: orig_id = node_map.index(node_id) - ibd_list.append(Segment(interval[0], interval[1], orig_id)) + ibd_list.append(ibd.Segment(interval[0], interval[1], orig_id)) return(ibd_list) @@ -143,32 +117,310 @@ def verify_equal_ibd(treeSequence): """ ts = treeSequence ibd0 = ibd.IbdFinder(ts, samples = ts.samples()) - ibd0 = ibd0.find_ibd_segments_of_length() + ibd0 = ibd0.find_ibd_segments() ibd1 = get_ibd_all_pairs(ts, path_ibd=True, mrca_ibd=True) for key0, val0 in ibd0.items(): - # print(key0) assert key0 in ibd1.keys() val1 = ibd1[key0] val0.sort() val1.sort() - # print('IBD from IBDFinder') - # print(val0) - # print('IBD from naive function') - # print(val1) - if val0 is None: # Get rid of this later -- don't want empty dict values at all assert val1 is None continue elif val1 is None: - # print(val0) - # print(val1) assert val0 is None assert len(val0) == len(val1) for i in range(0, len(val0)): assert val0[i] == val1[i] +def ibd_is_equal(dict1, dict2): + """ + Verifies that two dictionaries have the same keys, and that + the set of items corresponding to each key is identical. + Used to check identical IBD output. + NOTE: is there a better/neater way to do this??? + """ + if len(dict1) != len(dict2): + return False + for key1, val1 in dict1.items(): + if key1 not in dict2.keys(): + return False + val2 = dict2[key1] + if not segment_lists_are_equal(val1, val2): + return False + + return True + + +def segment_lists_are_equal(val1, val2): + + if len(val1) != len(val2): + return False + + val1.sort() + val2.sort() + + if val1 is None: # get rid of this later -- we don't any empty dict values! + if val2 is not None: + return False + elif val2 is None: + if val1 is not None: + return False + for i in range(len(val1)): + if val1[i] != val2[i]: + # print(val1q, val2) + # print(val1[i], val2[i]) + return False + + return True + + +class TestIbdTopologies(unittest.TestCase): + + def test_single_binary_tree(self): + # + # 2 4 + # / \ + # 1 3 \ + # / \ \ + # 0 0 1 2 + nodes = io.StringIO( + """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 0 1 + 4 0 2 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0 1 3 0,1 + 0 1 4 2,3 + """ + ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + # Basic test + ibd_f = ibd.IbdFinder(ts) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = { + (0, 1): [ibd.Segment(0.0, 1.0, 3)], + (0, 2): [ibd.Segment(0.0, 1.0, 4)], + (1, 2): [ibd.Segment(0.0, 1.0, 4)]} + assert ibd_is_equal(ibd_segs, true_segs) + # Max time = 1.5 + ibd_f = ibd.IbdFinder(ts, max_time=1.5) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = {(0, 1): [ibd.Segment(0.0, 1.0, 3)], + (0, 2): [], (1, 2): []} + assert ibd_is_equal(ibd_segs, true_segs) + # # Min length = 2 + ibd_f = ibd.IbdFinder(ts, min_length=2) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = {(0, 1): [], (0, 2): [], (1, 2): []} + assert ibd_is_equal(ibd_segs, true_segs) + + def test_two_samples_two_trees(self): + # + # 2 + # | 3 + # 1 2 | / \ + # / \ | / \ + # 0 0 1 | 0 1 + #|------------|----------| + #0.0 0.4 1.0 + nodes = io.StringIO( + """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + 3 0 1.5 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0 0.4 2 0,1 + 0.4 1.0 3 0,1 + """ + ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + # Basic test + ibd_f = ibd.IbdFinder(ts) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = {(0, 1): [ibd.Segment(0.0, 0.4, 2), ibd.Segment(0.4, 1.0, 3)]} + assert ibd_is_equal(ibd_segs, true_segs) + # Max time = 1.2 + ibd_f = ibd.IbdFinder(ts, max_time=1.2) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = {(0, 1): [ibd.Segment(0.0, 0.4, 2)]} + assert ibd_is_equal(ibd_segs, true_segs) + # Min length = 0.5 + ibd_f = ibd.IbdFinder(ts, max_time=1.2) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = {(0, 1): [ibd.Segment(0.0, 0.4, 2)]} + assert ibd_is_equal(ibd_segs, true_segs) + + def test_unrelated_samples(self): + # + # 2 3 + # | | + # 0 1 + + nodes = io.StringIO( + """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + 3 0 1 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0 1 2 0 + 0 1 3 1 + """ + ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + ibd_f = ibd.IbdFinder(ts) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = {(0,1): []} + assert ibd_is_equal(ibd_segs, true_segs) + + def test_no_samples(self): + # + # 2 + # / \ + # / \ + # / \ + # (0) (1) + nodes = io.StringIO( + """\ + id is_sample time + 0 0 0 + 1 0 0 + 2 0 1 + 3 0 1 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0 1 2 0 + 0 1 3 1 + """ + ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + ibd_f = ibd.IbdFinder(ts) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = {} + assert ibd_is_equal(ibd_segs, true_segs) + +class TestIbdSamplesAreDescendants(unittest.TestCase): + # + # 2 + # | + # 1 + # | + # 0 + nodes = io.StringIO( + """\ + id is_sample time + 0 1 0 + 1 1 1 + 2 0 2 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0 1 1 0 + 0 1 2 1 + """ + ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + + def test_basic(self): + ts = self.ts + ibd_f = ibd.IbdFinder(ts) + ibd_segs = ibd_f.find_ibd_segments() + # print(ibd_segs) + ### NOTE: At the moment, this returns an empty set: + # {(0, 1): []} + # Should modify code to account for this case. + # Should return that 1 is the ancestor of samples 0 and 1. + + +class TestIbdDifferentPaths(unittest.TestCase): + # + # 4 | 4 | 4 + # / \ | / \ | / \ + # / \ | / 3 | / \ + # / \ | 2 \ | / \ + # / \ | / \ | / \ + # 0 1 | 0 1 | 0 1 + # | | + # 0.2 0.7 + + nodes = io.StringIO( + """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + 3 0 1.5 + 4 0 2.5 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0.2 0.7 2 0 + 0.2 0.7 3 1 + 0.0 0.2 4 0 + 0.0 0.2 4 1 + 0.7 1.0 4 0 + 0.7 1.0 4 1 + 0.2 0.7 4 2 + 0.2 0.7 4 3 + """ + ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + + def test_defaults(self): + ts = self.ts + ibd_f = ibd.IbdFinder(ts) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = {(0, 1): [ibd.Segment(0.0, 0.2, 4), ibd.Segment(0.7, 1.0, 4), + ibd.Segment(0.2, 0.7, 4)]} + assert ibd_is_equal(ibd_segs, true_segs) + + def test_time(self): + ts = self.ts + ibd_f = ibd.IbdFinder(ts, max_time=1.8) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = {(0, 1): []} + assert ibd_is_equal(ibd_segs, true_segs) + ibd_f = ibd.IbdFinder(ts, max_time=2.8) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = {(0, 1): [ibd.Segment(0.0, 0.2, 4), ibd.Segment(0.7, 1.0, 4), + ibd.Segment(0.2, 0.7, 4)]} + assert ibd_is_equal(ibd_segs, true_segs) + + def test_length(self): + ts = self.ts + ibd_f = ibd.IbdFinder(ts, min_length=0.4) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = {(0, 1): [ibd.Segment(0.2, 0.7, 4)]} + assert ibd_is_equal(ibd_segs, true_segs) + class TestIbdByLength(unittest.TestCase): """ @@ -224,18 +476,18 @@ class TestIbdByLength(unittest.TestCase): """ - def test_canonical_example1(self): - ts0 = msprime.simulate(sample_size=5, recombination_rate=.5, random_seed=2) + def test_random_example1(self): + ts0 = msprime.simulate(sample_size=10, recombination_rate=.5, random_seed=2) verify_equal_ibd(ts0) - def test_canonical_example2(self): - ts1 = msprime.simulate(sample_size=5, recombination_rate=.5, random_seed=23) + def test_random_example2(self): + ts1 = msprime.simulate(sample_size=10, recombination_rate=.5, random_seed=23) verify_equal_ibd(ts1) - def test_canonical_example3(self): - ts2 = msprime.simulate(sample_size=5, recombination_rate=.5, random_seed=232) + def test_random_example3(self): + ts2 = msprime.simulate(sample_size=10, recombination_rate=.5, random_seed=232) verify_equal_ibd(ts2) - def test_random_example(self): + def test_random_example4(self): ts_r = msprime.simulate(sample_size=10, recombination_rate=.3, random_seed=726) verify_equal_ibd(ts_r) From f1a18d8b552273d7664bd7c8382f0347ef0abd9b Mon Sep 17 00:00:00 2001 From: Georgia Tsambos Date: Thu, 2 Apr 2020 14:39:49 +1100 Subject: [PATCH 3/4] Added SegmentList class into IBD algorithm. Added to SegmentList class and changed all uses of A, S to use the SegmentList class. Fixed bug. Added more examples to test_ibd.py ibd_segments is now an attribute of the IbdFinder class Removed new_segs Infrastructure for converting sample pairs to index in struct array. ibd_segments now contains linked lists Fixed bug, modified tests to convert SegmentList output into normal python lists. Neatened the IBD tests. Added some documentation. Added more tests, code to deal with fringe cases. Jerome and Peter's latest comments. Peter's comments. Updated IBD tests with forward time DTWF tests. --- python/tests/ibd.py | 595 ++++++++++++++++++-------------- python/tests/test_ibd.py | 715 ++++++++++++++++++++++++++++----------- 2 files changed, 855 insertions(+), 455 deletions(-) diff --git a/python/tests/ibd.py b/python/tests/ibd.py index 6d68dc4f7b..52625f252e 100644 --- a/python/tests/ibd.py +++ b/python/tests/ibd.py @@ -1,7 +1,6 @@ # MIT License # -# Copyright (c) 2018-2020 Tskit Developers -# Copyright (c) 2015-2018 University of Oxford +# Copyright (c) 2020 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,15 +22,12 @@ """ Python implementation of the IBD-finding algorithms. """ +import argparse import tskit -import numpy as np -import itertools -import sys -import argparse -class Segment(object): +class Segment: """ A class representing a single segment. Each segment has a left and right, denoting the loci over which it spans, a node and a next, giving the next @@ -39,90 +35,223 @@ class Segment(object): The node it records is the *output* node ID. """ - def __init__(self, left=None, right=None, node=None, next=None): + + def __init__(self, left=None, right=None, node=None, next_seg=None): self.left = left self.right = right self.node = node - self.next = next + self.next = next_seg def __str__(self): s = "({}-{}->{}:next={})".format( - self.left, self.right, self.node, repr(self.next)) + self.left, self.right, self.node, repr(self.next) + ) return s def __repr__(self): return repr((self.left, self.right, self.node)) def __eq__(self, other): - return (self.left == other.left and self.right == other.right and self.node == other.node) + # NOTE: to simplify tests, we DON'T check for equality of 'next'. + return ( + self.left == other.left + and self.right == other.right + and self.node == other.node + ) def __lt__(self, other): - return (self.node, self.left, self.right) < (other.node, other.left, other.right) + return (self.node, self.left, self.right) < ( + other.node, + other.left, + other.right, + ) -class SegmentList(object): +class SegmentList: """ - A class representing a list of segments that are descended from a given ancestral node - via a particular child of the ancestor. + A class representing a list of segments that are descended from a given ancestral + node via a particular child of the ancestor. + Each SegmentList keeps track of the first and last segment in the list, head and + tail. The next attribute points to another SegmentList, allowing SegmentList + objects to be 'chained' to one another. """ - def __init__(self, head=None, tail=None, next=None): + + def __init__(self, head=None, tail=None, next_list=None): self.head = head self.tail = tail - self.next = next + self.next = next_list def __str__(self): - s = "head={},tail={},next={}".format( - self.head, self.right, repr(self.next)) + s = "head={},tail={},next={}".format(self.head, self.tail, repr(self.next)) return s def __repr__(self): - s = "[{}...]".format(repr(self.next.head)) + if self.head is None: + s = "[{}]".format(repr(None)) + elif self.head == self.tail: + s = "[{}]".format(repr(self.head)) + elif self.head.next == self.tail: + s = "[{}, {}]".format(repr(self.head), repr(self.tail)) + else: + s = "[{}, ..., {}]".format(repr(self.head), repr(self.tail)) return s - - -class IbdFinder(object): + def __len__(self): + # Returns the number of segments in the list. + count = 0 + seg = self.head + while seg is not None: + count += 1 + seg = seg.next + return count + + def add(self, other): + """ + Use to append another SegmentList, or a single segment. + SegmentList1.add(SegmentList2) will modify SegmentList1 so that + SegmentList1.tail.next = SegmentList2.head + SegmentList1.add(Segment1) will add Segment1 to the tail of SegmentList1 + """ + assert isinstance(other, SegmentList) or isinstance(other, Segment) + + if isinstance(other, SegmentList): + if self.head is None: + self.head = other.head + self.tail = other.tail + else: + self.tail.next = other.head + self.tail = other.tail + elif isinstance(other, Segment): + if self.head is None: + self.head = other + self.tail = other + else: + self.tail.next = other + self.tail = other + + +class IbdFinder: """ Finds all IBD relationships between specified samples in a tree sequence. """ - def __init__( - self, - ts, - samples=None, - min_length=0, - max_time=None): + def __init__(self, ts, samples=None, min_length=0, max_time=None): self.ts = ts + # Note: samples *must* be in order of ascending node ID if samples is None: - samples = ts.samples() + self.samples = ts.samples() else: self.samples = samples + self.samples.sort() self.min_length = min_length self.max_time = max_time self.current_parent = self.ts.tables.edges.parent[0] self.current_time = 0 - self.A = [[] for _ in range(ts.num_nodes)] # Descendant segments + self.A = [None for _ in range(ts.num_nodes)] # Descendant segments self.tables = self.ts.tables - # self.ibd_segments = dict.fromkeys(itertools.combinations(self.ts.samples(), 2), []) - + self.parent_list = [[] for i in range(0, self.ts.num_nodes)] + for e in self.ts.tables.edges: + if ( + len(self.parent_list[e.child]) == 0 + or e.parent != self.parent_list[e.child][-1] + ): + self.parent_list[e.child].append(e.parent) + # Objects below are needed for the IBD segment-holding object. + self.sample_id_map = self.get_sample_id_map() + self.num_samples = len(self.samples) + self.sample_pairs = self.get_sample_pairs() + + # Note: in the C code the object below should be a struct array. + # Each item will be accessed using its index, which corresponds to a particular + # sample pair. The mapping between index and sample pair is defined in the + # find_sample_pair_index method further down. + + self.ibd_segments = {} + for key in self.sample_pairs: + self.ibd_segments[key] = None + + def add_ibd_segments(self, sample0, sample1, seg): + index = self.find_sample_pair_index(sample0, sample1) + + # Note: the code below is specific to the Python implementation, where the + # output is a dictionary indexed by sample pairs. + # In the C implementation, it'll be more like + # self.ibd_segments[index].add(seg) + + if self.ibd_segments[self.sample_pairs[index]] is None: + self.ibd_segments[self.sample_pairs[index]] = SegmentList( + head=seg, tail=seg + ) + else: + self.ibd_segments[self.sample_pairs[index]].add(seg) + + def get_sample_id_map(self): + """ + Returns id_map, a vector of length ts.num_nodes. For a node with id i, + id_map[i] is the position of the node in self.samples. If i is not a + user-specified sample, id_map[i] is -1. + Note: it assumes nodes in the TS are numbered 0 to ts.num_nodes - 1 + """ + id_map = [-1] * self.ts.num_nodes + for i, samp in enumerate(self.samples): + id_map[samp] = i + + return id_map + + def get_sample_pairs(self): + """ + Returns a list of all pairs of samples. Replaces itertools.combinations. + Note: they must be sorted + """ + sample_pairs = [] + for ind, i in enumerate(self.samples): + for j in self.samples[(ind + 1) :]: + sample_pairs.append((i, j)) + + return sample_pairs + + def find_sample_pair_index(self, sample0, sample1): + """ + Note: this method isn't strictly necessary for the Python implementation + but is needed for the C implemention, where the output ibd_segments is a + struct array. + This calculates the position of the object corresponding to the inputted + sample pair in the struct array. + """ + + # Ensure samples are in order. + if sample0 == sample1: + raise ValueError("Samples in pair must have different node IDs.") + elif sample0 > sample1: + sample0, sample1 = sample1, sample0 + + i0 = self.sample_id_map[sample0] + i1 = self.sample_id_map[sample1] + + # Calculate the position of the sample pair in the vector. + index = ( + (self.num_samples) * (self.num_samples - 1) / 2 + - (self.num_samples - i0) * (self.num_samples - i0 - 1) / 2 + + i1 + - i0 + - 1 + ) + + return int(index) def find_ibd_segments(self, min_length=0, max_time=None): - - # 1 - # A = [[] for n in range(0, self.ts.num_nodes)] - ibd_segments = dict.fromkeys(itertools.combinations(self.ts.samples(), 2), []) - edges = self.ts.edges() - parent_list = self.list_of_parents() ## Needed for memory-pruning step - - # 2 + """ + The wrapper for the procedure that calculates IBD segments. + """ + + # Set up an iterator over the edges in the tree sequence. mygen = iter(self.ts.edges()) e = next(mygen) - - # 3 - while e is not None: - # 3a + # Iterate over the edges. + while e is not None: + # Process all edges with the same parent node. S_head = None S_tail = None self.current_parent = e.parent @@ -131,10 +260,11 @@ def find_ibd_segments(self, min_length=0, max_time=None): # Stop looking for IBD segments once the # processed nodes are older than the max time. break - - # 3b + + # Create a list of segments, S, bookended by S_head and S_tail. + # The list holds all segments that are immediate descendants of + # the current parent node being processed. while e is not None and self.current_parent == e.parent: - # Create the list S of immediate descendants of u. seg = Segment(e.left, e.right, e.child) if S_head is None: S_head = seg @@ -142,216 +272,185 @@ def find_ibd_segments(self, min_length=0, max_time=None): else: S_tail.next = seg S_tail = S_tail.next - if e.id < self.ts.num_edges - 1: - e = next(mygen) - continue - else: - e = None + e = next(mygen, None) + + # Create a list of SegmentList objects, SegL, bookended by SegL_head and + # SegL_tail. Each SegmentList corresponds to a segment s in S, and + # contains all segments of samples that descend from s. + SegL_head = SegmentList() + SegL_tail = SegmentList() - # 3c while S_head is not None: - # Create A[u] from S. - # Do we still need to do the initialisation if the below is there?? u = S_head.node - if u in self.ts.samples(): - self.A[self.current_parent].append([S_head]) + list_to_add = SegmentList() + if u in self.samples: + list_to_add.add(Segment(S_head.left, S_head.right, u)) else: - list_to_add = [] - for s in self.A[u]: - l = (max(S_head.left, s.left), min(S_head.right, s.right)) - if l[1] - l[0] > 0: - list_to_add.append(Segment(l[0], l[1], s.node)) - self.A[self.current_parent].append(list_to_add) + if self.A[u] is not None: + s = self.A[u].head + while s is not None: + intvl = ( + max(S_head.left, s.left), + min(S_head.right, s.right), + ) + if intvl[1] - intvl[0] > 0: + list_to_add.add(Segment(intvl[0], intvl[1], s.node)) + s = s.next + + # Add this list to the end of SegL. + if SegL_head.head is None: + SegL_head = list_to_add + SegL_tail = list_to_add + else: + SegL_tail.next = list_to_add + SegL_tail = list_to_add + S_head = S_head.next - # d. Squash - # A[self.current_parent] = self.squash(self.A[self.current_parent]) - - # e. Process A[self.current_parent] - if len(self.A[self.current_parent]) > 1: - new_segs, nodes_to_remove = self.update_A_and_find_ibd_segs(ibd_segments) - - # e. Add any new IBD segments discovered. - for key, val in new_segs.items(): - for v in val: - if len(ibd_segments[key]) == 0: - ibd_segments[key] = [v] - else: - ibd_segments[key].append(v) - - # g. Remove elements of self.A[u] if they are no longer needed. - ## (memory-pruning step) + # If we wanted to squash segments, we'd do it here. + + # Use the info in SegL to find new ibd segments that descend from the + # parent node currently being processed. + if SegL_head.next is not None or self.current_parent in self.samples: + nodes_to_remove = self.calculate_ibd_segs(SegL_head) + + # g. Remove elements of the list of ancestral segments A if they + # are no longer needed. (Memory-pruning step) for n in nodes_to_remove: - if self.current_parent in parent_list[n]: - parent_list[n].remove(self.current_parent) - if len(parent_list[n]) == 0: + if self.current_parent in self.parent_list[n]: + self.parent_list[n].remove(self.current_parent) + if len(self.parent_list[n]) == 0: self.A[n] = [] - - # Unlist the ancestral segments in self.A. - self.A[self.current_parent] = list(itertools.chain(*self.A[self.current_parent])) - - # 4 - return ibd_segments - - - def update_A_and_find_ibd_segs(self, ibd_segments): - - new_segments = dict.fromkeys(itertools.combinations(self.ts.samples(), 2), []) - num_coalescing_sets = len(self.A[self.current_parent]) - index_pairs = list(itertools.combinations(range(0, num_coalescing_sets), 2)) - - for setpair in index_pairs: - for seg0 in self.A[self.current_parent][setpair[0]]: - for seg1 in self.A[self.current_parent][setpair[1]]: - - if seg0.node == seg1.node: - continue - left = max(seg0.left, seg1.left) - right = min(seg0.right, seg1.right) - if left >= right: - continue - nodes = [seg0.node, seg1.node] - nodes.sort() - - # if mrca_ibd: - # pass # for now - # existing_segs = ibd_segments[(nodes[0], nodes[1])].copy() - # if right - left > self.min_length: - # if len(existing_segs) == 0: - # new_segments[(nodes[0], nodes[1])] = [Segment(left, right, self.current_parent)] - # existing_segs.append(Segment(left, right, self.current_parent)) - # else: - # for i in existing_segs: - # # no overlap. - # if right <= i.left or left >= i.right: - # if len(new_segments[(nodes[0], nodes[1])]) == 0: - # new_segments[(nodes[0], nodes[1])] = [Segment(left, right, self.current_parent)] - # else: - # new_segments[(nodes[0], nodes[1])].append(Segment(left, right, self.current_parent)) - # existing_segs.append(Segment(left, right, self.current_parent)) - # # partial overlap -- does this even happen? - # elif (left < i.left and right < i.right) or (i.left < left and i.right < right): - # print('partial overlap') - # Yes, but I think it's okay to leave these segments... - # else: - if right - left > self.min_length: - if len(new_segments[(nodes[0], nodes[1])]) == 0: - new_segments[(nodes[0], nodes[1])] = [Segment(left, right, self.current_parent)] - else: - new_segments[(nodes[0], nodes[1])].append(Segment(left, right, self.current_parent)) - - # iv. specify elements of A that can be removed (for memory-pruning step) + + # Create the list of ancestral segments, A, descending from current node. + # Concatenate all of the segment lists in SegL into a single list, and + # store in A. This is a list of all sample segments descending from + # the currently processed node. + seglist = SegL_head + self.A[self.current_parent] = seglist + seglist = seglist.next + while seglist is not None: + self.A[self.current_parent].add(seglist) + seglist = seglist.next + + return self.ibd_segments + + def calculate_ibd_segs(self, SegL_head): + """ + Calculates all new IBD segments found at the current iteration of the + algorithm. Returns information about nodes processed in this step, which + allows memory to be cleared as the procedure runs. + """ + + list0 = SegL_head + # Iterate through all pairs of segment lists. + while list0.next is not None: + list1 = list0.next + while list1 is not None: + seg0 = list0.head + # Iterate through all segments with one taken from each list. + while seg0 is not None: + seg1 = list1.head + while seg1 is not None: + left = max(seg0.left, seg1.left) + right = min(seg0.right, seg1.right) + if left >= right: + seg1 = seg1.next + continue + nodes = [seg0.node, seg1.node] + nodes.sort() + + # If there are any overlapping segments, record as a new + # IBD relationship. + if right - left > self.min_length: + self.add_ibd_segments( + nodes[0], + nodes[1], + Segment(left, right, self.current_parent), + ) + + seg1 = seg1.next + + seg0 = seg0.next + list1 = list1.next + list0 = list0.next + + # If the current processed node is itself a sample, calculate IBD separately. + if self.current_parent in self.samples: + seglist = SegL_head + while seglist is not None: + seg = seglist.head + while seg is not None: + self.add_ibd_segments( + seg.node, + self.current_parent, + Segment(seg.left, seg.right, self.current_parent), + ) + seg = seg.next + seglist = seglist.next + + # Specify elements of A that can be removed (for memory-pruning step) processed_child_nodes = [] - for seglist in self.A[self.current_parent]: - processed_child_nodes += [seg.node for seg in seglist] - processed_child_nodes = list(set(processed_child_nodes)) - - return new_segments, processed_child_nodes - - - def list_of_parents(self): - parents = [[] for i in range(0, self.ts.num_nodes)] - edges = self.ts.tables.edges - for e in edges: - if len(parents[e.child]) == 0 or e.parent != parents[e.child][-1]: - parents[e.child].append(e.parent) - return parents - - - # def squash(self, segment_lists): - - # # Concatenate the input lists and record the number of - # # segments in each. - # A_u = [] - # num_desc_edges = [] - # for L in segment_lists: - # for l in L: - # A_u.append(l) - # num_desc_edges.append(len(L)) - - # # Sort the list, keeping track of the original order. - # sorted_A = sorted(enumerate(A_u), key=lambda i:i[1]) - - # # Squash the list. - # next_ind = len(sorted_A) - # inds_to_remove = [] - # ind = 1 - # while ind < len(sorted_A): - # if sorted_A[ind][1].node == sorted_A[ind - 1][1].node: - # if sorted_A[ind][1].right > sorted_A[ind - 1][1].right and\ - # sorted_A[ind][1].left <= sorted_A[ind - 1][1].right: - # # Squash the previous int into the current one. - # sorted_A[ind][1].left = sorted_A[ind - 1][1].left - # # Flag the interval to be removed. - # inds_to_remove.append(ind - 1) - # # Change order index. - # sorted_A[ind] = (next_ind, sorted_A[ind][1]) - # next_ind += 1 - # ind += 1 - - # # Remove any unnecessary list items. - # for i in reversed(inds_to_remove): - # # Needs to be done in reverse order!! - # sorted_A.pop(i) - - # # Restore the original order as lists of lists. - # cum_sum = np.cumsum(num_desc_edges) - # squashed_sorted_A = [[] for _ in range(0, next_ind)] - # for a in sorted_A: - # ind = a[0] - # if ind < cum_sum[-1]: - # s = 0 - # while s < len(cum_sum): - # if a[0] < cum_sum[s]: - # squashed_sorted_A[s].append(a[1]) - # break - # s += 1 - - # else: - # squashed_sorted_A[ind].append(a[1]) - - # # Remove lists of length 0. - # squashed_sorted_A = [_ for _ in squashed_sorted_A if len(_) > 0] - - # return squashed_sorted_A - + seglist = SegL_head + while seglist is not None: + seg = seglist.head + while seg is not None: + processed_child_nodes += [seg.node] + seg = seg.next + seglist = seglist.next + processed_child_nodes = list(set(processed_child_nodes)) + + return processed_child_nodes + if __name__ == "__main__": - # Simple CLI for running IBDFinder. - - parser = argparse.ArgumentParser(description="Command line interface for the IBDFinder.") - - parser.add_argument('--infile', - type=str, - dest='infile', - nargs=1, - metavar="IN_FILE", - help="The tree sequence to be analysed.") - - parser.add_argument('--min-length', - type=float, - dest='min_length', - nargs=1, - metavar="MIN_LENGTH", - help="Only segments longer than this cutoff will be returned.") - - parser.add_argument('--max-time', - type=float, - dest='max_time', - nargs=1, - metavar="MAX_TIME", - help="Only segments younger this time will be returned.") - - parser.add_argument('--samples', - type=int, - dest='samples', - nargs=2, - metavar="SAMPLES", - help="If provided, only IBD relationships between the given node pair are returned." - ) + """ + A simple CLI for running IBDFinder on a command line from the `python` + subdirectory. Basic usage: + > python3 ./tests/ibd.py --infile test.trees + """ - args = parser.parse_args() + parser = argparse.ArgumentParser( + description="Command line interface for the IBDFinder." + ) + + parser.add_argument( + "--infile", + type=str, + dest="infile", + nargs=1, + metavar="IN_FILE", + help="The tree sequence to be analysed.", + ) + + parser.add_argument( + "--min-length", + type=float, + dest="min_length", + nargs=1, + metavar="MIN_LENGTH", + help="Only segments longer than this cutoff will be returned.", + ) + + parser.add_argument( + "--max-time", + type=float, + dest="max_time", + nargs=1, + metavar="MAX_TIME", + help="Only segments younger this time will be returned.", + ) + + parser.add_argument( + "--samples", + type=int, + dest="samples", + nargs=2, + metavar="SAMPLES", + help="If provided, only this pair's IBD info is returned.", + ) + args = parser.parse_args() ts = tskit.load(args.infile[0]) if args.min_length is None: min_length = 0 @@ -362,15 +461,11 @@ def list_of_parents(self): else: max_time = args.max_time[0] - s = IbdFinder(ts, samples = ts.samples(), - min_length=min_length, max_time=max_time) + s = IbdFinder(ts, samples=ts.samples(), min_length=min_length, max_time=max_time) all_segs = s.find_ibd_segments() - if args.samples is None: print(all_segs) else: samples = args.samples print(all_segs[(samples[0], samples[1])]) - - diff --git a/python/tests/test_ibd.py b/python/tests/test_ibd.py index b3f31619cb..75b6870661 100644 --- a/python/tests/test_ibd.py +++ b/python/tests/test_ibd.py @@ -1,39 +1,48 @@ - """ Tests of IBD finding algorithms. """ -import unittest -import sys -import random import io import itertools +import random +import unittest -import tests as tests -import tests.ibd as ibd +import msprime +from packaging import version +import tests.ibd as ibd import tskit -import msprime -import numpy as np # Functions for computing IBD 'naively'. -def get_ibd(sample0, sample1, treeSequence, min_length=0, max_time=None, - path_ibd=True, mrca_ibd=False): + +def get_ibd( + sample0, + sample1, + treeSequence, + min_length=0, + max_time=None, + path_ibd=True, + mrca_ibd=True, +): """ Returns all IBD segments for a given pair of nodes in a tree using a naive algorithm. + Note: This function probably looks more complicated than it needs to be -- + This is because it also calculates other 'versions' of IBD (mrca_ibd=False, + path_ibd=False) that we have't implemented properly yet. """ ibd_list = [] - ts, node_map = treeSequence.simplify(samples=[sample0, sample1], keep_unary=True, - map_nodes=True) + ts, node_map = treeSequence.simplify( + samples=[sample0, sample1], keep_unary=True, map_nodes=True + ) node_map = node_map.tolist() - + for n in ts.nodes(): - + if max_time is not None and n.time > max_time: break - + node_id = n.id interval_list = [] if n.flags == 1: @@ -45,7 +54,7 @@ def get_ibd(sample0, sample1, treeSequence, min_length=0, max_time=None, if len(list(t.nodes(n.id))) == 1 or t.num_samples(n.id) < 2: continue if mrca_ibd and n.id != t.mrca(0, 1): - continue + continue current_int = t.get_interval() if len(interval_list) == 0: @@ -54,41 +63,63 @@ def get_ibd(sample0, sample1, treeSequence, min_length=0, max_time=None, prev_int = interval_list[-1] if not path_ibd and prev_int[1] == current_int[0]: interval_list[-1] = (prev_int[0], current_int[1]) - elif prev_dict is not None and subtrees_are_equal(t, prev_dict, node_id): + elif prev_dict is not None and subtrees_are_equal( + t, prev_dict, node_id + ): interval_list[-1] = (prev_int[0], current_int[1]) else: interval_list.append(current_int) - + prev_dict = t.get_parent_dict() - + for interval in interval_list: if min_length == 0 or interval[1] - interval[0] > min_length: orig_id = node_map.index(node_id) ibd_list.append(ibd.Segment(interval[0], interval[1], orig_id)) - - return(ibd_list) + return ibd_list + + +def get_ibd_all_pairs( + treeSequence, + samples=None, + min_length=0, + max_time=None, + path_ibd=True, + mrca_ibd=False, +): + + """ + Returns all IBD segments for all pairs of nodes in a tree sequence + using the naive algorithm above. + """ -def get_ibd_all_pairs(treeSequence, samples=None, min_length=0, max_time=None, - path_ibd=True, mrca_ibd=False): - ibd_dict = {} - + if samples is None: samples = treeSequence.samples().tolist() - + pairs = itertools.combinations(samples, 2) for pair in pairs: - ibd_list = get_ibd(pair[0], pair[1], treeSequence, - min_length=min_length, max_time=max_time, - path_ibd=path_ibd, mrca_ibd=mrca_ibd) + ibd_list = get_ibd( + pair[0], + pair[1], + treeSequence, + min_length=min_length, + max_time=max_time, + path_ibd=path_ibd, + mrca_ibd=mrca_ibd, + ) if len(ibd_list) > 0: ibd_dict[pair] = ibd_list - - return(ibd_dict) + + return ibd_dict def subtrees_are_equal(tree1, pdict0, root): + """ + Checks for equality of two subtrees beneath a given root node. + """ pdict1 = tree1.get_parent_dict() if root not in pdict0.values() or root not in pdict1.values(): return False @@ -100,33 +131,39 @@ def subtrees_are_equal(tree1, pdict0, root): if p1 not in pdict0.values(): return False p0 = pdict0[node] - if p0 != p1: - return False + if p0 != p1: + return False node = p1 - + return True def verify_equal_ibd(treeSequence): """ - Calculates IBD segments using both the 'naive' and sophisticated algorithms, + Calculates IBD segments using both the 'naive' and sophisticated algorithms, verifies that the same output is produced. NB: May be good to expand this in the future so that many different combos of IBD options are tested simultaneously (all the MRCA and path-IBD combos), for example. """ ts = treeSequence - ibd0 = ibd.IbdFinder(ts, samples = ts.samples()) + ibd0 = ibd.IbdFinder(ts, samples=ts.samples()) ibd0 = ibd0.find_ibd_segments() ibd1 = get_ibd_all_pairs(ts, path_ibd=True, mrca_ibd=True) - for key0, val0 in ibd0.items(): + # Convert each SegmentList object into a list of Segment objects. + ibd0_tolist = {} + for key, val in ibd0.items(): + ibd0_tolist[key] = convert_segmentlist_to_list(val) + + # Check for equality. + for key0, val0 in ibd0_tolist.items(): assert key0 in ibd1.keys() val1 = ibd1[key0] val0.sort() val1.sort() - if val0 is None: # Get rid of this later -- don't want empty dict values at all + if val0 is None: # Get rid of this later -- don't want empty dict values at all assert val1 is None continue elif val1 is None: @@ -135,6 +172,38 @@ def verify_equal_ibd(treeSequence): for i in range(0, len(val0)): assert val0[i] == val1[i] + +def convert_segmentlist_to_list(seglist): + """ + Turns a SegmentList object into a list of Segment objects. + (This makes them easier to compare for testing purposes) + """ + outlist = [] + if seglist is None: + return outlist + else: + seg = seglist.head + outlist = [seg] + seg = seg.next + while seg is not None: + outlist.append(seg) + seg = seg.next + + return outlist + + +def convert_dict_of_segmentlists(dict0): + """ + Turns a dictionary of SegmentList objects into a dictionary of lists of + Segment objects. (makes them easier to compare in tests). + """ + dict_out = {} + for key, val in dict0.items(): + dict_out[key] = convert_segmentlist_to_list(val) + + return dict_out + + def ibd_is_equal(dict1, dict2): """ Verifies that two dictionaries have the same keys, and that @@ -155,6 +224,10 @@ def ibd_is_equal(dict1, dict2): def segment_lists_are_equal(val1, val2): + """ + Returns True if the two lists hold the same set of segments, otherwise + returns False. + """ if len(val1) != len(val2): return False @@ -162,7 +235,7 @@ def segment_lists_are_equal(val1, val2): val1.sort() val2.sort() - if val1 is None: # get rid of this later -- we don't any empty dict values! + if val1 is None: # get rid of this later -- we don't any empty dict values! if val2 is not None: return False elif val2 is None: @@ -177,123 +250,162 @@ def segment_lists_are_equal(val1, val2): return True -class TestIbdTopologies(unittest.TestCase): +class TestIbdSingleBinaryTree(unittest.TestCase): - def test_single_binary_tree(self): - # - # 2 4 - # / \ - # 1 3 \ - # / \ \ - # 0 0 1 2 - nodes = io.StringIO( - """\ - id is_sample time - 0 1 0 - 1 1 0 - 2 1 0 - 3 0 1 - 4 0 2 - """ - ) - edges = io.StringIO( - """\ - left right parent child - 0 1 3 0,1 - 0 1 4 2,3 - """ - ) - ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) - # Basic test - ibd_f = ibd.IbdFinder(ts) + # + # 2 4 + # / \ + # 1 3 \ + # / \ \ + # 0 0 1 2 + nodes = io.StringIO( + """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 0 1 + 4 0 2 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0 1 3 0,1 + 0 1 4 2,3 + """ + ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + + # Basic test + def test_defaults(self): + ibd_f = ibd.IbdFinder(self.ts) ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = convert_dict_of_segmentlists(ibd_segs) true_segs = { - (0, 1): [ibd.Segment(0.0, 1.0, 3)], - (0, 2): [ibd.Segment(0.0, 1.0, 4)], - (1, 2): [ibd.Segment(0.0, 1.0, 4)]} + (0, 1): [ibd.Segment(0.0, 1.0, 3)], + (0, 2): [ibd.Segment(0.0, 1.0, 4)], + (1, 2): [ibd.Segment(0.0, 1.0, 4)], + } assert ibd_is_equal(ibd_segs, true_segs) - # Max time = 1.5 - ibd_f = ibd.IbdFinder(ts, max_time=1.5) + + # Max time = 1.5 + def test_time(self): + ibd_f = ibd.IbdFinder(self.ts, max_time=1.5) ibd_segs = ibd_f.find_ibd_segments() - true_segs = {(0, 1): [ibd.Segment(0.0, 1.0, 3)], - (0, 2): [], (1, 2): []} + ibd_segs = convert_dict_of_segmentlists(ibd_segs) + true_segs = {(0, 1): [ibd.Segment(0.0, 1.0, 3)], (0, 2): [], (1, 2): []} assert ibd_is_equal(ibd_segs, true_segs) - # # Min length = 2 - ibd_f = ibd.IbdFinder(ts, min_length=2) + + # Min length = 2 + def test_length(self): + ibd_f = ibd.IbdFinder(self.ts, min_length=2) ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = convert_dict_of_segmentlists(ibd_segs) true_segs = {(0, 1): [], (0, 2): [], (1, 2): []} assert ibd_is_equal(ibd_segs, true_segs) - def test_two_samples_two_trees(self): - # - # 2 - # | 3 - # 1 2 | / \ - # / \ | / \ - # 0 0 1 | 0 1 - #|------------|----------| - #0.0 0.4 1.0 - nodes = io.StringIO( - """\ - id is_sample time - 0 1 0 - 1 1 0 - 2 0 1 - 3 0 1.5 - """ - ) - edges = io.StringIO( - """\ - left right parent child - 0 0.4 2 0,1 - 0.4 1.0 3 0,1 - """ - ) - ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) - # Basic test - ibd_f = ibd.IbdFinder(ts) + +class TestIbdTwoSamplesTwoTrees(unittest.TestCase): + + # 2 + # | 3 + # 1 2 | / \ + # / \ | / \ + # 0 0 1 | 0 1 + # |------------|----------| + # 0.0 0.4 1.0 + nodes = io.StringIO( + """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + 3 0 1.5 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0 0.4 2 0,1 + 0.4 1.0 3 0,1 + """ + ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + + # Basic test + def test_basic(self): + ibd_f = ibd.IbdFinder(self.ts) ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = convert_dict_of_segmentlists(ibd_segs) true_segs = {(0, 1): [ibd.Segment(0.0, 0.4, 2), ibd.Segment(0.4, 1.0, 3)]} assert ibd_is_equal(ibd_segs, true_segs) - # Max time = 1.2 - ibd_f = ibd.IbdFinder(ts, max_time=1.2) + + # Max time = 1.2 + def test_time(self): + ibd_f = ibd.IbdFinder(self.ts, max_time=1.2) ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = convert_dict_of_segmentlists(ibd_segs) true_segs = {(0, 1): [ibd.Segment(0.0, 0.4, 2)]} assert ibd_is_equal(ibd_segs, true_segs) - # Min length = 0.5 - ibd_f = ibd.IbdFinder(ts, max_time=1.2) + + # Min length = 0.5 + def test_length(self): + ibd_f = ibd.IbdFinder(self.ts, min_length=0.5) ibd_segs = ibd_f.find_ibd_segments() - true_segs = {(0, 1): [ibd.Segment(0.0, 0.4, 2)]} + ibd_segs = convert_dict_of_segmentlists(ibd_segs) + true_segs = {(0, 1): [ibd.Segment(0.4, 1.0, 3)]} assert ibd_is_equal(ibd_segs, true_segs) - def test_unrelated_samples(self): - # - # 2 3 - # | | - # 0 1 - nodes = io.StringIO( - """\ - id is_sample time - 0 1 0 - 1 1 0 - 2 0 1 - 3 0 1 - """ - ) - edges = io.StringIO( - """\ - left right parent child - 0 1 2 0 - 0 1 3 1 - """ - ) - ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) - ibd_f = ibd.IbdFinder(ts) +class TestIbdUnrelatedSamples(unittest.TestCase): + + # + # 2 3 + # | | + # 0 1 + + nodes = io.StringIO( + """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 0 1 + 3 0 1 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0 1 2 0 + 0 1 3 1 + """ + ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + + def test_basic(self): + ibd_f = ibd.IbdFinder(self.ts) + ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = convert_dict_of_segmentlists(ibd_segs) + true_segs = {(0, 1): []} + assert ibd_is_equal(ibd_segs, true_segs) + + def test_time(self): + ibd_f = ibd.IbdFinder(self.ts, max_time=1.2) + ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = convert_dict_of_segmentlists(ibd_segs) + true_segs = {(0, 1): []} + assert ibd_is_equal(ibd_segs, true_segs) + + def test_length(self): + ibd_f = ibd.IbdFinder(self.ts, min_length=0.2) ibd_segs = ibd_f.find_ibd_segments() - true_segs = {(0,1): []} + ibd_segs = convert_dict_of_segmentlists(ibd_segs) + true_segs = {(0, 1): []} assert ibd_is_equal(ibd_segs, true_segs) + +class TestIbdNoSamples(unittest.TestCase): def test_no_samples(self): # # 2 @@ -317,12 +429,14 @@ def test_no_samples(self): 0 1 3 1 """ ) - ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) ibd_f = ibd.IbdFinder(ts) ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = convert_dict_of_segmentlists(ibd_segs) true_segs = {} assert ibd_is_equal(ibd_segs, true_segs) + class TestIbdSamplesAreDescendants(unittest.TestCase): # # 2 @@ -345,26 +459,24 @@ class TestIbdSamplesAreDescendants(unittest.TestCase): 0 1 2 1 """ ) - ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) def test_basic(self): ts = self.ts ibd_f = ibd.IbdFinder(ts) ibd_segs = ibd_f.find_ibd_segments() - # print(ibd_segs) - ### NOTE: At the moment, this returns an empty set: - # {(0, 1): []} - # Should modify code to account for this case. - # Should return that 1 is the ancestor of samples 0 and 1. + ibd_segs = convert_dict_of_segmentlists(ibd_segs) + true_segs = {(0, 1): [ibd.Segment(0.0, 1.0, 1)]} + assert ibd_is_equal(ibd_segs, true_segs) class TestIbdDifferentPaths(unittest.TestCase): - # + # # 4 | 4 | 4 # / \ | / \ | / \ # / \ | / 3 | / \ # / \ | 2 \ | / \ - # / \ | / \ | / \ + # / \ | / \ | / \ # 0 1 | 0 1 | 0 1 # | | # 0.2 0.7 @@ -398,8 +510,14 @@ def test_defaults(self): ts = self.ts ibd_f = ibd.IbdFinder(ts) ibd_segs = ibd_f.find_ibd_segments() - true_segs = {(0, 1): [ibd.Segment(0.0, 0.2, 4), ibd.Segment(0.7, 1.0, 4), - ibd.Segment(0.2, 0.7, 4)]} + true_segs = { + (0, 1): [ + ibd.Segment(0.0, 0.2, 4), + ibd.Segment(0.7, 1.0, 4), + ibd.Segment(0.2, 0.7, 4), + ] + } + ibd_segs = convert_dict_of_segmentlists(ibd_segs) assert ibd_is_equal(ibd_segs, true_segs) def test_time(self): @@ -407,11 +525,18 @@ def test_time(self): ibd_f = ibd.IbdFinder(ts, max_time=1.8) ibd_segs = ibd_f.find_ibd_segments() true_segs = {(0, 1): []} + ibd_segs = convert_dict_of_segmentlists(ibd_segs) assert ibd_is_equal(ibd_segs, true_segs) ibd_f = ibd.IbdFinder(ts, max_time=2.8) ibd_segs = ibd_f.find_ibd_segments() - true_segs = {(0, 1): [ibd.Segment(0.0, 0.2, 4), ibd.Segment(0.7, 1.0, 4), - ibd.Segment(0.2, 0.7, 4)]} + true_segs = { + (0, 1): [ + ibd.Segment(0.0, 0.2, 4), + ibd.Segment(0.7, 1.0, 4), + ibd.Segment(0.2, 0.7, 4), + ] + } + ibd_segs = convert_dict_of_segmentlists(ibd_segs) assert ibd_is_equal(ibd_segs, true_segs) def test_length(self): @@ -419,75 +544,255 @@ def test_length(self): ibd_f = ibd.IbdFinder(ts, min_length=0.4) ibd_segs = ibd_f.find_ibd_segments() true_segs = {(0, 1): [ibd.Segment(0.2, 0.7, 4)]} + ibd_segs = convert_dict_of_segmentlists(ibd_segs) assert ibd_is_equal(ibd_segs, true_segs) -class TestIbdByLength(unittest.TestCase): - """ - Tests of length-based IBD function. - """ - # 11 * * * - # / \ * * * 10 - # / \ * 9 * * / \ - # / \ * / \ * 8 * / 8 - # | | * / \ * / \ * / / \ - # | 7 * / 7 * / 7 * / / 7 - # | / \ * | / \ * / / \ * / / / \ - # | / 6 * | / 6 * / / 6 * | / | | - # 5 | / \ * 5 | / \ * 5 | / \ * | 5 | | - # / \ | / \ * / \ | / \ * / \ | / \ * | / \ | | - # 0 4 1 2 3 * 0 4 1 2 3 * 0 4 1 2 3 * 3 0 4 1 2 - # - # ------------------------------------------------------------------------------ - # 0 0.10 0.50 0.75 1.00 - - small_tree_ex_nodes = """\ - id flags population individual time - 0 1 0 -1 0.00000000000000 - 1 1 0 -1 0.00000000000000 - 2 1 0 -1 0.00000000000000 - 3 1 0 -1 0.00000000000000 - 4 1 0 -1 0.00000000000000 - 5 0 0 -1 1.00000000000000 - 6 0 0 -1 2.00000000000000 - 7 0 0 -1 3.00000000000000 - 8 0 0 -1 4.00000000000000 - 9 0 0 -1 5.00000000000000 - 10 0 0 -1 6.00000000000000 - 11 0 0 -1 7.00000000000000 - """ - small_tree_ex_edges = """\ - id left right parent child - 0 0.00000000 1.00000000 5 0 - 1 0.00000000 1.00000000 5 4 - 2 0.00000000 0.75000000 6 2 - 3 0.00000000 0.75000000 6 3 - 4 0.00000000 1.00000000 7 1 - 5 0.75000000 1.00000000 7 2 - 6 0.00000000 0.75000000 7 6 - 7 0.50000000 1.00000000 8 5 - 8 0.50000000 1.00000000 8 7 - 9 0.10000000 0.50000000 9 5 - 10 0.10000000 0.50000000 9 7 - 11 0.75000000 1.00000000 10 3 - 12 0.75000000 1.00000000 10 8 - 13 0.00000000 0.10000000 11 5 - 14 0.00000000 0.10000000 11 7 +class TestIbdPolytomies(unittest.TestCase): + # + # 5 | 5 + # / \ | / \ + # 4 \ | 4 \ + # /|\ \ | /|\ \ + # / | \ \ | / | \ \ + # / | \ \ | / | \ \ + # / | \ \ | / | \ \ + # 0 1 2 3 | 0 1 3 2 + # | + # 0.3 + + nodes = io.StringIO( + """\ + id is_sample time + 0 1 0 + 1 1 0 + 2 1 0 + 3 1 0 + 4 0 2.5 + 5 0 3.5 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0.0 1.0 4 0 + 0.0 1.0 4 1 + 0.0 0.3 4 2 + 0.3 1.0 4 3 + 0.3 1.0 5 2 + 0.0 0.3 5 3 + 0.0 1.0 5 4 """ + ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + + def test_defaults(self): + ts = self.ts + ibd_f = ibd.IbdFinder(ts) + ibd_segs = ibd_f.find_ibd_segments() + # print(ibd_segs[(0,1)]) + true_segs = { + (0, 1): [ibd.Segment(0, 1, 4)], + (0, 2): [ibd.Segment(0, 0.3, 4), ibd.Segment(0.3, 1, 5)], + (0, 3): [ibd.Segment(0, 0.3, 5), ibd.Segment(0.3, 1, 4)], + (1, 2): [ibd.Segment(0, 0.3, 4), ibd.Segment(0.3, 1, 5)], + (1, 3): [ibd.Segment(0, 0.3, 5), ibd.Segment(0.3, 1, 4)], + (2, 3): [ibd.Segment(0.3, 1, 5), ibd.Segment(0, 0.3, 5)], + } + ibd_segs = convert_dict_of_segmentlists(ibd_segs) + # print(ibd_segs) + assert ibd_is_equal(ibd_segs, true_segs) + + def test_time(self): + ts = self.ts + ibd_f = ibd.IbdFinder(ts, max_time=3) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = { + (0, 1): [ibd.Segment(0, 1, 4)], + (0, 2): [ibd.Segment(0, 0.3, 4)], + (0, 3): [ibd.Segment(0.3, 1, 4)], + (1, 2): [ibd.Segment(0, 0.3, 4)], + (1, 3): [ibd.Segment(0.3, 1, 4)], + (2, 3): [], + } + ibd_segs = convert_dict_of_segmentlists(ibd_segs) + assert ibd_is_equal(ibd_segs, true_segs) + + def test_length(self): + ts = self.ts + ibd_f = ibd.IbdFinder(ts, min_length=0.5) + ibd_segs = ibd_f.find_ibd_segments() + true_segs = { + (0, 1): [ibd.Segment(0, 1, 4)], + (0, 2): [ibd.Segment(0.3, 1, 5)], + (0, 3): [ibd.Segment(0.3, 1, 4)], + (1, 2): [ibd.Segment(0.3, 1, 5)], + (1, 3): [ibd.Segment(0.3, 1, 4)], + (2, 3): [ibd.Segment(0.3, 1, 5)], + } + ibd_segs = convert_dict_of_segmentlists(ibd_segs) + assert ibd_is_equal(ibd_segs, true_segs) +class TestIbdInternalSamples(unittest.TestCase): + # + # + # 3 + # / \ + # / 2 + # / \ + # 0 (1) + nodes = io.StringIO( + """\ + id is_sample time + 0 1 0 + 1 0 0 + 2 1 1 + 3 0 2 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0.0 1.0 2 1 + 0.0 1.0 3 0 + 0.0 1.0 3 2 + """ + ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + + def test_defaults(self): + ts = self.ts + ibd_f = ibd.IbdFinder(ts) + ibd_segs = ibd_f.find_ibd_segments() + ibd_segs = convert_dict_of_segmentlists(ibd_segs) + true_segs = { + (0, 2): [ibd.Segment(0, 1, 3)], + } + assert ibd_is_equal(ibd_segs, true_segs) + + +class TestIbdRandomExamples(unittest.TestCase): + """ + Randomly generated test cases. + """ + + # Infinite sites, Hudson model. + @unittest.skipIf( + version.parse(msprime.__version__) < version.parse("0.7.5"), + "Needs finite sites simulations to work.", + ) def test_random_example1(self): - ts0 = msprime.simulate(sample_size=10, recombination_rate=.5, random_seed=2) - verify_equal_ibd(ts0) + ts = msprime.simulate(sample_size=10, recombination_rate=0.5, random_seed=2) + verify_equal_ibd(ts) + @unittest.skipIf( + version.parse(msprime.__version__) < version.parse("0.7.5"), + "Needs finite sites simulations to work.", + ) def test_random_example2(self): - ts1 = msprime.simulate(sample_size=10, recombination_rate=.5, random_seed=23) - verify_equal_ibd(ts1) + ts = msprime.simulate(sample_size=10, recombination_rate=0.5, random_seed=23) + verify_equal_ibd(ts) + @unittest.skipIf( + version.parse(msprime.__version__) < version.parse("0.7.5"), + "Needs finite sites simulations to work.", + ) def test_random_example3(self): - ts2 = msprime.simulate(sample_size=10, recombination_rate=.5, random_seed=232) - verify_equal_ibd(ts2) + ts = msprime.simulate(sample_size=10, recombination_rate=0.5, random_seed=232) + verify_equal_ibd(ts) + @unittest.skipIf( + version.parse(msprime.__version__) < version.parse("0.7.5"), + "Needs finite sites simulations to work.", + ) def test_random_example4(self): - ts_r = msprime.simulate(sample_size=10, recombination_rate=.3, random_seed=726) - verify_equal_ibd(ts_r) + ts = msprime.simulate(sample_size=10, recombination_rate=0.3, random_seed=726) + verify_equal_ibd(ts) + + # Finite sites + def sim_finite_sites(self, random_seed, dtwf=False): + seq_length = int(1e5) + positions = random.sample(range(1, seq_length), 98) + [0, seq_length] + positions.sort() + rates = [random.uniform(1e-9, 1e-7) for _ in range(100)] + r_map = msprime.RecombinationMap( + positions=positions, rates=rates, discrete=True + ) + if dtwf: + model = "dtwf" + else: + model = "hudson" + ts = msprime.simulate( + sample_size=10, + recombination_map=r_map, + Ne=10, + random_seed=random_seed, + model=model, + ) + return ts + + @unittest.skipIf( + version.parse(msprime.__version__) < version.parse("0.7.5"), + "Needs finite sites simulations to work.", + ) + def test_finite_sites1(self): + ts = self.sim_finite_sites(9257) + verify_equal_ibd(ts) + + @unittest.skipIf( + version.parse(msprime.__version__) < version.parse("0.7.5"), + "Needs finite sites simulations to work.", + ) + def test_finite_sites2(self): + ts = self.sim_finite_sites(835) + verify_equal_ibd(ts) + + @unittest.skipIf( + version.parse(msprime.__version__) < version.parse("0.7.5"), + "Needs finite sites simulations to work.", + ) + def test_finite_sites3(self): + ts = self.sim_finite_sites(27278) + verify_equal_ibd(ts) + + @unittest.skipIf( + version.parse(msprime.__version__) < version.parse("0.7.5"), + "Needs finite sites simulations to work.", + ) + def test_finite_sites4(self): + ts = self.sim_finite_sites(22446688) + verify_equal_ibd(ts) + + # DTWF + @unittest.skipIf( + version.parse(msprime.__version__) < version.parse("0.7.5"), + "Needs finite sites simulations to work.", + ) + def test_dtwf1(self): + ts = self.sim_finite_sites(84, dtwf=True) + verify_equal_ibd(ts) + + @unittest.skipIf( + version.parse(msprime.__version__) < version.parse("0.7.5"), + "Needs finite sites simulations to work.", + ) + def test_dtwf2(self): + ts = self.sim_finite_sites(17482, dtwf=True) + verify_equal_ibd(ts) + + @unittest.skipIf( + version.parse(msprime.__version__) < version.parse("0.7.5"), + "Needs finite sites simulations to work.", + ) + def test_dtwf3(self): + ts = self.sim_finite_sites(846, dtwf=True) + verify_equal_ibd(ts) + + @unittest.skipIf( + version.parse(msprime.__version__) < version.parse("0.7.5"), + "Needs finite sites simulations to work.", + ) + def test_dtwf4(self): + ts = self.sim_finite_sites(273, dtwf=True) + verify_equal_ibd(ts) From bebf6f6ff10df77112fede890cefc161271d5571 Mon Sep 17 00:00:00 2001 From: Georgia Tsambos Date: Mon, 18 May 2020 15:36:19 +1000 Subject: [PATCH 4/4] Simplified algorithm based on Peter's comments. small bugfix forward DTWF sims are now in tests Samples no longer need to be sorted, sample_id_map is part of the constructor. mygen renamed to edges_iter current_time, current_parent and PARENT_TO_BE_ADDED are now local to find_ibd_segments() More small changes Removed dependency on msprime v1.0 Removed parent_list, added simpler oldest_parent attribute Removed processed_nodes object Small change to variable name. --- python/tests/ibd.py | 280 +++++++++++++++------------------------ python/tests/test_ibd.py | 99 +++++--------- 2 files changed, 137 insertions(+), 242 deletions(-) diff --git a/python/tests/ibd.py b/python/tests/ibd.py index 52625f252e..d23a4fefec 100644 --- a/python/tests/ibd.py +++ b/python/tests/ibd.py @@ -24,6 +24,8 @@ """ import argparse +import numpy as np + import tskit @@ -96,15 +98,6 @@ def __repr__(self): s = "[{}, ..., {}]".format(repr(self.head), repr(self.tail)) return s - def __len__(self): - # Returns the number of segments in the list. - count = 0 - seg = self.head - while seg is not None: - count += 1 - seg = seg.next - return count - def add(self, other): """ Use to append another SegmentList, or a single segment. @@ -143,22 +136,23 @@ def __init__(self, ts, samples=None, min_length=0, max_time=None): self.samples = ts.samples() else: self.samples = samples - self.samples.sort() + if len(self.samples) == 0: + raise ValueError("The tree sequence contains no samples.") + + self.sample_id_map = np.zeros(ts.num_nodes, dtype=int) - 1 + for index, u in enumerate(self.samples): + self.sample_id_map[u] = index self.min_length = min_length - self.max_time = max_time - self.current_parent = self.ts.tables.edges.parent[0] - self.current_time = 0 + if max_time is None: + self.max_time = 2 * ts.max_root_time + else: + self.max_time = max_time self.A = [None for _ in range(ts.num_nodes)] # Descendant segments self.tables = self.ts.tables - self.parent_list = [[] for i in range(0, self.ts.num_nodes)] - for e in self.ts.tables.edges: - if ( - len(self.parent_list[e.child]) == 0 - or e.parent != self.parent_list[e.child][-1] - ): - self.parent_list[e.child].append(e.parent) + + self.oldest_parent = self.get_oldest_parents() + # Objects below are needed for the IBD segment-holding object. - self.sample_id_map = self.get_sample_id_map() self.num_samples = len(self.samples) self.sample_pairs = self.get_sample_pairs() @@ -171,6 +165,18 @@ def __init__(self, ts, samples=None, min_length=0, max_time=None): for key in self.sample_pairs: self.ibd_segments[key] = None + def get_oldest_parents(self): + oldest_parents = [-1 for _ in range(self.ts.num_nodes)] + node_times = self.ts.tables.nodes.time + for e in self.ts.tables.edges: + c = e.child + if ( + oldest_parents[c] == -1 + or node_times[oldest_parents[c]] < node_times[e.parent] + ): + oldest_parents[c] = e.parent + return oldest_parents + def add_ibd_segments(self, sample0, sample1, seg): index = self.find_sample_pair_index(sample0, sample1) @@ -186,19 +192,6 @@ def add_ibd_segments(self, sample0, sample1, seg): else: self.ibd_segments[self.sample_pairs[index]].add(seg) - def get_sample_id_map(self): - """ - Returns id_map, a vector of length ts.num_nodes. For a node with id i, - id_map[i] is the position of the node in self.samples. If i is not a - user-specified sample, id_map[i] is -1. - Note: it assumes nodes in the TS are numbered 0 to ts.num_nodes - 1 - """ - id_map = [-1] * self.ts.num_nodes - for i, samp in enumerate(self.samples): - id_map[samp] = i - - return id_map - def get_sample_pairs(self): """ Returns a list of all pairs of samples. Replaces itertools.combinations. @@ -240,167 +233,104 @@ def find_sample_pair_index(self, sample0, sample1): return int(index) - def find_ibd_segments(self, min_length=0, max_time=None): + def find_ibd_segments(self): """ The wrapper for the procedure that calculates IBD segments. """ # Set up an iterator over the edges in the tree sequence. - mygen = iter(self.ts.edges()) - e = next(mygen) + edges_iter = iter(self.ts.edges()) + e = next(edges_iter) + parent_should_be_added = True + node_times = self.tables.nodes.time # Iterate over the edges. while e is not None: - # Process all edges with the same parent node. - S_head = None - S_tail = None - self.current_parent = e.parent - self.current_time = self.tables.nodes.time[self.current_parent] - if self.max_time is not None and self.current_time > self.max_time: + + current_parent = e.parent + current_time = node_times[current_parent] + if current_time > self.max_time: # Stop looking for IBD segments once the # processed nodes are older than the max time. break - # Create a list of segments, S, bookended by S_head and S_tail. - # The list holds all segments that are immediate descendants of - # the current parent node being processed. - while e is not None and self.current_parent == e.parent: - seg = Segment(e.left, e.right, e.child) - if S_head is None: - S_head = seg - S_tail = seg - else: - S_tail.next = seg - S_tail = S_tail.next - e = next(mygen, None) - - # Create a list of SegmentList objects, SegL, bookended by SegL_head and - # SegL_tail. Each SegmentList corresponds to a segment s in S, and - # contains all segments of samples that descend from s. - SegL_head = SegmentList() - SegL_tail = SegmentList() - - while S_head is not None: - u = S_head.node - list_to_add = SegmentList() - if u in self.samples: - list_to_add.add(Segment(S_head.left, S_head.right, u)) - else: - if self.A[u] is not None: - s = self.A[u].head - while s is not None: - intvl = ( - max(S_head.left, s.left), - min(S_head.right, s.right), - ) - if intvl[1] - intvl[0] > 0: - list_to_add.add(Segment(intvl[0], intvl[1], s.node)) - s = s.next - - # Add this list to the end of SegL. - if SegL_head.head is None: - SegL_head = list_to_add - SegL_tail = list_to_add - else: - SegL_tail.next = list_to_add - SegL_tail = list_to_add - - S_head = S_head.next - - # If we wanted to squash segments, we'd do it here. - - # Use the info in SegL to find new ibd segments that descend from the - # parent node currently being processed. - if SegL_head.next is not None or self.current_parent in self.samples: - nodes_to_remove = self.calculate_ibd_segs(SegL_head) - - # g. Remove elements of the list of ancestral segments A if they - # are no longer needed. (Memory-pruning step) - for n in nodes_to_remove: - if self.current_parent in self.parent_list[n]: - self.parent_list[n].remove(self.current_parent) - if len(self.parent_list[n]) == 0: - self.A[n] = [] - - # Create the list of ancestral segments, A, descending from current node. - # Concatenate all of the segment lists in SegL into a single list, and - # store in A. This is a list of all sample segments descending from - # the currently processed node. - seglist = SegL_head - self.A[self.current_parent] = seglist - seglist = seglist.next - while seglist is not None: - self.A[self.current_parent].add(seglist) - seglist = seglist.next + seg = Segment(e.left, e.right, e.child) + + # Create a SegmentList() holding all segments that descend from seg. + list_to_add = SegmentList() + u = seg.node + if self.sample_id_map[u] != tskit.NULL: + list_to_add.add(seg) + else: + if self.A[u] is not None: + s = self.A[u].head + while s is not None: + intvl = ( + max(seg.left, s.left), + min(seg.right, s.right), + ) + if intvl[1] - intvl[0] > 0: + list_to_add.add(Segment(intvl[0], intvl[1], s.node)) + s = s.next + + if list_to_add.head is not None: + self.calculate_ibd_segs(current_parent, list_to_add) + + # For parents that are also samples + if ( + self.sample_id_map[current_parent] != tskit.NULL + ) and parent_should_be_added: + singleton_seg = SegmentList() + singleton_seg.add(Segment(0, self.ts.sequence_length, current_parent)) + self.calculate_ibd_segs(current_parent, singleton_seg) + parent_should_be_added = False + + # Move to next edge. + e = next(edges_iter, None) + + # Remove any processed nodes that are no longer needed. + if e is not None and e.parent != current_parent: + for i, n in enumerate(self.oldest_parent): + if current_parent == n: + self.A[i] = None return self.ibd_segments - def calculate_ibd_segs(self, SegL_head): + def calculate_ibd_segs(self, current_parent, list_to_add): """ - Calculates all new IBD segments found at the current iteration of the - algorithm. Returns information about nodes processed in this step, which - allows memory to be cleared as the procedure runs. + Write later. """ - list0 = SegL_head - # Iterate through all pairs of segment lists. - while list0.next is not None: - list1 = list0.next - while list1 is not None: - seg0 = list0.head - # Iterate through all segments with one taken from each list. - while seg0 is not None: - seg1 = list1.head - while seg1 is not None: - left = max(seg0.left, seg1.left) - right = min(seg0.right, seg1.right) - if left >= right: - seg1 = seg1.next - continue - nodes = [seg0.node, seg1.node] - nodes.sort() - - # If there are any overlapping segments, record as a new - # IBD relationship. - if right - left > self.min_length: - self.add_ibd_segments( - nodes[0], - nodes[1], - Segment(left, right, self.current_parent), - ) + if list_to_add.head is None: + return [] - seg1 = seg1.next + if self.A[current_parent] is None: + self.A[current_parent] = list_to_add - seg0 = seg0.next - list1 = list1.next - list0 = list0.next - - # If the current processed node is itself a sample, calculate IBD separately. - if self.current_parent in self.samples: - seglist = SegL_head - while seglist is not None: - seg = seglist.head - while seg is not None: - self.add_ibd_segments( - seg.node, - self.current_parent, - Segment(seg.left, seg.right, self.current_parent), - ) - seg = seg.next - seglist = seglist.next - - # Specify elements of A that can be removed (for memory-pruning step) - processed_child_nodes = [] - seglist = SegL_head - while seglist is not None: - seg = seglist.head - while seg is not None: - processed_child_nodes += [seg.node] - seg = seg.next - seglist = seglist.next - processed_child_nodes = list(set(processed_child_nodes)) - - return processed_child_nodes + else: + seg0 = self.A[current_parent].head + while seg0 is not None: + seg1 = list_to_add.head + while seg1 is not None: + left = max(seg0.left, seg1.left) + right = min(seg0.right, seg1.right) + if left >= right: + seg1 = seg1.next + continue + nodes = [seg0.node, seg1.node] + nodes.sort() + + # If there are any overlapping segments, record as a new + # IBD relationship. + if right - left > self.min_length: + self.add_ibd_segments( + nodes[0], nodes[1], Segment(left, right, current_parent), + ) + seg1 = seg1.next + seg0 = seg0.next + + # Add list_to_add to A[u]. + self.A[current_parent].add(list_to_add) if __name__ == "__main__": diff --git a/python/tests/test_ibd.py b/python/tests/test_ibd.py index 75b6870661..2093107b43 100644 --- a/python/tests/test_ibd.py +++ b/python/tests/test_ibd.py @@ -7,9 +7,9 @@ import unittest import msprime -from packaging import version import tests.ibd as ibd +import tests.test_wright_fisher as wf import tskit # Functions for computing IBD 'naively'. @@ -154,24 +154,17 @@ def verify_equal_ibd(treeSequence): # Convert each SegmentList object into a list of Segment objects. ibd0_tolist = {} for key, val in ibd0.items(): - ibd0_tolist[key] = convert_segmentlist_to_list(val) + if val is not None: + ibd0_tolist[key] = convert_segmentlist_to_list(val) # Check for equality. for key0, val0 in ibd0_tolist.items(): + assert key0 in ibd1.keys() val1 = ibd1[key0] val0.sort() val1.sort() - if val0 is None: # Get rid of this later -- don't want empty dict values at all - assert val1 is None - continue - elif val1 is None: - assert val0 is None - assert len(val0) == len(val1) - for i in range(0, len(val0)): - assert val0[i] == val1[i] - def convert_segmentlist_to_list(seglist): """ @@ -243,8 +236,6 @@ def segment_lists_are_equal(val1, val2): return False for i in range(len(val1)): if val1[i] != val2[i]: - # print(val1q, val2) - # print(val1[i], val2[i]) return False return True @@ -430,11 +421,8 @@ def test_no_samples(self): """ ) ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) - ibd_f = ibd.IbdFinder(ts) - ibd_segs = ibd_f.find_ibd_segments() - ibd_segs = convert_dict_of_segmentlists(ibd_segs) - true_segs = {} - assert ibd_is_equal(ibd_segs, true_segs) + with self.assertRaises(ValueError): + ibd.IbdFinder(ts) class TestIbdSamplesAreDescendants(unittest.TestCase): @@ -467,6 +455,7 @@ def test_basic(self): ibd_segs = ibd_f.find_ibd_segments() ibd_segs = convert_dict_of_segmentlists(ibd_segs) true_segs = {(0, 1): [ibd.Segment(0.0, 1.0, 1)]} + assert ibd_is_equal(ibd_segs, true_segs) @@ -678,34 +667,18 @@ class TestIbdRandomExamples(unittest.TestCase): """ # Infinite sites, Hudson model. - @unittest.skipIf( - version.parse(msprime.__version__) < version.parse("0.7.5"), - "Needs finite sites simulations to work.", - ) def test_random_example1(self): ts = msprime.simulate(sample_size=10, recombination_rate=0.5, random_seed=2) verify_equal_ibd(ts) - @unittest.skipIf( - version.parse(msprime.__version__) < version.parse("0.7.5"), - "Needs finite sites simulations to work.", - ) def test_random_example2(self): ts = msprime.simulate(sample_size=10, recombination_rate=0.5, random_seed=23) verify_equal_ibd(ts) - @unittest.skipIf( - version.parse(msprime.__version__) < version.parse("0.7.5"), - "Needs finite sites simulations to work.", - ) def test_random_example3(self): ts = msprime.simulate(sample_size=10, recombination_rate=0.5, random_seed=232) verify_equal_ibd(ts) - @unittest.skipIf( - version.parse(msprime.__version__) < version.parse("0.7.5"), - "Needs finite sites simulations to work.", - ) def test_random_example4(self): ts = msprime.simulate(sample_size=10, recombination_rate=0.3, random_seed=726) verify_equal_ibd(ts) @@ -717,7 +690,7 @@ def sim_finite_sites(self, random_seed, dtwf=False): positions.sort() rates = [random.uniform(1e-9, 1e-7) for _ in range(100)] r_map = msprime.RecombinationMap( - positions=positions, rates=rates, discrete=True + positions=positions, rates=rates, num_loci=seq_length ) if dtwf: model = "dtwf" @@ -732,67 +705,59 @@ def sim_finite_sites(self, random_seed, dtwf=False): ) return ts - @unittest.skipIf( - version.parse(msprime.__version__) < version.parse("0.7.5"), - "Needs finite sites simulations to work.", - ) def test_finite_sites1(self): ts = self.sim_finite_sites(9257) verify_equal_ibd(ts) - @unittest.skipIf( - version.parse(msprime.__version__) < version.parse("0.7.5"), - "Needs finite sites simulations to work.", - ) def test_finite_sites2(self): ts = self.sim_finite_sites(835) verify_equal_ibd(ts) - @unittest.skipIf( - version.parse(msprime.__version__) < version.parse("0.7.5"), - "Needs finite sites simulations to work.", - ) def test_finite_sites3(self): ts = self.sim_finite_sites(27278) verify_equal_ibd(ts) - @unittest.skipIf( - version.parse(msprime.__version__) < version.parse("0.7.5"), - "Needs finite sites simulations to work.", - ) def test_finite_sites4(self): ts = self.sim_finite_sites(22446688) verify_equal_ibd(ts) # DTWF - @unittest.skipIf( - version.parse(msprime.__version__) < version.parse("0.7.5"), - "Needs finite sites simulations to work.", - ) def test_dtwf1(self): ts = self.sim_finite_sites(84, dtwf=True) verify_equal_ibd(ts) - @unittest.skipIf( - version.parse(msprime.__version__) < version.parse("0.7.5"), - "Needs finite sites simulations to work.", - ) def test_dtwf2(self): ts = self.sim_finite_sites(17482, dtwf=True) verify_equal_ibd(ts) - @unittest.skipIf( - version.parse(msprime.__version__) < version.parse("0.7.5"), - "Needs finite sites simulations to work.", - ) def test_dtwf3(self): ts = self.sim_finite_sites(846, dtwf=True) verify_equal_ibd(ts) - @unittest.skipIf( - version.parse(msprime.__version__) < version.parse("0.7.5"), - "Needs finite sites simulations to work.", - ) def test_dtwf4(self): ts = self.sim_finite_sites(273, dtwf=True) verify_equal_ibd(ts) + + def test_sim_wright_fisher_generations(self): + # Uses the bespoke DTWF forward-time simulator. + number_of_gens = 3 + tables = wf.wf_sim(4, number_of_gens, deep_history=False, seed=83) + tables.sort() + ts = tables.tree_sequence() + verify_equal_ibd(ts) + + def test_sim_wright_fisher_generations2(self): + # Uses the bespoke DTWF forward-time simulator. + number_of_gens = 10 + tables = wf.wf_sim(10, number_of_gens, deep_history=False, seed=837) + tables.sort() + ts = tables.tree_sequence() + verify_equal_ibd(ts) + + def test_sim_wright_fisher_generations3(self): + # Uses the bespoke DTWF forward-time simulator. + number_of_gens = 10 + tables = wf.wf_sim(10, number_of_gens, deep_history=False, seed=37) + tables.sort() + ts = tables.tree_sequence() + verify_equal_ibd(ts)