Skip to content

Commit

Permalink
Merge pull request #1282 from hyanwong/test-empty
Browse files Browse the repository at this point in the history
Test some wacky tree seqs
  • Loading branch information
mergify[bot] committed Apr 1, 2021
2 parents d40978f + 2e5626e commit 6782342
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 48 deletions.
121 changes: 82 additions & 39 deletions python/tests/test_highlevel.py
Expand Up @@ -229,8 +229,12 @@ def get_example_tree_sequences(back_mutations=True, gaps=True, internal_samples=
yield ts
yield tsutil.add_random_metadata(ts)
tables = ts.dump_tables()
tables.nodes.flags = np.zeros_like(tables.nodes.flags)
yield tables.tree_sequence() # no samples
tables = ts.dump_tables()
tables.edges.clear()
yield tables.tree_sequence() # empty tree sequence
yield tables.tree_sequence() # empty tree
yield tskit.TableCollection(sequence_length=1).tree_sequence() # empty tree seq


def get_bottleneck_examples():
Expand Down Expand Up @@ -470,11 +474,13 @@ def verify_trees(self, ts):
assert st1.left_root == st2.left_root
assert sorted(list(roots)) == sorted(st1.roots)
assert st1.roots == st2.roots
if len(roots) > 1:
if len(roots) == 0:
assert st1.root == tskit.NULL
elif len(roots) == 1:
assert st1.root == list(roots)[0]
else:
with pytest.raises(ValueError):
st1.root
else:
assert st1.root == list(roots)[0]
assert st2 == st1
assert not (st2 != st1)
left, right = st1.get_interval()
Expand Down Expand Up @@ -597,6 +603,10 @@ def test_mutations(self):

def verify_pairwise_diversity(self, ts):
haplotypes = ts.genotype_matrix(isolated_as_missing=False).T
if ts.num_samples == 0:
with pytest.raises(ValueError, match="at least one element"):
ts.get_pairwise_diversity()
return
pi1 = ts.get_pairwise_diversity()
pi2 = simple_get_pairwise_diversity(haplotypes)
assert pi1 == pytest.approx(pi2)
Expand Down Expand Up @@ -882,7 +892,7 @@ def test_first_last(self):
assert t1.parent_dict == t2.parent_dict
assert t1.index == 0
if "tracked_samples" in kwargs:
assert t1.num_tracked_samples() != 0
assert t1.num_tracked_samples() == ts.num_samples
else:
assert t1.num_tracked_samples() == 0

Expand All @@ -892,7 +902,7 @@ def test_first_last(self):
assert t1.parent_dict == t2.parent_dict
assert t1.index == ts.num_trees - 1
if "tracked_samples" in kwargs:
assert t1.num_tracked_samples() != 0
assert t1.num_tracked_samples() == ts.num_samples
else:
assert t1.num_tracked_samples() == 0

Expand All @@ -916,13 +926,19 @@ def test_trees_interface(self):

def test_get_pairwise_diversity(self):
for ts in get_example_tree_sequences():
with pytest.raises(ValueError):
with pytest.raises(ValueError, match="at least one element"):
ts.get_pairwise_diversity([])
samples = list(ts.samples())
assert ts.get_pairwise_diversity() == ts.get_pairwise_diversity(samples)
assert ts.get_pairwise_diversity(samples[:2]) == ts.get_pairwise_diversity(
list(reversed(samples[:2]))
)
if len(samples) == 0:
with pytest.raises(
ValueError, match="Sample sets must contain at least one element"
):
ts.get_pairwise_diversity()
else:
assert ts.get_pairwise_diversity() == ts.get_pairwise_diversity(samples)
assert ts.get_pairwise_diversity(
samples[:2]
) == ts.get_pairwise_diversity(list(reversed(samples[:2])))

def test_populations(self):
more_than_zero = False
Expand Down Expand Up @@ -975,7 +991,7 @@ def test_get_population(self):
ts.get_population(N)
with pytest.raises(ValueError):
ts.get_population(N + 1)
for node in [0, N - 1]:
for node in range(0, N - 1):
assert ts.get_population(node) == ts.node(node).population

def test_get_time(self):
Expand All @@ -993,10 +1009,18 @@ def test_get_time(self):

def test_max_root_time(self):
for ts in get_example_tree_sequences():
oldest = max(
max(tree.time(root) for root in tree.roots) for tree in ts.trees()
)
assert oldest == ts.max_root_time
oldest = None
for tree in ts.trees():
for root in tree.roots:
oldest = (
tree.time(root)
if oldest is None
else max(oldest, tree.time(root))
)
if oldest is None:
assert pytest.raises(ValueError, match="max()")
else:
assert oldest == ts.max_root_time

def test_max_root_time_corner_cases(self):
tables = tskit.TableCollection(1)
Expand Down Expand Up @@ -1129,9 +1153,11 @@ def test_simplify(self):
for ts in get_example_tree_sequences():
self.verify_tables_api_equality(ts)
self.verify_simplify_provenance(ts)
n = ts.get_sample_size()
num_mutations += ts.get_num_mutations()
sample_sizes = {0, 1}
n = ts.num_samples
num_mutations += ts.num_mutations
sample_sizes = {0}
if n > 1:
sample_sizes |= {1}
if n > 2:
sample_sizes |= {2, max(2, n // 2), n - 1}
for k in sample_sizes:
Expand Down Expand Up @@ -1354,8 +1380,8 @@ def test_list(self):
assert t1 == t2
assert t1.parent_dict == t2.parent_dict
if "tracked_samples" in kwargs:
assert t1.num_tracked_samples() != 0
assert t2.num_tracked_samples() != 0
assert t1.num_tracked_samples() == ts.num_samples
assert t2.num_tracked_samples() == ts.num_samples
else:
assert t1.num_tracked_samples() == 0
assert t2.num_tracked_samples() == 0
Expand All @@ -1382,7 +1408,7 @@ def test_at_index(self):
assert t1.interval == t2.interval
assert t1.parent_dict == t2.parent_dict
if "tracked_samples" in kwargs:
assert t2.num_tracked_samples() != 0
assert t2.num_tracked_samples() == ts.num_samples
else:
assert t2.num_tracked_samples() == 0

Expand All @@ -1405,7 +1431,7 @@ def test_at(self):
assert t3.interval == t2.interval
assert t3.parent_dict == t2.parent_dict
if "tracked_samples" in kwargs:
assert t2.num_tracked_samples() != 0
assert t2.num_tracked_samples() == ts.num_samples
else:
assert t2.num_tracked_samples() == 0

Expand Down Expand Up @@ -1591,7 +1617,12 @@ def test_tree_node_edges(self):
for mapping, tree in zip(ts._tree_node_edges(), ts.trees()):
node_mapped = mapping >= 0
edge_visited[mapping[node_mapped]] = True
assert np.sum(node_mapped) == len(list(tree.nodes())) - tree.num_roots
# Note that tree.nodes() does not necessarily list all the nodes
# in the tree topology, only the ones that descend from a root.
# Therefore if not all the topological trees in a single `Tree` have
# a root, we can have edges above nodes that are not listed. This
# happens, for example, in a tree with no sample nodes.
assert np.sum(node_mapped) >= len(list(tree.nodes())) - tree.num_roots
for u in tree.nodes():
if tree.parent(u) == tskit.NULL:
assert mapping[u] == tskit.NULL
Expand Down Expand Up @@ -2149,24 +2180,24 @@ def test_num_children(self):
assert tree.num_children(u) == len(tree.children(u))

def test_root_properties(self):
tested = set()
for ts in get_example_tree_sequences():
for tree in ts.trees():
if tree.has_single_root:
tested.add("single")
assert tree.num_roots == 1
assert tree.num_roots == 1
assert tree.root != tskit.NULL
elif tree.has_multiple_roots:
tested.add("multiple")
assert tree.num_roots > 1
with pytest.raises(ValueError, match="More than one root exists"):
_ = tree.root
else:
tested.add("zero")
assert tree.num_roots == 0
assert tree.root == tskit.NULL

def test_root_properties_empty_ts(self):
# NB - this can be removed once the example_tree_sequences contain an empty ts
tree = tskit.TableCollection(sequence_length=1).tree_sequence().first()
assert tree.num_roots == 0
assert tree.root == tskit.NULL
assert len(tested) == 3

def verify_newick(self, tree):
"""
Expand All @@ -2175,7 +2206,7 @@ def verify_newick(self, tree):
# TODO to make this work we may need to clamp the precision of node
# times because Python and C float printing algorithms work slightly
# differently. Seems to work OK now, so leaving alone.
if tree.num_roots == 1:
if tree.has_single_root:
py_tree = tests.PythonTree.from_tree(tree)
newick1 = tree.newick(precision=16)
newick2 = py_tree.newick()
Expand Down Expand Up @@ -2233,12 +2264,23 @@ def test_newick_topology_equiv(self):
replace_numeric = {ord(x): None for x in "1234567890:."}
for ts in get_example_tree_sequences():
for tree in ts.trees():
if tree.num_roots > 1:
continue
plain_newick = tree.newick(node_labels={}, include_branch_lengths=False)
newick1 = tree.newick().translate(replace_numeric)
newick2 = tree.newick(node_labels={}).translate(replace_numeric)
assert newick1 == newick2 == plain_newick
if not tree.has_single_root:
with pytest.raises(ValueError) as plain_newick_err:
tree.newick(node_labels={}, include_branch_lengths=False)
with pytest.raises(ValueError) as newick1_err:
tree.newick()
with pytest.raises(ValueError) as newick2_err:
tree.newick(node_labels={})
assert str(newick1_err) == str(newick2_err)
assert str(newick1_err) == str(plain_newick_err)
else:
plain_newick = tree.newick(
node_labels={}, include_branch_lengths=False
)
newick1 = tree.newick().translate(replace_numeric)
newick2 = tree.newick(node_labels={}).translate(replace_numeric)
assert newick1 == newick2
assert newick2 == plain_newick

def test_newick_buffer_too_small_bug(self):
nodes = io.StringIO(
Expand Down Expand Up @@ -2291,7 +2333,8 @@ def verify_nx_graph_topology(self, tree, g):
assert set(tree.leaves()) == {n for n in g.nodes if g.out_degree(n) == 0}

# test if tree has no in-degrees > 1
assert nx.is_branching(g)
if len(g) > 0:
assert nx.is_branching(g)

def verify_nx_algorithm_equivalence(self, tree, g):
for root in tree.roots:
Expand Down Expand Up @@ -2419,7 +2462,7 @@ def verify_traversals(self, tree):
"breadthfirst",
"minlex_postorder",
]
if tree.has_single_root:
if tree.num_roots == 1:
with pytest.raises(ValueError):
list(t1.nodes(order="bad order"))
assert list(t1.nodes()) == list(t1.nodes(t1.get_root()))
Expand Down
5 changes: 3 additions & 2 deletions python/tests/tsutil.py
Expand Up @@ -1534,8 +1534,9 @@ def iterate(self):
edge = edges[in_order[j]]
self.insert_edge(edge)
j += 1
while self.left_sib[self.left_root] != tskit.NULL:
self.left_root = self.left_sib[self.left_root]
if self.left_root != tskit.NULL:
while self.left_sib[self.left_root] != tskit.NULL:
self.left_root = self.left_sib[self.left_root]
right = sequence_length
if j < M:
right = min(right, edges[in_order[j]].left)
Expand Down
20 changes: 13 additions & 7 deletions python/tskit/trees.py
Expand Up @@ -2291,9 +2291,9 @@ def newick(
:rtype: str
"""
if root is None:
if self.num_roots > 1:
if not self.has_single_root:
raise ValueError(
"Cannot get newick for multiroot trees. Try "
"Cannot get newick unless a tree has a single root. Try "
"[t.newick(root) for root in t.roots] to get a list of "
"newick trees, one for each root."
)
Expand Down Expand Up @@ -3848,19 +3848,25 @@ def num_migrations(self):
@property
def max_root_time(self):
"""
Returns time of the oldest root in any of the trees in this tree sequence.
Returns the time of the oldest root in any of the trees in this tree sequence.
This is usually equal to ``np.max(ts.tables.nodes.time)`` but may not be
since there can be nodes that are not present in any tree. Consistent
with the definition of tree roots, if there are no edges in the tree
sequence we return the time of the oldest sample.
since there can be non-sample nodes that are not present in any tree. Note that
isolated samples are also defined as roots (so there can be a max_root_time
even in a tree sequence with no edges).
:return: The maximum time of a root in this tree sequence.
:rtype: float
:raises ValueError: If there are no samples in the tree, and hence no roots (as
roots are defined by the ends of the upward paths from the set of samples).
"""
if self.num_samples == 0:
raise ValueError(
"max_root_time is not defined in a tree sequence with 0 samples"
)
ret = max(self.node(u).time for u in self.samples())
if self.num_edges > 0:
# Edges are guaranteed to be listed in parent-time order, so we can get the
# last one to get the oldest root.
# last one to get the oldest root
edge = self.edge(self.num_edges - 1)
# However, we can have situations where there is a sample older than a
# 'proper' root
Expand Down

0 comments on commit 6782342

Please sign in to comment.