Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/tests/test_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def verify(self, tree):

m1 = drawing.get_left_neighbour(tree, "minlex_postorder")
m2 = get_left_neighbour(tree, "minlex_postorder")
np.testing.assert_array_equal(m1, m2)

def test_2_binary(self):
ts = msprime.simulate(2, random_seed=2)
Expand Down Expand Up @@ -281,6 +282,24 @@ def test_zero_roots(self):
def test_multiroot(self):
self.verify(self.get_multiroot_tree())

def test_left_child(self):
t = self.get_nonbinary_tree()
left_child = drawing.get_left_child(t, "postorder")
for u in t.nodes(order="postorder"):
if t.num_children(u) > 0:
self.assertEqual(left_child[u], t.children(u)[0])

def test_null_node_left_child(self):
t = self.get_nonbinary_tree()
left_child = drawing.get_left_child(t, "minlex_postorder")
self.assertEqual(left_child[tskit.NULL], tskit.NULL)

def test_leaf_node_left_child(self):
t = self.get_nonbinary_tree()
left_child = drawing.get_left_child(t, "minlex_postorder")
for u in t.samples():
self.assertEqual(left_child[u], tskit.NULL)


class TestOrder(TestTreeDraw):
"""
Expand Down
19 changes: 17 additions & 2 deletions python/tskit/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ def draw(self):
node_x_coord_map = self.node_x_coord_map
node_y_coord_map = self.node_y_coord_map
tree = self.tree
left_child = get_left_child(tree, self.traversal_order)

# Iterate over nodes, adding groups to reflect the tree heirarchy
stack = []
Expand Down Expand Up @@ -658,7 +659,7 @@ def draw(self):
curr_svg_group.add(dwg.circle(**self.node_attrs[u]))
# Labels
if not tree.is_leaf(u) and tree.parent(u) != NULL:
if tree.left_sib(u) == NULL:
if u == left_child[tree.parent(u)]:
self.add_class(self.node_label_attrs[u], "rgt")
else:
self.add_class(self.node_label_attrs[u], "lft")
Expand All @@ -677,7 +678,7 @@ def draw(self):
# Symbols
mut_group.add(dwg.rect(insert=o, **self.mutation_attrs[mutation.id]))
# Labels
if tree.left_sib(mutation.node) == NULL:
if mutation.node == left_child[tree.parent(mutation.node)]:
mut_label_class = "rgt"
else:
mut_label_class = "lft"
Expand Down Expand Up @@ -805,6 +806,20 @@ def find_neighbours(u, neighbour):
return left_neighbour[:-1]


def get_left_child(tree, traversal_order):
"""
Returns the left-most child of each node in the tree according to the
specified traversal order. If a node has no children or NULL is passed
in, return NULL.
"""
left_child = np.full(tree.num_nodes + 1, NULL, dtype=int)
for u in tree.nodes(order=traversal_order):
parent = tree.parent(u)
if parent != NULL and left_child[parent] == NULL:
left_child[parent] = u
return left_child


def node_time_depth(tree, min_branch_length=None, max_tree_height="tree"):
"""
Returns a dictionary mapping nodes in the specified tree to their depth
Expand Down