diff --git a/python/tests/data/svg/mut_tree.svg b/python/tests/data/svg/mut_tree.svg new file mode 100644 index 0000000000..b237646e0d --- /dev/null +++ b/python/tests/data/svg/mut_tree.svg @@ -0,0 +1,61 @@ + + + + + + + + + + + + + + + + 0 + + + + + 1 + + + + 4 + + + + 2 + + + + + + 2 + + + + + 3 + + + + 5 + + + + 9 + + + + 0 + + + + 1 + + + diff --git a/python/tests/data/svg/tree.svg b/python/tests/data/svg/tree.svg index 81feb6f487..34e8feed99 100644 --- a/python/tests/data/svg/tree.svg +++ b/python/tests/data/svg/tree.svg @@ -1,47 +1,46 @@ - + - - + + - - - - - - - 0 + + + + + + + 0 - - - - 1 - - - - 4 - - - 0 + + + + 1 + + + 4 - - - - - 2 + + + + + 2 - - - - 3 + + + + 3 - - - 5 + + + 5 - - 9 + + + 7 diff --git a/python/tests/data/svg/ts.svg b/python/tests/data/svg/ts.svg index 6e68a58407..9e2fdb73bc 100644 --- a/python/tests/data/svg/ts.svg +++ b/python/tests/data/svg/ts.svg @@ -1,8 +1,10 @@ - + - - + + @@ -11,188 +13,214 @@ - - - - - - - 0 - - - - - 1 - - - - 4 - - - 0 - + + + + + + + + + + + 0 + + + + + 1 + + + + 4 + + + + 2 + + + + + + 2 + + + + + 3 + + + + 5 + + + + 9 + + + + 0 + + + + 1 - - - - - 2 - - - - - 3 - - - - 5 - - - 9 - - - - - - - 0 - - - - - 1 - - - - 4 - - - - - - 2 - - - - - 3 - - - - 5 + + + + + + + + 0 + + + + + 1 + + + + 4 + + + + + + 2 + + + + + 3 + + + + 5 + + + + 7 - - 7 - - - - - - - 0 - - - - - 1 - - - - 4 + + + + + + + + 0 + + + + + 1 + + + + 4 + + + + + + 2 + + + + + 3 + + + + 5 + + + + 6 - - - - - 2 - - - - - 3 - - - - 5 - - - 6 - - - - - - - 0 - - - - - 1 - - - - 4 + + + + + + + + 0 + + + + + 1 + + + + 4 + + + + + + 2 + + + + + 3 + + + + 5 + + + + 7 - - - - - 2 - - - - - 3 - - - - 5 - - - 7 - - - - - - - 0 - - - - - 1 - - - - 4 - - - - - - 2 - - - - - 3 - - - - 5 + + + + + + + + 0 + + + + + 1 + + + + 4 + + + + + + 2 + + + + + 3 + + + + 5 + + + + 8 - - 8 diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index 58f7a1483a..c3de6954bf 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -188,13 +188,16 @@ def get_simple_ts(self): sites = io.StringIO( """\ position ancestral_state - 0.01 A + 0.05 A + 0.06 0 """ ) mutations = io.StringIO( """\ site node derived_state parent - 0 4 T -1 + 0 9 T -1 + 0 9 G 0 + 0 4 1 -1 """ ) return tskit.load_text( @@ -1419,12 +1422,15 @@ def verify_basic_svg(self, svg, width=200, height=200): trees = g break self.assertIsNotNone(trees) # Must have found a trees group - first_tree = trees.find(prefix + "g") + first_treebox = trees.find(prefix + "g") + self.assertIn("class", first_treebox.attrib) + self.assertRegexpMatches(first_treebox.attrib["class"], r"\btreebox\b") + first_tree = first_treebox.find(prefix + "g") self.assertIn("class", first_tree.attrib) self.assertRegexpMatches(first_tree.attrib["class"], r"\btree\b") else: first_tree = root_group - # Check that we have edges, symbols, and labels groups + # Check that the first grouping is labelled as a root groups = first_tree.findall(prefix + "g") self.assertGreater(len(groups), 0) for group in groups: @@ -1744,8 +1750,8 @@ def test_max_tree_height(self): svg1 = ts.at_index(0).draw(max_tree_height="ts") svg2 = ts.at_index(1).draw(max_tree_height="ts") - # when scaled, node 3 should be at the *same* height in both trees, so the label - # should be the same + # when scaled, node 3 should be at the *same* height in both trees, so the edge + # definition should be the same self.verify_basic_svg(svg1) self.verify_basic_svg(svg2) str_pos = svg1.find(">0<") @@ -1765,13 +1771,14 @@ def test_draw_simple_ts(self): self.verify_basic_svg(svg, width=200 * ts.num_trees) def test_draw_integer_breaks_ts(self): - r_map = msprime.RecombinationMap.uniform_map(1000, 0.001, num_loci=1000) r_map = msprime.RecombinationMap.uniform_map(1000, 0.005, num_loci=1000) ts = msprime.simulate(5, recombination_map=r_map, random_seed=1) + self.assertGreater(ts.num_trees, 2) svg = ts.draw_svg() self.verify_basic_svg(svg, width=200 * ts.num_trees) + axis_pos = svg.find('class="axis"') for b in ts.breakpoints(): - self.assertNotEqual(svg.find(f">{b:.0f}<"), -1) + self.assertNotEqual(svg.find(f">{b:.0f}<", axis_pos), -1) def test_draw_even_height_ts(self): ts = msprime.simulate(5, recombination_rate=1, random_seed=1) @@ -1808,8 +1815,27 @@ def test_bad_x_scale(self): with self.assertRaises(ValueError): ts.draw_svg(x_scale=bad_x_scale) - def test_known_svg_tree(self): - tree = self.get_simple_ts().first() + def test_tree_root_branch(self): + # in the simple_ts, there are root mutations in the first tree but not the second + ts = self.get_simple_ts() + tree_with_root_mutations = ts.at_index(0) + root1 = tree_with_root_mutations.root + tree_without_root_mutations = ts.at_index(1) + root2 = tree_without_root_mutations.root + svg1 = tree_with_root_mutations.draw_svg() + svg2 = tree_without_root_mutations.draw_svg() + self.verify_basic_svg(svg1) + self.verify_basic_svg(svg2) + edge_str = ' circle {r: 3px; fill: black; stroke: none}" - ".tree text {dominant-baseline: middle}" # not inherited in css 1.1 - ".mut > text.lft {transform: translateX(0.5em); text-anchor: start}" - ".mut > text.rgt {transform: translateX(-0.5em); text-anchor: end}" - ".root > text {transform: translateY(-0.8em)}" # Root - ".leaf > text {transform: translateY(1em)}" # Leaves - ".node > text.lft {transform: translate(0.5em, -0.5em); text-anchor: start}" - ".node > text.rgt {transform: translate(-0.5em, -0.5em); text-anchor: end}" - ".mut {fill: red; font-style: italic}" + ".node > .sym {r: 3px; fill: black; stroke: none}" + ".node > .lab {transform: translateY(-0.8em)}" # Root + ".node.leaf > .lab {transform: translateY(1em)}" # Leaves + ".tree .lab.rgt {text-anchor: start}" + ".tree .lab.lft {text-anchor: end}" + ".mut > .lab.rgt {transform: translateX(0.5em);}" + ".mut > .lab.lft {transform: translateX(-0.5em);}" + ".node > .lab.rgt {transform: translate(0.35em, -0.5em);}" + ".node > .lab.lft {transform: translate(-0.35em, -0.5em);}" + ".mut > .lab {fill: red; font-style: italic}" + ".mut > .sym {fill: red;}" ) @staticmethod @@ -417,6 +438,7 @@ def __init__( root_svg_attributes=None, style=None, order=None, + root_branch=None, ): self.tree = tree self.traversal_order = check_order(order) @@ -429,6 +451,10 @@ def __init__( self.drawing = self.setup_drawing() if style is not None: self.drawing.defs.add(self.drawing.style(style)) + if root_branch is None: + # put a root branch in if we have mutations over the root + root_branch = any(tree.parent(mut.node) == NULL for mut in tree.mutations()) + self.root_branch = root_branch self.treebox_x_offset = 10 self.treebox_y_offset = 10 self.treebox_width = size[0] - 2 * self.treebox_x_offset @@ -448,15 +474,18 @@ def __init__( self.edge_attrs[u] = {} if edge_attrs is not None and u in edge_attrs: self.edge_attrs[u].update(edge_attrs[u]) + self.add_class(self.edge_attrs[u], "edge") self.node_attrs[u] = {} if node_attrs is not None and u in node_attrs: self.node_attrs[u].update(node_attrs[u]) + self.add_class(self.node_attrs[u], "sym") # class 'sym' for symbol label = "" if node_labels is None: label = str(u) elif u in node_labels: label = str(node_labels[u]) self.node_label_attrs[u] = {"text": label} + self.add_class(self.node_label_attrs[u], "lab") # class 'lab' for label if node_label_attrs is not None and u in node_label_attrs: self.node_label_attrs[u].update(node_label_attrs[u]) @@ -472,6 +501,7 @@ def __init__( } if mutation_attrs is not None and m in mutation_attrs: self.mutation_attrs[m].update(mutation_attrs[m]) + self.add_class(self.mutation_attrs[m], "sym") # class 'sym' for symbol label = "" if mutation_labels is None: label = str(m) @@ -480,6 +510,7 @@ def __init__( self.mutation_label_attrs[m] = {"text": label} if mutation_label_attrs is not None and m in mutation_label_attrs: self.mutation_label_attrs[m].update(mutation_label_attrs[m]) + self.add_class(self.mutation_label_attrs[m], "lab") self.draw() @@ -536,18 +567,14 @@ def assign_y_coordinates(self, tree_height_scale, max_tree_height): # node labels within the treebox label_padding = 10 y_padding = self.treebox_y_offset + 2 * label_padding - mutations_over_root = any( - any(tree.parent(mut.node) == NULL for mut in tree.mutations()) - for tree in ts.trees() - ) - root_branch_length = 0 + self.root_branch_length = 0 height = self.image_size[1] - if mutations_over_root: + if self.root_branch: # Allocate a fixed about of space to show the mutations on the # 'root branch' - root_branch_length = height / 10 # FIXME just draw branch?? + self.root_branch_length = height / 10 # FIXME what scaling to use? # y scaling - padding_numerator = height - root_branch_length - 2 * y_padding + padding_numerator = height - self.root_branch_length - 2 * y_padding if tree_height_scale == "log_time": # again shift time by 1 in log(max_tree_height), so consistent y_scale = padding_numerator / (np.log(max_tree_height + 1)) @@ -577,38 +604,83 @@ def assign_x_coordinates(self, tree, x_start, width): node_x_coord_map[u] = a + (b - a) / 2 return node_x_coord_map - def info_classes(self, focal_node): + def add_node(self, curr_svg_group, focal, dx, dy): """ - For a focal node id, return a set of classes that encode this useful information: - "nA": where A == focal node id - "pB" or "root": where B == parent id (or "root" if the focal node is a root) - "sample": a class present if the focal node is a sample - "leaf": a class present if the focal node is a leaf - "mC": where C == mutation id of all mutations above this focal node - "sD": where D == site id of the sites associated with all mutations - above this focal node + Return a list of SvgGroupInfo objects to add to the stack """ - # Add a new group for each node, and give it classes for css targetting - classes = set() - classes.add(f"node n{focal_node}") - v = self.tree.parent(focal_node) + ret = [] + dwg = self.drawing + grp = curr_svg_group + v = self.tree.parent(focal) + classes = [f"n{focal}"] + offset_x = dx + offset_y = dy if v == NULL: - classes.add(f"root") + classes.append(f"root") + # set the origin of the + dx = 0 + dy = self.root_branch_length + edge_x = 0 + edge_height = dy / (len(self.node_mutations[focal]) + 1) + offset_y = offset_y - self.root_branch_length + edge_height else: - classes.add(f"p{v}") - if self.tree.is_sample(focal_node): - classes.add("sample") - if self.tree.is_leaf(focal_node): - classes.add("leaf") - for mutation in self.node_mutations[focal_node]: - # Adding mutations and sites above this node allows identification - # of the tree under any specific mutation - classes.add(f"m{mutation.id}") - classes.add(f"s{mutation.site}") - return sorted(classes) + classes.append(f"a{v}") + edge_x = offset_x + edge_height = dy / (len(self.node_mutations[focal]) + 1) + offset_y = edge_height + + # Add mut group for each mutation, and give it these classes for css targetting: + # "mut" + # "a" or "root": where == id of immediate ancestor (parent) node + # "n": where == focal node id + # "m": where == mutation id + # "s": where == site id + for m in reversed(self.node_mutations[focal]): + mutation_classes = ["mut", f"m{m.id}", f"s{m.site}"] + grp = grp.add( + dwg.g( + class_=" ".join(classes + mutation_classes), + transform=f"translate({rnd(offset_x)} {rnd(offset_y)})", + ) + ) + ret.append( + SvgGroupInfo( + g=grp, edge_dxy=(edge_x, edge_height), node=focal, mutation=m.id + ) + ) + # after the first sideways line of an edge all further movements go downwards + offset_x = 0 + edge_x = 0 + offset_y = edge_height + + # Add a new group for each node, and give it these classes for css targetting: + # "node" + # "nA": where A == focal node id + # "aB" or "root": where B == parent id (or "root" if the focal node is a root) + # "sample": a class present if the focal node is a sample + # "leaf": a class present if the focal node is a leaf + classes.append("node") + focal_node = self.tree.tree_sequence.node(focal) + if focal_node.individual != NULL: + classes.append(f"i{focal_node.individual}") + if focal_node.population != NULL: + classes.append(f"p{focal_node.population}") + if focal_node.is_sample(): + classes.append("sample") + if self.tree.is_leaf(focal): + classes.append("leaf") + grp = grp.add( + dwg.g( + class_=" ".join(classes), + transform=f"translate({rnd(offset_x)} {rnd(offset_y)})", + ) + ) + ret.append(SvgGroupInfo(g=grp, edge_dxy=(dx, dy), node=focal, mutation=None)) + return ret def draw(self): dwg = self.drawing + o = (0, 0) node_x_coord_map = self.node_x_coord_map node_y_coord_map = self.node_y_coord_map tree = self.tree @@ -617,73 +689,58 @@ def draw(self): # Iterate over nodes, adding groups to reflect the tree heirarchy stack = [] for u in tree.roots: - grp = dwg.g( - class_=" ".join(self.info_classes(u)), - transform=f"translate({rnd(node_x_coord_map[u])} " - f"{rnd(node_y_coord_map[u])})", + stack.extend( + self.add_node( + self.root_group, u, node_x_coord_map[u], node_y_coord_map[u], + ) ) - stack.append((u, self.root_group.add(grp))) while len(stack) > 0: - u, curr_svg_group = stack.pop() - pu = node_x_coord_map[u], node_y_coord_map[u] - for focal in tree.children(u): - fx = node_x_coord_map[focal] - pu[0] - fy = node_y_coord_map[focal] - pu[1] - new_svg_group = curr_svg_group.add( - dwg.g( - class_=" ".join(self.info_classes(focal)), - transform=f"translate({rnd(fx)} {rnd(fy)})", - ) - ) - stack.append((focal, new_svg_group)) - - o = (0, 0) - v = tree.parent(u) + curr = stack.pop() + u = curr.node + if curr.mutation is None: + # This is a tree node, not a mutation group + for c in tree.children(u): + dx = node_x_coord_map[c] - node_x_coord_map[u] + dy = node_y_coord_map[c] - node_y_coord_map[u] + stack.extend(self.add_node(curr.g, c, dx, dy)) # Add edge first => below - if v != NULL: - self.add_class(self.edge_attrs[u], "edge") - pv = node_x_coord_map[v], node_y_coord_map[v] - dx = pv[0] - pu[0] - dy = pv[1] - pu[1] + dx, dy = curr.edge_dxy + if dx == 0 and dy == 0: + path = dwg.path([("M", o)], **self.edge_attrs[u]) # e.g. at root + else: + # allowing "H 0" means that animating transitions works correctly path = dwg.path( - [("M", o), ("V", rnd(dy)), ("H", rnd(dx))], **self.edge_attrs[u] + [("M", o), ("V", -rnd(dy)), ("H", -rnd(dx))], **self.edge_attrs[u] ) - curr_svg_group.add(path) # Edges in parent group, so + curr.g.add(path) + + if curr.mutation is None: + # Node symbol + curr.g.add(dwg.circle(**self.node_attrs[u])) + # Labels + if not tree.is_leaf(u): + if tree.parent(u) == NULL: + if self.root_branch: + self.add_class(self.node_label_attrs[u], "rgt") + else: + if u == left_child[tree.parent(u)]: + self.add_class(self.node_label_attrs[u], "lft") + else: + self.add_class(self.node_label_attrs[u], "rgt") + curr.g.add(dwg.text(**self.node_label_attrs[u])) else: - # FIXME this is pretty crappy for spacing mutations over a root. - pv = (pu[0], pu[1] - 20) - - # Add node symbol + label next (visually above the edge subtending this node) - # Symbols - curr_svg_group.add(dwg.circle(**self.node_attrs[u])) - # Labels - if not tree.is_leaf(u) and tree.parent(u) != NULL: + # Mutation symbol + curr.g.add(dwg.rect(insert=o, **self.mutation_attrs[curr.mutation])) + # Labels if u == left_child[tree.parent(u)]: - self.add_class(self.node_label_attrs[u], "rgt") + mut_label_class = "lft" else: - self.add_class(self.node_label_attrs[u], "lft") - curr_svg_group.add(dwg.text(**self.node_label_attrs[u])) - - # Add mutation symbols + labels - delta = (pv[1] - pu[1]) / (len(self.node_mutations[u]) + 1) - for i, mutation in enumerate(reversed(self.node_mutations[u])): - # TODO get rid of these manual positioning tweaks and add them - # as offsets the user can access via a transform or something. - dy = (i + 1) * delta - mutation_class = f"mut m{mutation.id} s{mutation.site}" - mut_group = curr_svg_group.add( - dwg.g(class_=mutation_class, transform=f"translate(0 {rnd(dy)})") - ) - # Symbols - mut_group.add(dwg.rect(insert=o, **self.mutation_attrs[mutation.id])) - # Labels - if mutation.node == left_child[tree.parent(mutation.node)]: mut_label_class = "rgt" - else: - mut_label_class = "lft" - self.add_class(self.mutation_label_attrs[mutation.id], mut_label_class) - mut_group.add(dwg.text(**self.mutation_label_attrs[mutation.id])) + self.add_class( + self.mutation_label_attrs[curr.mutation], mut_label_class + ) + curr.g.add(dwg.text(**self.mutation_label_attrs[curr.mutation])) class TextTreeSequence: @@ -913,7 +970,7 @@ def __init__( # If we don't specify node_labels, default to node ID self.node_labels[u] = str(u) else: - # If we do specify node_labels, default an empty line + # If we do specify node_labels, default to an empty line self.node_labels[u] = self.default_node_label if node_labels is not None: for node, label in node_labels.items():