diff --git a/python/tskit/drawing.py b/python/tskit/drawing.py index 840e6b208c..a158f0dfbe 100644 --- a/python/tskit/drawing.py +++ b/python/tskit/drawing.py @@ -194,8 +194,8 @@ def __init__( mutation_label_attrs=None, root_svg_attributes=None, style=None, + root_branch=None, ): - self.ts = ts if size is None: size = (200 * ts.num_trees, 200) if root_svg_attributes is None: @@ -206,9 +206,15 @@ def __init__( self.drawing = svgwrite.Drawing( size=self.image_size, debug=True, **root_svg_attributes ) + if node_labels is None: + node_labels = {u: str(u) for u in range(ts.num_nodes)} if style is not None: self.drawing.defs.add(self.drawing.style(style)) - self.node_labels = {u: str(u) for u in range(ts.num_nodes)} + if root_branch is None: + root_branch = any( + any(tree.parent(mut.node) == NULL for mut in tree.mutations()) + for tree in ts.trees() + ) # TODO add general padding arguments following matplotlib's terminology. self.axes_x_offset = 15 self.axes_y_offset = 10 @@ -231,6 +237,7 @@ def __init__( node_label_attrs=node_label_attrs, mutation_attrs=mutation_attrs, mutation_label_attrs=mutation_label_attrs, + root_branch=root_branch, ) for tree in ts.trees() ] @@ -285,7 +292,7 @@ class SvgTree: """ A class to draw a tree in SVG format. - See :meth:`Tree.draw_svg` for a description of usage and parameters. + See :meth:`Tree.draw_svg` for a description of usage and frequently used parameters. """ def __init__( @@ -303,6 +310,7 @@ def __init__( mutation_label_attrs=None, root_svg_attributes=None, style=None, + root_branch=None, ): self.tree = tree if size is None: @@ -314,6 +322,10 @@ def __init__( self.drawing = self.setup_drawing() if style is not None: self.drawing.defs.add(self.drawing.style(style)) + if root_branch is None: + # put a root branch in if we have mutations over the root + root_branch = any(tree.parent(mut.node) == NULL for mut in tree.mutations()) + self.root_branch = root_branch self.treebox_x_offset = 10 self.treebox_y_offset = 10 self.treebox_width = size[0] - 2 * self.treebox_x_offset @@ -434,13 +446,9 @@ def assign_y_coordinates(self, tree_height_scale, max_tree_height): # node labels within the treebox label_padding = 10 y_padding = self.treebox_y_offset + 2 * label_padding - mutations_over_root = any( - any(tree.parent(mut.node) == NULL for mut in tree.mutations()) - for tree in ts.trees() - ) root_branch_length = 0 height = self.image_size[1] - if mutations_over_root: + if self.root_branch: # Allocate a fixed about of space to show the mutations on the # 'root branch' root_branch_length = height / 10 # FIXME just draw branch??