diff --git a/python/tests/simplify.py b/python/tests/simplify.py index bfd9c387d5..4ae37f3755 100644 --- a/python/tests/simplify.py +++ b/python/tests/simplify.py @@ -108,6 +108,7 @@ def __init__( filter_populations=True, filter_individuals=True, keep_unary=False, + founders=None, ): self.ts = ts self.n = len(sample) @@ -140,6 +141,11 @@ def __init__( self.position_lookup = None if self.reduce_to_site_topology: self.position_lookup = np.hstack([[0], position, [self.sequence_length]]) + self.founders = founders + if self.founders is not None: + self.keep_founders = True + else: + self.keep_founders = False def record_node(self, input_id, is_sample=False): """ @@ -263,6 +269,8 @@ def merge_labeled_ancestors(self, S, input_id): """ output_id = self.node_id_map[input_id] is_sample = output_id != -1 + if self.keep_founders: + is_founder = input_id in self.founders if is_sample: # Free up the existing ancestry mapping. x = self.A_tail[input_id] @@ -281,6 +289,11 @@ def merge_labeled_ancestors(self, S, input_id): if output_id == -1: output_id = self.record_node(input_id) self.record_edge(left, right, output_id, ancestry_node) + elif self.keep_founders and is_founder: + if output_id == -1: + output_id = self.record_node(input_id) + self.record_edge(left, right, output_id, ancestry_node) + else: if output_id == -1: output_id = self.record_node(input_id) diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index 39be7b800a..3cb296d674 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -4550,6 +4550,7 @@ def do_simplify( filter_populations=True, filter_individuals=True, keep_unary=False, + founders=None, ): """ Runs the Python test implementation of simplify. @@ -4563,6 +4564,7 @@ def do_simplify( filter_populations=filter_populations, filter_individuals=filter_individuals, keep_unary=keep_unary, + founders=founders, ) new_ts, node_map = s.simplify() if compare_lib: @@ -5413,6 +5415,83 @@ def test_many_trees_recurrent_mutations_internal_samples(self): for keep in [True, False]: self.verify_simplify_haplotypes(ts, samples, keep_unary=keep) + def test_keep_founders(self): + + # 1.76┊ 14 ┊ 14 ┊ 14 ┊ 14 ┊ ┊ + # ┊ ┏━━┻━┓ ┊ ┏━━┻━━━┓ ┊ ┏━━┻━━┓ ┊ ┏━━┻━━┓ ┊ ┊ + # 1.51┊ ┃ ┃ ┊ ┃ 13 ┊ ┃ 13 ┊ ┃ 13 ┊ 13 ┊ + # ┊ ┃ ┃ ┊ ┃ ┏━┻━┓ ┊ ┃ ┏┻━┓ ┊ ┃ ┏━┻━┓ ┊ ┏━━┻━━┓ ┊ + # 1.27┊ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ 12 ┃ ┊ ┃ 12 ┊ + # ┊ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┏┻┓ ┊ + # 1.21┊ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ 11 ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┏━┻━┓ ┃ ┃ ┊ + # 0.64┊ ┃ ┃ ┊ ┃ ┃ ┃ ┊ 10 ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ + # ┊ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┏┻━┓ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ + # 0.49┊ ┃ 9 ┊ ┃ 9 ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ + # ┊ ┃ ┏━┻━┓ ┊ ┃ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ + # 0.31┊ ┃ ┃ 8 ┊ ┃ ┃ 8 ┃ ┊ ┃ ┃ 8 ┃ ┊ ┃ ┃ ┃ 8 ┊ ┃ 8 ┃ ┃ ┊ + # ┊ ┃ ┃ ┏━┻┓ ┊ ┃ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┃ ┃ ┊ + # 0.26┊ ┃ ┃ ┃ 7 ┊ ┃ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┊ + # ┊ ┃ ┃ ┃ ┏┻┓ ┊ ┃ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┊ + # 0.03┊ 6 ┃ ┃ ┃ ┃ ┊ 6 ┃ ┃ ┃ ┃ ┊ 6 ┃ ┃ ┃ ┃ ┊ 6 ┃ ┃ ┃ ┃ ┊ 6 ┃ ┃ ┃ ┃ ┊ + # ┊ ┏┻┓ ┃ ┃ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┃ ┃ ┊ + # 0.00┊ 0 5 1 2 3 4 ┊ 0 5 1 2 4 3 ┊ 0 5 1 2 4 3 ┊ 0 5 1 3 2 4 ┊ 0 5 2 4 1 3 ┊ + # 0.00 0.18 0.37 0.75 0.91 1.00 + + nodes = io.StringIO( + """\ + id is_sample population time + 0 1 -1 0.00000000000000 + 1 1 -1 0.00000000000000 + 2 0 -1 0.00000000000000 + 3 0 -1 0.00000000000000 + 4 0 -1 0.00000000000000 + 5 0 -1 0.00000000000000 + 6 0 -1 0.03000000000000 + 7 0 -1 0.26000000000000 + 8 0 -1 0.31000000000000 + 9 0 -1 0.49000000000000 + 10 0 -1 0.64000000000000 + 11 0 -1 1.21000000000000 + 12 0 -1 1.27000000000000 + 13 0 -1 1.51000000000000 + 14 0 -1 1.76000000000000 + """ + ) + edges = io.StringIO( + """\ + id left right parent child + 0 0.00000000 1.00000000 6 0,5 + 1 0.00000000 0.18000000 7 3,4 + 2 0.00000000 1.00000000 8 2 + 3 0.18000000 1.00000000 8 4 + 4 0.00000000 0.18000000 8 7 + 5 0.00000000 0.37000000 9 1,8 + 6 0.37000000 0.75000000 10 1,6 + 7 0.91000000 1.00000000 11 6,8 + 8 0.75000000 1.00000000 12 1,3 + 9 0.18000000 0.75000000 13 3 + 10 0.37000000 0.91000000 13 8 + 11 0.18000000 0.37000000 13 9 + 12 0.91000000 1.00000000 13 11 + 13 0.75000000 1.00000000 13 12 + 14 0.00000000 0.37000000 14 6 + 15 0.75000000 0.91000000 14 6 + 16 0.00000000 0.18000000 14 9 + 17 0.37000000 0.75000000 14 10 + 18 0.18000000 0.91000000 14 13 + """ + ) + ts = tskit.load_text(nodes=nodes, edges=edges, strict=False) + ts.dump("founders.trees") + ts, node_map = self.do_simplify( + ts, samples=ts.samples(), founders=[14], compare_lib=False + ) + # for t in ts.trees(): + # print(t.draw(format="unicode")) + # t1 = ts.dump_tables() + # print(t1) + class TestMapToAncestors(unittest.TestCase): """