diff --git a/appveyor.yml b/appveyor.yml
index e017c42b7b..a100748590 100644
--- a/appveyor.yml
+++ b/appveyor.yml
@@ -26,6 +26,7 @@ build_script:
- cmd: python -m pip install PyVCF
- cmd: python -m pip install newick
- cmd: python -m pip install python_jsonschema_objects
+ - cmd: python -m pip install xmlunittest
- cmd: python -m nose
after_test:
diff --git a/docs/conf.py b/docs/conf.py
index 064bb8eae0..707d81dd71 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -279,6 +279,7 @@ def handle_item(fieldarg, content):
intersphinx_mapping = {
"https://docs.python.org/": None,
"http://docs.scipy.org/doc/numpy/": None,
+ "https://svgwrite.readthedocs.io/en/stable/": None,
}
# -- Options for todo extension ----------------------------------------------
diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst
index 46294f9069..f04dfb187f 100644
--- a/python/CHANGELOG.rst
+++ b/python/CHANGELOG.rst
@@ -6,6 +6,11 @@ In development
**New features**
+- Add classes to SVG drawings to allow easy adjustment and styling, and document the new
+ ``tskit.Tree.draw_svg()`` and ``tskit.TreeSequence.draw_svg()`` methods. This also fixes
+ :issue:`467` for duplicate SVG entity ``id`` s in Jupyter notebooks.
+ (:user:`hyanwong`, :pr:`555`)
+
- Add a ``nexus`` function that outputs a tree sequence in Nexus format
(:user:`saunack`, :pr:`550`).
diff --git a/python/requirements/CI/requirements.txt b/python/requirements/CI/requirements.txt
index 743267fa54..391e3e1626 100644
--- a/python/requirements/CI/requirements.txt
+++ b/python/requirements/CI/requirements.txt
@@ -22,4 +22,5 @@ sphinx==2.4.4
sphinx-argparse==0.2.5
sphinx-issues==1.2.0
sphinx_rtd_theme==0.4.3
-svgwrite==1.4
\ No newline at end of file
+svgwrite==1.4
+xmlunittest==0.5.0
\ No newline at end of file
diff --git a/python/requirements/development.txt b/python/requirements/development.txt
index 8d046aa8be..09394b8e38 100644
--- a/python/requirements/development.txt
+++ b/python/requirements/development.txt
@@ -22,4 +22,5 @@ sphinx==2.4.4 #Pinned as breathe v3 compatibility is rough for now.
sphinx-argparse
sphinx-issues
sphinx_rtd_theme
-svgwrite
\ No newline at end of file
+svgwrite
+xmlunittest
\ No newline at end of file
diff --git a/python/tests/data/svg/tree.svg b/python/tests/data/svg/tree.svg
new file mode 100644
index 0000000000..7745c45d91
--- /dev/null
+++ b/python/tests/data/svg/tree.svg
@@ -0,0 +1,70 @@
+
+
diff --git a/python/tests/data/svg/ts.svg b/python/tests/data/svg/ts.svg
new file mode 100644
index 0000000000..90f540ba64
--- /dev/null
+++ b/python/tests/data/svg/ts.svg
@@ -0,0 +1,329 @@
+
+
diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py
index 6ee98675f7..9a290fb4be 100644
--- a/python/tests/test_drawing.py
+++ b/python/tests/test_drawing.py
@@ -30,6 +30,7 @@
import xml.etree
import msprime
+import xmlunittest
import tests.tsutil as tsutil
import tskit
@@ -44,18 +45,20 @@ def get_binary_tree(self):
ts = msprime.simulate(10, random_seed=1, mutation_rate=1)
return next(ts.trees())
- def get_nonbinary_tree(self):
+ def get_nonbinary_ts(self):
demographic_events = [
msprime.SimpleBottleneck(time=0.1, population=0, proportion=0.5)
]
- ts = msprime.simulate(
+ return msprime.simulate(
10,
recombination_rate=5,
mutation_rate=10,
demographic_events=demographic_events,
random_seed=1,
)
- for t in ts.trees():
+
+ def get_nonbinary_tree(self):
+ for t in self.get_nonbinary_ts().trees():
for u in t.nodes():
if len(t.children(u)) > 2:
return t
@@ -140,6 +143,60 @@ def get_empty_tree(self):
ts = tables.tree_sequence()
return next(ts.trees())
+ def get_simple_ts(self):
+ """
+ return a simple tree seq that does not depend on msprime
+ """
+ nodes = io.StringIO(
+ """\
+ id is_sample population individual time metadata
+ 0 1 0 -1 0.00000000000000
+ 1 1 0 -1 0.00000000000000
+ 2 1 0 -1 0.00000000000000
+ 3 1 0 -1 0.00000000000000
+ 4 0 0 -1 0.02445014598813
+ 5 0 0 -1 0.11067965364865
+ 6 0 0 -1 1.75005250750382
+ 7 0 0 -1 2.31067154311640
+ 8 0 0 -1 3.57331354884652
+ 9 0 0 -1 9.08308317451295
+ """
+ )
+ edges = io.StringIO(
+ """\
+ id left right parent child
+ 0 0.00000000 1.00000000 4 0
+ 1 0.00000000 1.00000000 4 1
+ 2 0.00000000 1.00000000 5 2
+ 3 0.00000000 1.00000000 5 3
+ 4 0.79258618 0.90634460 6 4
+ 5 0.79258618 0.90634460 6 5
+ 6 0.05975243 0.79258618 7 4
+ 7 0.90634460 0.91029435 7 4
+ 8 0.05975243 0.79258618 7 5
+ 9 0.90634460 0.91029435 7 5
+ 10 0.91029435 1.00000000 8 4
+ 11 0.91029435 1.00000000 8 5
+ 12 0.00000000 0.05975243 9 4
+ 13 0.00000000 0.05975243 9 5
+ """
+ )
+ sites = io.StringIO(
+ """\
+ position ancestral_state
+ 0.01 A
+ """
+ )
+ mutations = io.StringIO(
+ """\
+ site node derived_state parent
+ 0 4 T -1
+ """
+ )
+ return tskit.load_text(
+ nodes, edges, sites=sites, mutations=mutations, strict=False
+ )
+
class TestFormats(TestTreeDraw):
"""
@@ -364,7 +421,7 @@ def test_bad_orientation(self):
t.draw_text(orientation=bad_orientation)
-class TestDrawTextExamples(unittest.TestCase):
+class TestDrawTextExamples(TestTreeDraw):
"""
Verify that we get the correct rendering for some examples.
"""
@@ -954,42 +1011,7 @@ def test_draw_multiroot_forky_tree(self):
self.verify_text_rendering(t.draw_text(), tree)
def test_simple_tree_sequence(self):
- nodes = io.StringIO(
- """\
- id is_sample population individual time metadata
- 0 1 0 -1 0.00000000000000
- 1 1 0 -1 0.00000000000000
- 2 1 0 -1 0.00000000000000
- 3 1 0 -1 0.00000000000000
- 4 0 0 -1 0.02445014598813
- 5 0 0 -1 0.11067965364865
- 6 0 0 -1 1.75005250750382
- 7 0 0 -1 2.31067154311640
- 8 0 0 -1 3.57331354884652
- 9 0 0 -1 9.08308317451295
- """
- )
- edges = io.StringIO(
- """\
- id left right parent child
- 0 0.00000000 1.00000000 4 0
- 1 0.00000000 1.00000000 4 1
- 2 0.00000000 1.00000000 5 2
- 3 0.00000000 1.00000000 5 3
- 4 0.79258618 0.90634460 6 4
- 5 0.79258618 0.90634460 6 5
- 6 0.05975243 0.79258618 7 4
- 7 0.90634460 0.91029435 7 4
- 8 0.05975243 0.79258618 7 5
- 9 0.90634460 0.91029435 7 5
- 10 0.91029435 1.00000000 8 4
- 11 0.91029435 1.00000000 8 5
- 12 0.00000000 0.05975243 9 4
- 13 0.00000000 0.05975243 9 5
- """
- )
- ts = tskit.load_text(nodes, edges, strict=False)
-
+ ts = self.get_simple_ts()
ts_drawing = (
"9.08┊ 9 ┊ ┊ ┊ ┊ ┊\n"
" ┊ ┏━┻━┓ ┊ ┊ ┊ ┊ ┊\n"
@@ -1065,41 +1087,7 @@ def test_simple_tree_sequence(self):
)
def test_max_tree_height(self):
- nodes = io.StringIO(
- """\
- id is_sample population individual time metadata
- 0 1 0 -1 0.00000000000000
- 1 1 0 -1 0.00000000000000
- 2 1 0 -1 0.00000000000000
- 3 1 0 -1 0.00000000000000
- 4 0 0 -1 0.02445014598813
- 5 0 0 -1 0.11067965364865
- 6 0 0 -1 1.75005250750382
- 7 0 0 -1 2.31067154311640
- 8 0 0 -1 3.57331354884652
- 9 0 0 -1 9.08308317451295
- """
- )
- edges = io.StringIO(
- """\
- id left right parent child
- 0 0.00000000 1.00000000 4 0
- 1 0.00000000 1.00000000 4 1
- 2 0.00000000 1.00000000 5 2
- 3 0.00000000 1.00000000 5 3
- 4 0.79258618 0.90634460 6 4
- 5 0.79258618 0.90634460 6 5
- 6 0.05975243 0.79258618 7 4
- 7 0.90634460 0.91029435 7 4
- 8 0.05975243 0.79258618 7 5
- 9 0.90634460 0.91029435 7 5
- 10 0.91029435 1.00000000 8 4
- 11 0.91029435 1.00000000 8 5
- 12 0.00000000 0.05975243 9 4
- 13 0.00000000 0.05975243 9 5
- """
- )
- ts = tskit.load_text(nodes, edges, strict=False)
+ ts = self.get_simple_ts()
tree = (
" 9 \n"
" ┏━┻━┓ \n"
@@ -1134,17 +1122,47 @@ def test_max_tree_height(self):
t.draw_text(max_tree_height=bad_max_tree_height)
-class TestDrawSvg(TestTreeDraw):
+class TestDrawSvg(TestTreeDraw, xmlunittest.XmlTestCase):
"""
Tests the SVG tree drawing.
"""
def verify_basic_svg(self, svg, width=200, height=200):
+ prefix = "{http://www.w3.org/2000/svg}"
root = xml.etree.ElementTree.fromstring(svg)
- self.assertEqual(root.tag, "{http://www.w3.org/2000/svg}svg")
+ self.assertEqual(root.tag, prefix + "svg")
self.assertEqual(width, int(root.attrib["width"]))
self.assertEqual(height, int(root.attrib["height"]))
+ # Verify the class structure of the svg
+ root_group = root.find(prefix + "g")
+ self.assertIn("class", root_group.attrib)
+ self.assertRegexpMatches(
+ root_group.attrib["class"], r"\b(tree|tree-sequence)\b"
+ )
+ if "tree-sequence" in root_group.attrib["class"]:
+ trees = root_group.find(prefix + "g")
+ self.assertIn("class", trees.attrib)
+ self.assertRegexpMatches(trees.attrib["class"], r"\btrees\b")
+ first_tree = trees.find(prefix + "g")
+ self.assertIn("class", first_tree.attrib)
+ self.assertRegexpMatches(first_tree.attrib["class"], r"\btree\b")
+ else:
+ first_tree = root_group
+ # Check that we have edges, symbols, and labels groups
+ groups = first_tree.findall(prefix + "g")
+ self.assertGreater(len(groups), 0)
+ for group in groups:
+ self.assertIn("class", group.attrib)
+ cls = group.attrib["class"]
+ self.assertRegexpMatches(cls, r"\b(edges|symbols|labels)\b")
+ if "symbols" in cls or "labels" in cls:
+ # Check that we have nodes & mutations subgroups
+ for subgroup in group.findall(prefix + "g"):
+ self.assertIn("class", subgroup.attrib)
+ subcls = subgroup.attrib["class"]
+ self.assertRegexpMatches(subcls, r"\b(nodes|mutations)\b")
+
def test_draw_file(self):
t = self.get_binary_tree()
fd, filename = tempfile.mkstemp(prefix="tskit_viz_")
@@ -1346,7 +1364,7 @@ def test_height_scale_rank_and_max_tree_height(self):
self.verify_basic_svg(svg1)
svg2 = t.draw_svg(tree_height_scale="rank")
self.assertEqual(svg1, svg2)
- svg3 = t.draw_svg("tmp.svg", max_tree_height="ts", tree_height_scale="rank")
+ svg3 = t.draw_svg(max_tree_height="ts", tree_height_scale="rank")
self.assertNotEqual(svg1, svg3)
self.verify_basic_svg(svg3)
# Numeric max tree height not supported for rank scale.
@@ -1448,11 +1466,11 @@ def test_max_tree_height(self):
svg2 = ts.at_index(1).draw()
# if not scaled to ts, node 3 is at a different height in both trees, because the
# root is at a different height. We expect a label looking something like
- # 3 where XXXX is different
+ # 3 where XXXX is different
str_pos = svg1.find(">3<")
- snippet1 = svg1[svg1.rfind("<", 0, str_pos) : str_pos]
+ snippet1 = svg1[svg1.rfind("3<")
- snippet2 = svg2[svg2.rfind("<", 0, str_pos) : str_pos]
+ snippet2 = svg2[svg2.rfind("3<")
- snippet1 = svg1[svg1.rfind("<", 0, str_pos) : str_pos]
+ snippet1 = svg1[svg1.rfind("3<")
- snippet2 = svg2[svg2.rfind("<", 0, str_pos) : str_pos]
+ snippet2 = svg2[svg2.rfind(">> SVG(tree.draw_svg())
+
+
+
+ The elements in the tree are placed into
+ different `SVG groups `_ for
+ easy styling and manipulation. Both these groups and their component items
+ are marked with SVG classes so that they can be targetted. This allows
+ individual components of the drawing to be hidden, styled, or otherwise
+ manipulated. For example, when drawing (say) the first tree from a tree
+ sequence, all the SVG components will be placed in a group of class ``tree``.
+ The group will have the additional class ``t0``, indicating that this tree
+ has index 0 in the tree sequence. The general SVG structure is as follows:
+
+ * The *tree* group (classes ``tree`` and ``tN`` where `N` is the tree index).
+ This contains the following three groups:
+
+ * The *edges* group (class ``edges``), containing edges. Each edge has
+ classes ``pX``, and ``cY`` where `X` and `Y` are the ids of the parent
+ and child nodes.
+ * The *symbols* group (class ``symbols``), containing two subgroups:
+
+ * The *node symbols* group (class ``nodes``) containing a
+ `circle `_ for
+ each node. Each node symbol has a class ``nX`` where ``X`` is the node
+ id. Symbols corresponding to sample nodes are additionally labelled
+ with a class of ``sample``.
+ * The *mutation symbols* group (class ``mutations``) containing a
+ `rectangle `_ for
+ each mutation. Each mutation symbol has classes ``mX``, ``sY``, and
+ ``nZ`` where `X` is the mutation id, `Y` is the site id, and `Z` is
+ the id of the node above which the mutation occurs.
+
+ * The *labels* group (class ``labels``) containing two subgroups:
+
+ * The *node labels* group (class ``nodes``) containing text for each
+ node. Each `text `_
+ element in this group corresponds to a node, and has the same set of
+ classes as its equivalent node symbol (i.e. ``nX`` and potentially
+ ``sample``)
+ * The *mutation labels* group (class `mutations`) containing containing
+ text for each mutation. Each
+ `text `_ element in
+ this group corresponds to a mutation, and has the same set of classes
+ as its equivalent mutation symbol (i.e. ``mX``, ``sY``, and ``nZ``)
+
+ The classes can be used to manipulate the element, e.g. by using
+ `stylesheets `_. Style strings can
+ be embedded in the svg by using the ``style`` parameter, or added to html
+ pages which contain the raw SVG (e.g. within a Jupyter notebook by using the
+ IPython HTML() function). As a simple example, the following style string
+ will hide all labels:
+
+ .. code-block:: css
+
+ .tree .labels {visibility: hidden}
+
+ You can also change the format of various items: the following styles will
+ display the symbols of the *sample* nodes only in blue, rotate the sample
+ node labels by 90 degrees, and hide the internal node labels:
+
+ .. code-block:: css
+
+ .tree .symbols .nodes .sample {fill: blue}
+ .tree .labels .nodes text.sample {transform: rotate(90deg)}
+ .tree .labels .nodes text:not(.sample) {visibility: hidden}
+
+ Specific nodes can be targetted by number. The following style will display
+ node 10 in red, and also colour in red the edges whose parent is node 10:
+
+ .. code-block:: css
+
+ .tree .symbols .nodes .n10 {fill: red}
+ .tree .edges .p10 {stroke: red}
+
+ Mutations can be targetted by id, site id, or node number. The following
+ style displays all mutations immediately above node 10 as yellow with a black
+ border
+
+ .. code-block:: css
+
+ .tree .symbols .mutations .n10 {fill: yellow; stroke: black}
+
+ :param str path: The path to the file to write the output. If None, do not
+ write to file.
+ :param size: A tuple of (width, height) giving the width and height of the
+ produced SVG drawing in abstract user units (usually interpreted as pixels on
+ initial display).
+ :type size: tuple(int, int)
+ :param str tree_height_scale: Control how height values for nodes are computed.
+ If this is equal to ``"time"`` (the default), node heights are proportional
+ to their time values. If this is equal to ``"log_time"``, node heights are
+ proportional to their log(time) values. If it is equal to ``"rank"``, node
+ heights are spaced equally according to their ranked times.
+ :param str,float max_tree_height: The maximum tree height value in the current
+ scaling system (see ``tree_height_scale``). Can be either a string or a
+ numeric value. If equal to ``"tree"`` (the default), the maximum tree height
+ is set to be that of the oldest root in the tree. If equal to ``"ts"`` the
+ maximum height is set to be the height of the oldest root in the tree
+ sequence; this is useful when drawing trees from the same tree sequence as it
+ ensures that node heights are consistent. If a numeric value, this is used as
+ the maximum tree height by which to scale other nodes.
+ :param node_labels: If specified, show custom labels for the nodes
+ (specified by ID) that are present in this map; any nodes not present will
+ not have a label.
+ :type node_labels: dict(int, str)
+ :param mutation_labels: If specified, show custom labels for the
+ mutations (specified by ID) that are present in the map; any mutations
+ not present will not have a label.
+ :type mutation_labels: dict(int, str)
+ :param dict root_svg_attributes: Additional attributes, such as an id, that will
+ be embedded in the root ``