diff --git a/python/tests/data/svg/tree.svg b/python/tests/data/svg/tree.svg index 7745c45d91..81feb6f487 100644 --- a/python/tests/data/svg/tree.svg +++ b/python/tests/data/svg/tree.svg @@ -1,70 +1,47 @@ - + - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - 5 - + + + + + + 0 - - - 9 - - - 0 - - - 1 - - - 2 - - - 3 - + + + + 1 - - - 4 - + + + 4 + + + 0 - - - - - 0 - + + + + + 2 + + + + + 3 + + + 5 + + 9 diff --git a/python/tests/data/svg/ts.svg b/python/tests/data/svg/ts.svg index 02a1758184..6e68a58407 100644 --- a/python/tests/data/svg/ts.svg +++ b/python/tests/data/svg/ts.svg @@ -1,304 +1,198 @@ + - - - + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - 5 - - - - - 9 - - - 0 - - - 1 - - - 2 - - - 3 - - - - - 4 - - - - - - - - 0 - - - + + + + + + 0 + + + + + 1 + + + + 4 + + + 0 + + + + + + + 2 + + + + + 3 + + + + 5 + + + 9 - - - - - - - - - - - - - - - - - - - - - - - - - 5 - - - - - 7 - - - 0 - - - 1 - - - 2 - - - 3 - - - - - 4 - - - - - - - + + + + + + 0 + + + + + 1 + + + + 4 + + + + + + 2 + + + + + 3 + + + + 5 + + + 7 - - - - - - - - - - - - - - - - - - - - - - - - - 5 - - - - - 6 - - - 0 - - - 1 - - - 2 - - - 3 - - - - - 4 - - - - - - - + + + + + + 0 + + + + + 1 + + + + 4 + + + + + + 2 + + + + + 3 + + + + 5 + + + 6 - - - - - - - - - - - - - - - - - - - - - - - - - 5 - - - - - 7 - - - 0 - - - 1 - - - 2 - - - 3 - - - - - 4 - - - - - - - + + + + + + 0 + + + + + 1 + + + + 4 + + + + + + 2 + + + + + 3 + + + + 5 + + + 7 - - - - - - - - - - - - - - - - - - - - - - - - - 5 - - - - - 8 - - - 0 - - - 1 - - - 2 - - - 3 - - - - - 4 - - - - - - - + + + + + + 0 + + + + + 1 + + + + 4 + + + + + + 2 + + + + + 3 + + + + 5 + + + 8 @@ -308,20 +202,20 @@ 0.00 - - + + 0.06 - - + + 0.79 - - + + 0.91 - - + + 0.91 diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index 381bba595f..f9989457a9 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -25,6 +25,7 @@ """ import collections import io +import math import os import tempfile import unittest @@ -1410,13 +1411,7 @@ def verify_basic_svg(self, svg, width=200, height=200): for group in groups: self.assertIn("class", group.attrib) cls = group.attrib["class"] - self.assertRegexpMatches(cls, r"\b(edges|symbols|labels)\b") - if "symbols" in cls or "labels" in cls: - # Check that we have nodes & mutations subgroups - for subgroup in group.findall(prefix + "g"): - self.assertIn("class", subgroup.attrib) - subcls = subgroup.attrib["class"] - self.assertRegexpMatches(subcls, r"\b(nodes|mutations)\b") + self.assertRegexpMatches(cls, r"\broot\b") def test_draw_file(self): t = self.get_binary_tree() @@ -1719,13 +1714,13 @@ def test_max_tree_height(self): svg1 = ts.at_index(0).draw() svg2 = ts.at_index(1).draw() - # if not scaled to ts, node 3 is at a different height in both trees, because the - # root is at a different height. We expect a label looking something like - # 3 where XXXX is different - str_pos = svg1.find(">3<") - snippet1 = svg1[svg1.rfind("3<") - snippet2 = svg2[svg2.rfind("0 + str_pos = svg1.find(">0<") + snippet1 = svg1[svg1.rfind("edge", 0, str_pos) : str_pos] + str_pos = svg2.find(">0<") + snippet2 = svg2[svg2.rfind("edge", 0, str_pos) : str_pos] self.assertNotEqual(snippet1, snippet2) svg1 = ts.at_index(0).draw(max_tree_height="ts") @@ -1734,10 +1729,10 @@ def test_max_tree_height(self): # should be the same self.verify_basic_svg(svg1) self.verify_basic_svg(svg2) - str_pos = svg1.find(">3<") - snippet1 = svg1[svg1.rfind("3<") - snippet2 = svg2[svg2.rfind("0<") + snippet1 = svg1[svg1.rfind("edge", 0, str_pos) : str_pos] + str_pos = svg2.find(">0<") + snippet2 = svg2[svg2.rfind("edge", 0, str_pos) : str_pos] self.assertEqual(snippet1, snippet2) def test_draw_sized_tree(self): @@ -1814,3 +1809,15 @@ def test_known_svg_ts(self): with open(svg_fn, "rb") as file: expected_svg = file.read() self.assertXmlEquivalentOutputs(svg, expected_svg) + + +class TestRounding(unittest.TestCase): + def test_rnd(self): + self.assertEqual(0, drawing.rnd(0)) + self.assertEqual(math.inf, drawing.rnd(math.inf)) + self.assertEqual(1, drawing.rnd(1)) + self.assertEqual(1.1, drawing.rnd(1.1)) + self.assertEqual(1.11111, drawing.rnd(1.111111)) + self.assertEqual(1111110, drawing.rnd(1111111)) + self.assertEqual(123.457, drawing.rnd(123.4567)) + self.assertEqual(123.456, drawing.rnd(123.4564)) diff --git a/python/tskit/drawing.py b/python/tskit/drawing.py index 7d904b1d86..2f2d9cc5c6 100644 --- a/python/tskit/drawing.py +++ b/python/tskit/drawing.py @@ -24,6 +24,7 @@ Module responsible for visualisations. """ import collections +import math import numbers import numpy as np @@ -116,11 +117,22 @@ def check_x_scale(x_scale): return x_scale +def rnd(x): + """ + Round a number so that the output SVG doesn't have unneeded precision + """ + digits = 6 + if x == 0 or not math.isfinite(x): + return x + digits -= math.ceil(math.log10(abs(x))) + return round(x, digits) + + def add_text_in_group(dwg, elem, x, y, text, **kwargs): """ Add the text to the elem within a group. This allows text rotations to work smoothly """ - grp = elem.add(dwg.g(transform=f"translate({x}, {y})")) + grp = elem.add(dwg.g(transform=f"translate({rnd(x)}, {rnd(y)})")) grp.add(dwg.text(text, **kwargs)) @@ -242,10 +254,9 @@ def __init__( if max_tree_height is None: max_tree_height = "ts" self.image_size = size - self.drawing = svgwrite.Drawing( - size=self.image_size, debug=True, **root_svg_attributes - ) - dwg = self.drawing + dwg = svgwrite.Drawing(size=self.image_size, debug=True, **root_svg_attributes) + self.drawing = dwg + dwg.defs.add(dwg.style(SvgTree.standard_style)) if style is not None: dwg.defs.add(dwg.style(style)) root_group = dwg.add(dwg.g(class_="tree-sequence")) @@ -256,7 +267,6 @@ def __init__( else: axis_top_padding = 5 tick_len = (5, 5) - self.node_labels = {u: str(u) for u in range(ts.num_nodes)} # TODO add general padding arguments following matplotlib's terminology. self.axes_x_offset = 15 @@ -292,14 +302,13 @@ def __init__( break_x = self.treebox_x_offset for svg_tree, tree in zip(svg_trees, ts.trees()): - svg_tree.root_group["transform"] = f"translate({tree_x} {y})" + svg_tree.root_group["transform"] = f"translate({rnd(tree_x)} {rnd(y)})" trees.add(svg_tree.root_group) ticks.append((tree_x, break_x, tree.interval[0])) tree_x += tree_width break_x += tree.span * drawing_scale ticks.append((tree_x, break_x, ts.sequence_length)) - # TODO - add the ability to show the commented section below as a flag # # Debug --- draw the tree and axes boxes # w = self.image_size[0] - 2 * self.treebox_x_offset # h = self.image_size[1] - 2 * self.treebox_y_offset @@ -330,21 +339,25 @@ def __init__( background.add( dwg.polygon( [ - (prev_break_x, y + tick_len[1]), - (prev_break_x, y), - (prev_tree_x, y - axis_top_padding), - (prev_tree_x, 0), - (tree_x, 0), - (tree_x, y - axis_top_padding), - (break_x, y), - (break_x, y + tick_len[1]), + (rnd(prev_break_x), rnd(y + tick_len[1])), + (rnd(prev_break_x), rnd(y)), + (rnd(prev_tree_x), rnd(y - axis_top_padding)), + (rnd(prev_tree_x), 0), + (rnd(tree_x), 0), + (rnd(tree_x), rnd(y - axis_top_padding)), + (rnd(break_x), rnd(y)), + (rnd(break_x), rnd(y + tick_len[1])), ], fill="#F1F1F1", ) ) axis.add( - dwg.line((x, y - tick_len[0]), (x, y + tick_len[1]), stroke="black") + dwg.line( + (rnd(x), rnd(y - tick_len[0])), + (rnd(x), rnd(y + tick_len[1])), + stroke="black", + ) ) add_text_in_group( dwg, @@ -365,6 +378,29 @@ class SvgTree: See :meth:`Tree.draw_svg` for a description of usage and parameters. """ + standard_style = ( + ".axis {font-weight: bold}" + ".tree, .axis {font-size: 14px; text-anchor:middle;}" + ".edge {stroke: black; fill: none}" + ".node > circle {r: 3px; fill: black; stroke: none}" + ".tree text {dominant-baseline: middle}" # not inherited in css 1.1 + ".mut > text.lft {transform: translateX(0.5em); text-anchor: start}" + ".mut > text.rgt {transform: translateX(-0.5em); text-anchor: end}" + ".root > text {transform: translateY(-0.8em)}" # Root + ".leaf > text {transform: translateY(1em)}" # Leaves + ".node > text.lft {transform: translate(0.5em, -0.5em); text-anchor: start}" + ".node > text.rgt {transform: translate(-0.5em, -0.5em); text-anchor: end}" + ".mut {fill: red; font-style: italic}" + ) + + @staticmethod + def add_class(attrs_dict, classes_str): + """Adds the classes_str to the 'class' key in attrs_dict, or creates it""" + try: + attrs_dict["class"] += " " + classes_str + except KeyError: + attrs_dict["class"] = classes_str + def __init__( self, tree, @@ -400,6 +436,10 @@ def __init__( self.node_x_coord_map = self.assign_x_coordinates( tree, self.treebox_x_offset, self.treebox_width ) + self.node_mutations = collections.defaultdict(list) + for site in tree.sites(): + for mutation in site.mutations: + self.node_mutations[mutation.node].append(mutation) self.edge_attrs = {} self.node_attrs = {} @@ -408,7 +448,7 @@ def __init__( self.edge_attrs[u] = {} if edge_attrs is not None and u in edge_attrs: self.edge_attrs[u].update(edge_attrs[u]) - self.node_attrs[u] = {"r": 3} + self.node_attrs[u] = {} if node_attrs is not None and u in node_attrs: self.node_attrs[u].update(node_attrs[u]) label = "" @@ -448,26 +488,9 @@ def setup_drawing(self): dwg = svgwrite.Drawing( size=self.image_size, debug=True, **self.root_svg_attributes ) + dwg.defs.add(dwg.style(self.standard_style)) tree_class = f"tree t{self.tree.index}" self.root_group = dwg.add(dwg.g(class_=tree_class)) - self.edges = self.root_group.add( - dwg.g(class_="edges", stroke="black", fill="none") - ) - self.symbols = self.root_group.add(dwg.g(class_="symbols")) - self.nodes = self.symbols.add(dwg.g(class_="nodes")) - self.mutations = self.symbols.add(dwg.g(class_="mutations", fill="red")) - self.labels = self.root_group.add( - dwg.g(class_="labels", font_size=14, dominant_baseline="middle") - ) - self.node_labels = self.labels.add(dwg.g(class_="nodes")) - self.mutation_labels = self.labels.add( - dwg.g(class_="mutations", font_style="italic") - ) - self.left_labels = self.node_labels.add(dwg.g(text_anchor="start")) - self.mid_labels = self.node_labels.add(dwg.g(text_anchor="middle")) - self.right_labels = self.node_labels.add(dwg.g(text_anchor="end")) - self.mutation_left_labels = self.mutation_labels.add(dwg.g(text_anchor="start")) - self.mutation_right_labels = self.mutation_labels.add(dwg.g(text_anchor="end")) return dwg def assign_y_coordinates(self, tree_height_scale, max_tree_height): @@ -554,92 +577,112 @@ def assign_x_coordinates(self, tree, x_start, width): node_x_coord_map[u] = a + (b - a) / 2 return node_x_coord_map + def info_classes(self, focal_node): + """ + For a focal node id, return a set of classes that encode this useful information: + "nA": where A == focal node id + "pB" or "root": where B == parent id (or "root" if the focal node is a root) + "sample": a class present if the focal node is a sample + "leaf": a class present if the focal node is a leaf + "mC": where C == mutation id of all mutations above this focal node + "sD": where D == site id of the sites associated with all mutations + above this focal node + """ + # Add a new group for each node, and give it classes for css targetting + classes = set() + classes.add(f"node n{focal_node}") + v = self.tree.parent(focal_node) + if v == NULL: + classes.add(f"root") + else: + classes.add(f"p{v}") + if self.tree.is_sample(focal_node): + classes.add("sample") + if self.tree.is_leaf(focal_node): + classes.add("leaf") + for mutation in self.node_mutations[focal_node]: + # Adding mutations and sites above this node allows identification + # of the tree under any specific mutation + classes.add(f"m{mutation.id}") + classes.add(f"s{mutation.site}") + return sorted(classes) + def draw(self): dwg = self.drawing node_x_coord_map = self.node_x_coord_map node_y_coord_map = self.node_y_coord_map tree = self.tree - node_mutations = collections.defaultdict(list) - for site in tree.sites(): - for mutation in site.mutations: - node_mutations[mutation.node].append(mutation) - - for u in tree.nodes(): - pu = node_x_coord_map[u], node_y_coord_map[u] - node_class = f"n{u}" - if tree.is_sample(u): - node_class += " sample" - self.nodes.add( - dwg.circle(center=pu, class_=node_class, **self.node_attrs[u]) - ) - dx = 0 - dy = -5 - labels = self.mid_labels - if tree.is_leaf(u): - dy = 20 - elif tree.parent(u) != NULL: - dx = 5 - if tree.left_sib(u) == NULL: - dx *= -1 - labels = self.right_labels - else: - labels = self.left_labels - # TODO get rid of these manual positioning tweaks and add them - # as offsets the user can access via a transform or something. - add_text_in_group( - dwg, - labels, - pu[0] + dx, - pu[1] + dy, - class_=node_class, - **self.node_label_attrs[u], + # Iterate over nodes, adding groups to reflect the tree heirarchy + stack = [] + for u in tree.roots: + grp = dwg.g( + class_=" ".join(self.info_classes(u)), + transform=f"translate({rnd(node_x_coord_map[u])} " + f"{rnd(node_y_coord_map[u])})", ) + stack.append((u, self.root_group.add(grp))) + while len(stack) > 0: + u, curr_svg_group = stack.pop() + pu = node_x_coord_map[u], node_y_coord_map[u] + for focal in tree.children(u): + fx = node_x_coord_map[focal] - pu[0] + fy = node_y_coord_map[focal] - pu[1] + new_svg_group = curr_svg_group.add( + dwg.g( + class_=" ".join(self.info_classes(focal)), + transform=f"translate({rnd(fx)} {rnd(fy)})", + ) + ) + stack.append((focal, new_svg_group)) + + o = (0, 0) v = tree.parent(u) + + # Add edge first => below if v != NULL: - edge_class = f"p{v} c{u}" + self.add_class(self.edge_attrs[u], "edge") pv = node_x_coord_map[v], node_y_coord_map[v] + dx = pv[0] - pu[0] + dy = pv[1] - pu[1] path = dwg.path( - [("M", pu), ("V", pv[1]), ("H", pv[0])], - class_=edge_class, - **self.edge_attrs[u], + [("M", o), ("V", rnd(dy)), ("H", rnd(dx))], **self.edge_attrs[u] ) - self.edges.add(path) + curr_svg_group.add(path) # Edges in parent group, so else: # FIXME this is pretty crappy for spacing mutations over a root. pv = (pu[0], pu[1] - 20) - num_mutations = len(node_mutations[u]) - delta = (pv[1] - pu[1]) / (num_mutations + 1) - x = pu[0] - y = pv[1] - delta - for mutation in reversed(node_mutations[u]): - mutation_class = f"m{mutation.id} s{mutation.site} n{u}" - self.mutations.add( - dwg.rect( - insert=(x, y), - class_=mutation_class, - **self.mutation_attrs[mutation.id], - ) - ) - dx = 5 - if tree.left_sib(mutation.node) == NULL: - dx *= -1 - labels = self.mutation_right_labels + # Add node symbol + label next (visually above the edge subtending this node) + # Symbols + curr_svg_group.add(dwg.circle(**self.node_attrs[u])) + # Labels + if not tree.is_leaf(u) and tree.parent(u) != NULL: + if tree.left_sib(u) == NULL: + self.add_class(self.node_label_attrs[u], "rgt") else: - labels = self.mutation_left_labels + self.add_class(self.node_label_attrs[u], "lft") + curr_svg_group.add(dwg.text(**self.node_label_attrs[u])) + + # Add mutation symbols + labels + delta = (pv[1] - pu[1]) / (len(self.node_mutations[u]) + 1) + for i, mutation in enumerate(reversed(self.node_mutations[u])): # TODO get rid of these manual positioning tweaks and add them # as offsets the user can access via a transform or something. - dy = 4 - add_text_in_group( - dwg, - labels, - x + dx, - y + dy, - class_=mutation_class, - **self.mutation_label_attrs[mutation.id], + dy = (i + 1) * delta + mutation_class = f"mut m{mutation.id} s{mutation.site}" + mut_group = curr_svg_group.add( + dwg.g(class_=mutation_class, transform=f"translate(0 {rnd(dy)})") ) - y -= delta + # Symbols + mut_group.add(dwg.rect(insert=o, **self.mutation_attrs[mutation.id])) + # Labels + if tree.left_sib(mutation.node) == NULL: + mut_label_class = "rgt" + else: + mut_label_class = "lft" + self.add_class(self.mutation_label_attrs[mutation.id], mut_label_class) + mut_group.add(dwg.text(**self.mutation_label_attrs[mutation.id])) class TextTreeSequence: diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 0aa2409760..028250b236 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1315,86 +1315,86 @@ def draw_svg( >>> SVG(tree.draw_svg()) - - - The elements in the tree are placed into - different `SVG groups `_ for - easy styling and manipulation. Both these groups and their component items - are marked with SVG classes so that they can be targetted. This allows - individual components of the drawing to be hidden, styled, or otherwise + The elements in the tree are grouped according to the structure of the tree, + using `SVG groups `_. This allows + easy styling and manipulation of elements and subtrees. Elements in the SVG file + are marked with SVG classes so that they can be targetted, allowing + different components of the drawing to be hidden, styled, or otherwise manipulated. For example, when drawing (say) the first tree from a tree sequence, all the SVG components will be placed in a group of class ``tree``. The group will have the additional class ``t0``, indicating that this tree has index 0 in the tree sequence. The general SVG structure is as follows: - * The *tree* group (classes ``tree`` and ``tN`` where `N` is the tree index). - This contains the following three groups: - - * The *edges* group (class ``edges``), containing edges. Each edge has - classes ``pX``, and ``cY`` where `X` and `Y` are the ids of the parent - and child nodes. - * The *symbols* group (class ``symbols``), containing two subgroups: - - * The *node symbols* group (class ``nodes``) containing a - `circle `_ for - each node. Each node symbol has a class ``nX`` where ``X`` is the node - id. Symbols corresponding to sample nodes are additionally labelled - with a class of ``sample``. - * The *mutation symbols* group (class ``mutations``) containing a - `rectangle `_ for - each mutation. Each mutation symbol has classes ``mX``, ``sY``, and - ``nZ`` where `X` is the mutation id, `Y` is the site id, and `Z` is - the id of the node above which the mutation occurs. - - * The *labels* group (class ``labels``) containing two subgroups: - - * The *node labels* group (class ``nodes``) containing text for each - node. Each `text `_ - element in this group corresponds to a node, and has the same set of - classes as its equivalent node symbol (i.e. ``nX`` and potentially - ``sample``) - * The *mutation labels* group (class `mutations`) containing containing - text for each mutation. Each - `text `_ element in - this group corresponds to a mutation, and has the same set of classes - as its equivalent mutation symbol (i.e. ``mX``, ``sY``, and ``nZ``) + Each tree is contained in a group of class ``tree``. Additionally, this group + has a class ``tN`` where `N` is the tree index. + + Within the ``tree`` group there is a nested hierarchy of groups corresponding + to the tree structure. Any particular node in the tree will have a corresponding + group containing child groups (if any) followed by the edge above that node, a + node symbol, and (potentially) text containing the node label. For example, a + simple two tip tree, with tip node ids 0 and 1, and a root node id of 2 will have + a structure similar to the following: + + .. code-block:: + + + + + + + Node 1 + + + + + Node 0 + + + + Root (Node 2) + + The classes can be used to manipulate the element, e.g. by using `stylesheets `_. Style strings can be embedded in the svg by using the ``style`` parameter, or added to html pages which contain the raw SVG (e.g. within a Jupyter notebook by using the - IPython HTML() function). As a simple example, the following style string - will hide all labels: + IPython ``HTML()`` function). As a simple example, passing the following + string as the ``style`` parameter will hide all labels: .. code-block:: css - .tree .labels {visibility: hidden} + .tree text {visibility: hidden} You can also change the format of various items: the following styles will - display the symbols of the *sample* nodes only in blue, rotate the sample - node labels by 90 degrees, and hide the internal node labels: + rotate the leaf nodes labels by 90 degrees, colour the leaf nodes (which are + adjacent siblings to the edge lines) blue, and hide the non-sample node labels: .. code-block:: css - .tree .symbols .nodes .sample {fill: blue} - .tree .labels .nodes text.sample {transform: rotate(90deg)} - .tree .labels .nodes text:not(.sample) {visibility: hidden} + .tree .node.leaf > text { + transform: translateY(0.5em) rotate(90deg); text-anchor: start} + .tree .node.leaf > .edge + * {fill: blue} + .tree .node:not(.sample) > text {visibility: hidden} Specific nodes can be targetted by number. The following style will display - node 10 in red, and also colour in red the edges whose parent is node 10: + a large symbol for node 10, coloured red with a black border, and will also use + thick red lines for all the edges that have it as a direct or indirect parent: .. code-block:: css - .tree .symbols .nodes .n10 {fill: red} - .tree .edges .p10 {stroke: red} - - Mutations can be targetted by id, site id, or node number. The following - style displays all mutations immediately above node 10 as yellow with a black - border + .tree .node.n10 > .edge + * {fill: red; stroke: black; r: 8px} + .tree .node.p10 .edge {stroke: red; stroke-width: 2px} - .. code-block:: css + .. note:: - .tree .symbols .mutations .n10 {fill: yellow; stroke: black} + A feature of SVG style commands is that they apply not just to the contents + within the container, but to the entire file. Thus if an SVG file is + embedded in a larger document, such as an HTML file (e.g. when an SVG + is displayed inline in a Jupyter notebook), the style will apply to all SVG + drawings in the notebook. To avoid this, you can tag the SVG with a unique + SVG using ``root_svg_attributes={'id':'MY_UID'}``, and prepend this to the + style string, as in ``#MY_UID .tree .edges {stroke: gray}``. :param str path: The path to the file to write the output. If None, do not write to file. @@ -4389,12 +4389,12 @@ def draw_svg( described in :meth:`Tree.draw_svg`, so that visual elements pertaining to one or more trees targetted as documented in that method. For instance, the following style will change the colour of all the edges of the *initial* - tree in the sequence and hide the internal node labels in *all* the trees + tree in the sequence and hide the non-sample node labels in *all* the trees .. code-block:: css .tree.t0 .edges {stroke: blue} - .tree .labels .nodes text:not(.sample) {visibility: hidden} + .tree .node:not(.sample) > text {visibility: hidden} See :meth:`Tree.draw_svg` for further details.