diff --git a/appveyor.yml b/appveyor.yml index e017c42b7b..a100748590 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -26,6 +26,7 @@ build_script: - cmd: python -m pip install PyVCF - cmd: python -m pip install newick - cmd: python -m pip install python_jsonschema_objects + - cmd: python -m pip install xmlunittest - cmd: python -m nose after_test: diff --git a/docs/conf.py b/docs/conf.py index 064bb8eae0..707d81dd71 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -279,6 +279,7 @@ def handle_item(fieldarg, content): intersphinx_mapping = { "https://docs.python.org/": None, "http://docs.scipy.org/doc/numpy/": None, + "https://svgwrite.readthedocs.io/en/stable/": None, } # -- Options for todo extension ---------------------------------------------- diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 46294f9069..f04dfb187f 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -6,6 +6,11 @@ In development **New features** +- Add classes to SVG drawings to allow easy adjustment and styling, and document the new + ``tskit.Tree.draw_svg()`` and ``tskit.TreeSequence.draw_svg()`` methods. This also fixes + :issue:`467` for duplicate SVG entity ``id`` s in Jupyter notebooks. + (:user:`hyanwong`, :pr:`555`) + - Add a ``nexus`` function that outputs a tree sequence in Nexus format (:user:`saunack`, :pr:`550`). diff --git a/python/requirements/CI/requirements.txt b/python/requirements/CI/requirements.txt index 743267fa54..391e3e1626 100644 --- a/python/requirements/CI/requirements.txt +++ b/python/requirements/CI/requirements.txt @@ -22,4 +22,5 @@ sphinx==2.4.4 sphinx-argparse==0.2.5 sphinx-issues==1.2.0 sphinx_rtd_theme==0.4.3 -svgwrite==1.4 \ No newline at end of file +svgwrite==1.4 +xmlunittest==0.5.0 \ No newline at end of file diff --git a/python/requirements/development.txt b/python/requirements/development.txt index 8d046aa8be..09394b8e38 100644 --- a/python/requirements/development.txt +++ b/python/requirements/development.txt @@ -22,4 +22,5 @@ sphinx==2.4.4 #Pinned as breathe v3 compatibility is rough for now. sphinx-argparse sphinx-issues sphinx_rtd_theme -svgwrite \ No newline at end of file +svgwrite +xmlunittest \ No newline at end of file diff --git a/python/tests/data/svg/tree.svg b/python/tests/data/svg/tree.svg new file mode 100644 index 0000000000..7745c45d91 --- /dev/null +++ b/python/tests/data/svg/tree.svg @@ -0,0 +1,70 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 5 + + + + + 9 + + + 0 + + + 1 + + + 2 + + + 3 + + + + + 4 + + + + + + + + 0 + + + + + + diff --git a/python/tests/data/svg/ts.svg b/python/tests/data/svg/ts.svg new file mode 100644 index 0000000000..90f540ba64 --- /dev/null +++ b/python/tests/data/svg/ts.svg @@ -0,0 +1,329 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 5 + + + + + 9 + + + 0 + + + 1 + + + 2 + + + 3 + + + + + 4 + + + + + + + + 0 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 5 + + + + + 7 + + + 0 + + + 1 + + + 2 + + + 3 + + + + + 4 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 5 + + + + + 6 + + + 0 + + + 1 + + + 2 + + + 3 + + + + + 4 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 5 + + + + + 7 + + + 0 + + + 1 + + + 2 + + + 3 + + + + + 4 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 5 + + + + + 8 + + + 0 + + + 1 + + + 2 + + + 3 + + + + + 4 + + + + + + + + + + + + + + + 0.00 + + + + 0.06 + + + + 0.79 + + + + 0.91 + + + + 0.91 + + + + 1.00 + + + + diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index 6ee98675f7..9a290fb4be 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -30,6 +30,7 @@ import xml.etree import msprime +import xmlunittest import tests.tsutil as tsutil import tskit @@ -44,18 +45,20 @@ def get_binary_tree(self): ts = msprime.simulate(10, random_seed=1, mutation_rate=1) return next(ts.trees()) - def get_nonbinary_tree(self): + def get_nonbinary_ts(self): demographic_events = [ msprime.SimpleBottleneck(time=0.1, population=0, proportion=0.5) ] - ts = msprime.simulate( + return msprime.simulate( 10, recombination_rate=5, mutation_rate=10, demographic_events=demographic_events, random_seed=1, ) - for t in ts.trees(): + + def get_nonbinary_tree(self): + for t in self.get_nonbinary_ts().trees(): for u in t.nodes(): if len(t.children(u)) > 2: return t @@ -140,6 +143,60 @@ def get_empty_tree(self): ts = tables.tree_sequence() return next(ts.trees()) + def get_simple_ts(self): + """ + return a simple tree seq that does not depend on msprime + """ + nodes = io.StringIO( + """\ + id is_sample population individual time metadata + 0 1 0 -1 0.00000000000000 + 1 1 0 -1 0.00000000000000 + 2 1 0 -1 0.00000000000000 + 3 1 0 -1 0.00000000000000 + 4 0 0 -1 0.02445014598813 + 5 0 0 -1 0.11067965364865 + 6 0 0 -1 1.75005250750382 + 7 0 0 -1 2.31067154311640 + 8 0 0 -1 3.57331354884652 + 9 0 0 -1 9.08308317451295 + """ + ) + edges = io.StringIO( + """\ + id left right parent child + 0 0.00000000 1.00000000 4 0 + 1 0.00000000 1.00000000 4 1 + 2 0.00000000 1.00000000 5 2 + 3 0.00000000 1.00000000 5 3 + 4 0.79258618 0.90634460 6 4 + 5 0.79258618 0.90634460 6 5 + 6 0.05975243 0.79258618 7 4 + 7 0.90634460 0.91029435 7 4 + 8 0.05975243 0.79258618 7 5 + 9 0.90634460 0.91029435 7 5 + 10 0.91029435 1.00000000 8 4 + 11 0.91029435 1.00000000 8 5 + 12 0.00000000 0.05975243 9 4 + 13 0.00000000 0.05975243 9 5 + """ + ) + sites = io.StringIO( + """\ + position ancestral_state + 0.01 A + """ + ) + mutations = io.StringIO( + """\ + site node derived_state parent + 0 4 T -1 + """ + ) + return tskit.load_text( + nodes, edges, sites=sites, mutations=mutations, strict=False + ) + class TestFormats(TestTreeDraw): """ @@ -364,7 +421,7 @@ def test_bad_orientation(self): t.draw_text(orientation=bad_orientation) -class TestDrawTextExamples(unittest.TestCase): +class TestDrawTextExamples(TestTreeDraw): """ Verify that we get the correct rendering for some examples. """ @@ -954,42 +1011,7 @@ def test_draw_multiroot_forky_tree(self): self.verify_text_rendering(t.draw_text(), tree) def test_simple_tree_sequence(self): - nodes = io.StringIO( - """\ - id is_sample population individual time metadata - 0 1 0 -1 0.00000000000000 - 1 1 0 -1 0.00000000000000 - 2 1 0 -1 0.00000000000000 - 3 1 0 -1 0.00000000000000 - 4 0 0 -1 0.02445014598813 - 5 0 0 -1 0.11067965364865 - 6 0 0 -1 1.75005250750382 - 7 0 0 -1 2.31067154311640 - 8 0 0 -1 3.57331354884652 - 9 0 0 -1 9.08308317451295 - """ - ) - edges = io.StringIO( - """\ - id left right parent child - 0 0.00000000 1.00000000 4 0 - 1 0.00000000 1.00000000 4 1 - 2 0.00000000 1.00000000 5 2 - 3 0.00000000 1.00000000 5 3 - 4 0.79258618 0.90634460 6 4 - 5 0.79258618 0.90634460 6 5 - 6 0.05975243 0.79258618 7 4 - 7 0.90634460 0.91029435 7 4 - 8 0.05975243 0.79258618 7 5 - 9 0.90634460 0.91029435 7 5 - 10 0.91029435 1.00000000 8 4 - 11 0.91029435 1.00000000 8 5 - 12 0.00000000 0.05975243 9 4 - 13 0.00000000 0.05975243 9 5 - """ - ) - ts = tskit.load_text(nodes, edges, strict=False) - + ts = self.get_simple_ts() ts_drawing = ( "9.08┊ 9 ┊ ┊ ┊ ┊ ┊\n" " ┊ ┏━┻━┓ ┊ ┊ ┊ ┊ ┊\n" @@ -1065,41 +1087,7 @@ def test_simple_tree_sequence(self): ) def test_max_tree_height(self): - nodes = io.StringIO( - """\ - id is_sample population individual time metadata - 0 1 0 -1 0.00000000000000 - 1 1 0 -1 0.00000000000000 - 2 1 0 -1 0.00000000000000 - 3 1 0 -1 0.00000000000000 - 4 0 0 -1 0.02445014598813 - 5 0 0 -1 0.11067965364865 - 6 0 0 -1 1.75005250750382 - 7 0 0 -1 2.31067154311640 - 8 0 0 -1 3.57331354884652 - 9 0 0 -1 9.08308317451295 - """ - ) - edges = io.StringIO( - """\ - id left right parent child - 0 0.00000000 1.00000000 4 0 - 1 0.00000000 1.00000000 4 1 - 2 0.00000000 1.00000000 5 2 - 3 0.00000000 1.00000000 5 3 - 4 0.79258618 0.90634460 6 4 - 5 0.79258618 0.90634460 6 5 - 6 0.05975243 0.79258618 7 4 - 7 0.90634460 0.91029435 7 4 - 8 0.05975243 0.79258618 7 5 - 9 0.90634460 0.91029435 7 5 - 10 0.91029435 1.00000000 8 4 - 11 0.91029435 1.00000000 8 5 - 12 0.00000000 0.05975243 9 4 - 13 0.00000000 0.05975243 9 5 - """ - ) - ts = tskit.load_text(nodes, edges, strict=False) + ts = self.get_simple_ts() tree = ( " 9 \n" " ┏━┻━┓ \n" @@ -1134,17 +1122,47 @@ def test_max_tree_height(self): t.draw_text(max_tree_height=bad_max_tree_height) -class TestDrawSvg(TestTreeDraw): +class TestDrawSvg(TestTreeDraw, xmlunittest.XmlTestCase): """ Tests the SVG tree drawing. """ def verify_basic_svg(self, svg, width=200, height=200): + prefix = "{http://www.w3.org/2000/svg}" root = xml.etree.ElementTree.fromstring(svg) - self.assertEqual(root.tag, "{http://www.w3.org/2000/svg}svg") + self.assertEqual(root.tag, prefix + "svg") self.assertEqual(width, int(root.attrib["width"])) self.assertEqual(height, int(root.attrib["height"])) + # Verify the class structure of the svg + root_group = root.find(prefix + "g") + self.assertIn("class", root_group.attrib) + self.assertRegexpMatches( + root_group.attrib["class"], r"\b(tree|tree-sequence)\b" + ) + if "tree-sequence" in root_group.attrib["class"]: + trees = root_group.find(prefix + "g") + self.assertIn("class", trees.attrib) + self.assertRegexpMatches(trees.attrib["class"], r"\btrees\b") + first_tree = trees.find(prefix + "g") + self.assertIn("class", first_tree.attrib) + self.assertRegexpMatches(first_tree.attrib["class"], r"\btree\b") + else: + first_tree = root_group + # Check that we have edges, symbols, and labels groups + groups = first_tree.findall(prefix + "g") + self.assertGreater(len(groups), 0) + for group in groups: + self.assertIn("class", group.attrib) + cls = group.attrib["class"] + self.assertRegexpMatches(cls, r"\b(edges|symbols|labels)\b") + if "symbols" in cls or "labels" in cls: + # Check that we have nodes & mutations subgroups + for subgroup in group.findall(prefix + "g"): + self.assertIn("class", subgroup.attrib) + subcls = subgroup.attrib["class"] + self.assertRegexpMatches(subcls, r"\b(nodes|mutations)\b") + def test_draw_file(self): t = self.get_binary_tree() fd, filename = tempfile.mkstemp(prefix="tskit_viz_") @@ -1346,7 +1364,7 @@ def test_height_scale_rank_and_max_tree_height(self): self.verify_basic_svg(svg1) svg2 = t.draw_svg(tree_height_scale="rank") self.assertEqual(svg1, svg2) - svg3 = t.draw_svg("tmp.svg", max_tree_height="ts", tree_height_scale="rank") + svg3 = t.draw_svg(max_tree_height="ts", tree_height_scale="rank") self.assertNotEqual(svg1, svg3) self.verify_basic_svg(svg3) # Numeric max tree height not supported for rank scale. @@ -1448,11 +1466,11 @@ def test_max_tree_height(self): svg2 = ts.at_index(1).draw() # if not scaled to ts, node 3 is at a different height in both trees, because the # root is at a different height. We expect a label looking something like - # 3 where XXXX is different + # 3 where XXXX is different str_pos = svg1.find(">3<") - snippet1 = svg1[svg1.rfind("<", 0, str_pos) : str_pos] + snippet1 = svg1[svg1.rfind("3<") - snippet2 = svg2[svg2.rfind("<", 0, str_pos) : str_pos] + snippet2 = svg2[svg2.rfind("3<") - snippet1 = svg1[svg1.rfind("<", 0, str_pos) : str_pos] + snippet1 = svg1[svg1.rfind("3<") - snippet2 = svg2[svg2.rfind("<", 0, str_pos) : str_pos] + snippet2 = svg2[svg2.rfind(">> SVG(tree.draw_svg()) + + + + The elements in the tree are placed into + different `SVG groups `_ for + easy styling and manipulation. Both these groups and their component items + are marked with SVG classes so that they can be targetted. This allows + individual components of the drawing to be hidden, styled, or otherwise + manipulated. For example, when drawing (say) the first tree from a tree + sequence, all the SVG components will be placed in a group of class ``tree``. + The group will have the additional class ``t0``, indicating that this tree + has index 0 in the tree sequence. The general SVG structure is as follows: + + * The *tree* group (classes ``tree`` and ``tN`` where `N` is the tree index). + This contains the following three groups: + + * The *edges* group (class ``edges``), containing edges. Each edge has + classes ``pX``, and ``cY`` where `X` and `Y` are the ids of the parent + and child nodes. + * The *symbols* group (class ``symbols``), containing two subgroups: + + * The *node symbols* group (class ``nodes``) containing a + `circle `_ for + each node. Each node symbol has a class ``nX`` where ``X`` is the node + id. Symbols corresponding to sample nodes are additionally labelled + with a class of ``sample``. + * The *mutation symbols* group (class ``mutations``) containing a + `rectangle `_ for + each mutation. Each mutation symbol has classes ``mX``, ``sY``, and + ``nZ`` where `X` is the mutation id, `Y` is the site id, and `Z` is + the id of the node above which the mutation occurs. + + * The *labels* group (class ``labels``) containing two subgroups: + + * The *node labels* group (class ``nodes``) containing text for each + node. Each `text `_ + element in this group corresponds to a node, and has the same set of + classes as its equivalent node symbol (i.e. ``nX`` and potentially + ``sample``) + * The *mutation labels* group (class `mutations`) containing containing + text for each mutation. Each + `text `_ element in + this group corresponds to a mutation, and has the same set of classes + as its equivalent mutation symbol (i.e. ``mX``, ``sY``, and ``nZ``) + + The classes can be used to manipulate the element, e.g. by using + `stylesheets `_. Style strings can + be embedded in the svg by using the ``style`` parameter, or added to html + pages which contain the raw SVG (e.g. within a Jupyter notebook by using the + IPython HTML() function). As a simple example, the following style string + will hide all labels: + + .. code-block:: css + + .tree .labels {visibility: hidden} + + You can also change the format of various items: the following styles will + display the symbols of the *sample* nodes only in blue, rotate the sample + node labels by 90 degrees, and hide the internal node labels: + + .. code-block:: css + + .tree .symbols .nodes .sample {fill: blue} + .tree .labels .nodes text.sample {transform: rotate(90deg)} + .tree .labels .nodes text:not(.sample) {visibility: hidden} + + Specific nodes can be targetted by number. The following style will display + node 10 in red, and also colour in red the edges whose parent is node 10: + + .. code-block:: css + + .tree .symbols .nodes .n10 {fill: red} + .tree .edges .p10 {stroke: red} + + Mutations can be targetted by id, site id, or node number. The following + style displays all mutations immediately above node 10 as yellow with a black + border + + .. code-block:: css + + .tree .symbols .mutations .n10 {fill: yellow; stroke: black} + + :param str path: The path to the file to write the output. If None, do not + write to file. + :param size: A tuple of (width, height) giving the width and height of the + produced SVG drawing in abstract user units (usually interpreted as pixels on + initial display). + :type size: tuple(int, int) + :param str tree_height_scale: Control how height values for nodes are computed. + If this is equal to ``"time"`` (the default), node heights are proportional + to their time values. If this is equal to ``"log_time"``, node heights are + proportional to their log(time) values. If it is equal to ``"rank"``, node + heights are spaced equally according to their ranked times. + :param str,float max_tree_height: The maximum tree height value in the current + scaling system (see ``tree_height_scale``). Can be either a string or a + numeric value. If equal to ``"tree"`` (the default), the maximum tree height + is set to be that of the oldest root in the tree. If equal to ``"ts"`` the + maximum height is set to be the height of the oldest root in the tree + sequence; this is useful when drawing trees from the same tree sequence as it + ensures that node heights are consistent. If a numeric value, this is used as + the maximum tree height by which to scale other nodes. + :param node_labels: If specified, show custom labels for the nodes + (specified by ID) that are present in this map; any nodes not present will + not have a label. + :type node_labels: dict(int, str) + :param mutation_labels: If specified, show custom labels for the + mutations (specified by ID) that are present in the map; any mutations + not present will not have a label. + :type mutation_labels: dict(int, str) + :param dict root_svg_attributes: Additional attributes, such as an id, that will + be embedded in the root ```` tag of the generated drawing. + :param str style: A + `css style string `_ that will be + included in the ``