Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/tests/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
filter_populations=True,
filter_individuals=True,
keep_unary=False,
founders=None,
):
self.ts = ts
self.n = len(sample)
Expand Down Expand Up @@ -140,6 +141,11 @@ def __init__(
self.position_lookup = None
if self.reduce_to_site_topology:
self.position_lookup = np.hstack([[0], position, [self.sequence_length]])
self.founders = founders
if self.founders is not None:
self.keep_founders = True
else:
self.keep_founders = False

def record_node(self, input_id, is_sample=False):
"""
Expand Down Expand Up @@ -263,6 +269,8 @@ def merge_labeled_ancestors(self, S, input_id):
"""
output_id = self.node_id_map[input_id]
is_sample = output_id != -1
if self.keep_founders:
is_founder = input_id in self.founders
if is_sample:
# Free up the existing ancestry mapping.
x = self.A_tail[input_id]
Expand All @@ -281,6 +289,11 @@ def merge_labeled_ancestors(self, S, input_id):
if output_id == -1:
output_id = self.record_node(input_id)
self.record_edge(left, right, output_id, ancestry_node)
elif self.keep_founders and is_founder:
if output_id == -1:
output_id = self.record_node(input_id)
self.record_edge(left, right, output_id, ancestry_node)

else:
if output_id == -1:
output_id = self.record_node(input_id)
Expand Down
79 changes: 79 additions & 0 deletions python/tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -4550,6 +4550,7 @@ def do_simplify(
filter_populations=True,
filter_individuals=True,
keep_unary=False,
founders=None,
):
"""
Runs the Python test implementation of simplify.
Expand All @@ -4563,6 +4564,7 @@ def do_simplify(
filter_populations=filter_populations,
filter_individuals=filter_individuals,
keep_unary=keep_unary,
founders=founders,
)
new_ts, node_map = s.simplify()
if compare_lib:
Expand Down Expand Up @@ -5413,6 +5415,83 @@ def test_many_trees_recurrent_mutations_internal_samples(self):
for keep in [True, False]:
self.verify_simplify_haplotypes(ts, samples, keep_unary=keep)

def test_keep_founders(self):

# 1.76┊ 14 ┊ 14 ┊ 14 ┊ 14 ┊ ┊
# ┊ ┏━━┻━┓ ┊ ┏━━┻━━━┓ ┊ ┏━━┻━━┓ ┊ ┏━━┻━━┓ ┊ ┊
# 1.51┊ ┃ ┃ ┊ ┃ 13 ┊ ┃ 13 ┊ ┃ 13 ┊ 13 ┊
# ┊ ┃ ┃ ┊ ┃ ┏━┻━┓ ┊ ┃ ┏┻━┓ ┊ ┃ ┏━┻━┓ ┊ ┏━━┻━━┓ ┊
# 1.27┊ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ 12 ┃ ┊ ┃ 12 ┊
# ┊ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┏┻┓ ┊
# 1.21┊ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ 11 ┃ ┃ ┊
# ┊ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┏━┻━┓ ┃ ┃ ┊
# 0.64┊ ┃ ┃ ┊ ┃ ┃ ┃ ┊ 10 ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊
# ┊ ┃ ┃ ┊ ┃ ┃ ┃ ┊ ┏┻━┓ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊
# 0.49┊ ┃ 9 ┊ ┃ 9 ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊
# ┊ ┃ ┏━┻━┓ ┊ ┃ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┊
# 0.31┊ ┃ ┃ 8 ┊ ┃ ┃ 8 ┃ ┊ ┃ ┃ 8 ┃ ┊ ┃ ┃ ┃ 8 ┊ ┃ 8 ┃ ┃ ┊
# ┊ ┃ ┃ ┏━┻┓ ┊ ┃ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┃ ┏┻┓ ┊ ┃ ┏┻┓ ┃ ┃ ┊
# 0.26┊ ┃ ┃ ┃ 7 ┊ ┃ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┊
# ┊ ┃ ┃ ┃ ┏┻┓ ┊ ┃ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┊ ┃ ┃ ┃ ┃ ┃ ┊
# 0.03┊ 6 ┃ ┃ ┃ ┃ ┊ 6 ┃ ┃ ┃ ┃ ┊ 6 ┃ ┃ ┃ ┃ ┊ 6 ┃ ┃ ┃ ┃ ┊ 6 ┃ ┃ ┃ ┃ ┊
# ┊ ┏┻┓ ┃ ┃ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┃ ┃ ┊ ┏┻┓ ┃ ┃ ┃ ┃ ┊
# 0.00┊ 0 5 1 2 3 4 ┊ 0 5 1 2 4 3 ┊ 0 5 1 2 4 3 ┊ 0 5 1 3 2 4 ┊ 0 5 2 4 1 3 ┊
# 0.00 0.18 0.37 0.75 0.91 1.00

nodes = io.StringIO(
"""\
id is_sample population time
0 1 -1 0.00000000000000
1 1 -1 0.00000000000000
2 0 -1 0.00000000000000
3 0 -1 0.00000000000000
4 0 -1 0.00000000000000
5 0 -1 0.00000000000000
6 0 -1 0.03000000000000
7 0 -1 0.26000000000000
8 0 -1 0.31000000000000
9 0 -1 0.49000000000000
10 0 -1 0.64000000000000
11 0 -1 1.21000000000000
12 0 -1 1.27000000000000
13 0 -1 1.51000000000000
14 0 -1 1.76000000000000
"""
)
edges = io.StringIO(
"""\
id left right parent child
0 0.00000000 1.00000000 6 0,5
1 0.00000000 0.18000000 7 3,4
2 0.00000000 1.00000000 8 2
3 0.18000000 1.00000000 8 4
4 0.00000000 0.18000000 8 7
5 0.00000000 0.37000000 9 1,8
6 0.37000000 0.75000000 10 1,6
7 0.91000000 1.00000000 11 6,8
8 0.75000000 1.00000000 12 1,3
9 0.18000000 0.75000000 13 3
10 0.37000000 0.91000000 13 8
11 0.18000000 0.37000000 13 9
12 0.91000000 1.00000000 13 11
13 0.75000000 1.00000000 13 12
14 0.00000000 0.37000000 14 6
15 0.75000000 0.91000000 14 6
16 0.00000000 0.18000000 14 9
17 0.37000000 0.75000000 14 10
18 0.18000000 0.91000000 14 13
"""
)
ts = tskit.load_text(nodes=nodes, edges=edges, strict=False)
ts.dump("founders.trees")
ts, node_map = self.do_simplify(
ts, samples=ts.samples(), founders=[14], compare_lib=False
)
# for t in ts.trees():
# print(t.draw(format="unicode"))
# t1 = ts.dump_tables()
# print(t1)


class TestMapToAncestors(unittest.TestCase):
"""
Expand Down