From 2ea6dd00498ae97a0ce7e7e207bcaaf35e945486 Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Tue, 26 Apr 2022 22:26:15 +0100 Subject: [PATCH] use min_time to allow negative times in ts --- python/CHANGELOG.rst | 4 + python/tests/test_drawing.py | 113 +++++++++++++++++++++-- python/tskit/drawing.py | 172 +++++++++++++++++++++++------------ python/tskit/trees.py | 25 ++++- 4 files changed, 247 insertions(+), 67 deletions(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 498c37fe7f..931dd584ab 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -4,6 +4,10 @@ **Changes** +- A ``min_time`` parameter in ``draw_svg`` enables the youngest node as the y axis min + value, allowing negative times. + (:user:`hyanwong`, :issue:`2197`, :pr:`2215`) + - ``VcfWriter.write`` now prints the site ID of variants in the ID field of the output VCF files. (:user:`roohy`, :issue:`2103`, :pr:`2107`) diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index cbd9521a4d..a0d9dc1e45 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -221,6 +221,28 @@ def get_simple_ts(self, use_mutation_times=False): tables.mutations.time = np.full_like(tables.mutations.time, tskit.UNKNOWN_TIME) return tables.tree_sequence() + def get_ts_varying_min_times(self, *args, **kwargs): + """ + Like get_simple_ts but return a tree sequence with negative times, and some trees + with different min times (i.e. with dangling nonsample nodes at negative times) + """ + ts = self.get_simple_ts(*args, **kwargs) + tables = ts.dump_tables() + time = tables.nodes.time + time[time == 0] = 0.1 + time[3] = -9.99 + tables.nodes.time = time + # set node 3 to be non-sample node lower than the rest + flags = tables.nodes.flags + flags[3] = 0 + tables.nodes.flags = flags + edges = tables.edges + assert edges[3].child == 3 and edges[3].parent == 5 + edges[3] = edges[3].replace(left=ts.breakpoints(True)[1]) + tables.sort() + tables.nodes.flags = flags + return tables.tree_sequence() + def fail(self, *args, **kwargs): """ Required for xmlunittest.XmlTestMixin to work with pytest not unittest @@ -591,6 +613,8 @@ def test_unused_args(self): t.draw(format=self.drawing_format, node_colours={}) with pytest.raises(ValueError): t.draw(format=self.drawing_format, max_time=1234) + with pytest.raises(ValueError): + t.draw(format=self.drawing_format, min_time=1234) with pytest.raises(ValueError): with pytest.warns(FutureWarning): t.draw(format=self.drawing_format, max_tree_height=1234) @@ -1500,8 +1524,6 @@ def test_nonimplemented_base_class(self): plot.set_spacing() with pytest.raises(NotImplementedError): plot.draw_x_axis(tick_positions=ts.breakpoints(as_array=True)) - with pytest.raises(NotImplementedError): - plot.draw_y_axis(ticks={0: "0"}) def test_bad_tick_spacing(self): # Integer y_ticks to give auto-generated tick locs is not currently implemented @@ -1514,7 +1536,7 @@ def test_bad_tick_spacing(self): def test_no_mixed_yscales(self): ts = self.get_simple_ts() - with pytest.raises(ValueError, match="varying yscales"): + with pytest.raises(ValueError, match="vary in timescale"): ts.draw_svg(y_axis=True, max_time="tree") def test_draw_defaults(self): @@ -1704,6 +1726,21 @@ def test_bad_max_time(self): with pytest.warns(FutureWarning): t.draw_svg(max_tree_height=bad_height) + def test_bad_min_time(self): + t = self.get_binary_tree() + for bad_min in ["te", "asdf", "", [], b"23"]: + with pytest.raises(ValueError): + t.draw_svg(min_time=bad_min) + with pytest.raises(ValueError): + with pytest.warns(FutureWarning): + t.draw_svg(max_tree_height=bad_min) + + def test_bad_neg_log_time(self): + t = self.get_ts_varying_min_times().at_index(1) + assert min(t.time(u) for u in t.nodes()) < 0 + with pytest.raises(ValueError, match="negative times"): + t.draw_svg(t.draw_svg(time_scale="log_time")) + def test_time_scale_time_and_max_time(self): ts = msprime.simulate(5, recombination_rate=2, random_seed=2) t = ts.first() @@ -1727,16 +1764,40 @@ def test_time_scale_rank_and_max_time(self): ts = msprime.simulate(5, recombination_rate=2, random_seed=2) t = ts.first() # The default should be the same as tree. - svg1 = t.draw_svg(max_time="tree", time_scale="rank") + svg1 = t.draw_svg(max_time="tree", time_scale="rank", y_axis=True) self.verify_basic_svg(svg1) - svg2 = t.draw_svg(time_scale="rank") + svg2 = t.draw_svg(time_scale="rank", y_axis=True) assert svg1 == svg2 - svg3 = t.draw_svg(max_time="ts", time_scale="rank") + svg3 = t.draw_svg(max_time="ts", time_scale="rank", y_axis=True) assert svg1 != svg3 self.verify_basic_svg(svg3) # Numeric max time not supported for rank scale. with pytest.raises(ValueError): - t.draw_svg(max_time=2, time_scale="rank") + t.draw_svg(max_time=2, time_scale="rank", y_axis=True) + + def test_min_tree_time(self): + ts = self.get_ts_varying_min_times() + t = ts.first() + # The default should be the same as tree. + svg1 = t.draw_svg(min_time="tree", y_axis=True) + self.verify_basic_svg(svg1) + svg2 = t.draw_svg(y_axis=True) + assert svg1 == svg2 + svg3 = t.draw_svg(min_time="ts", y_axis=True) + assert svg1 != svg3 + svg4 = t.draw_svg(min_time=min(ts.tables.nodes.time), y_axis=True) + assert svg3 == svg4 + + def test_min_ts_time(self): + ts = self.get_ts_varying_min_times() + svg1 = ts.draw_svg(y_axis=True) + self.verify_basic_svg(svg1, width=200 * ts.num_trees) + svg2 = ts.draw_svg(min_time="ts", y_axis=True) + assert svg1 == svg2 + with pytest.raises(ValueError, match="vary in timescale"): + ts.draw_svg(min_time="tree", y_axis=True) + svg3 = ts.draw_svg(min_time=min(ts.tables.nodes.time), y_axis=True) + assert svg2 == svg3 # # TODO: update the tests below here to check the new SVG based interface. @@ -1894,6 +1955,42 @@ def test_max_time(self): snippet2 = svg2[svg2.rfind("edge", 0, str_pos) : str_pos] assert snippet1 == snippet2 + def test_min_time(self): + nodes = io.StringIO( + """\ + id is_sample time + 0 0 -1.11 + 1 1 2.22 + 2 1 2.22 + 3 0 3.33 + 4 0 4.44 + 5 0 5.55 + """ + ) + edges = io.StringIO( + """\ + left right parent child + 0 1 5 2 + 0 1 5 3 + 1 2 4 2 + 1 2 4 3 + 0 1 3 0 + 0 2 3 1 + """ + ) + ts = tskit.load_text(nodes, edges, strict=False) + svg1a = ts.at_index(0).draw_svg(y_axis=True) + svg1b = ts.at_index(0).draw_svg(y_axis=True, min_time="ts") + svg2a = ts.at_index(1).draw_svg(y_axis=True) + svg2b = ts.at_index(1).draw_svg(y_axis=True, min_time="ts") + # axis should start at -1.11 + assert svg1a == svg1b + assert ">-1.11<" in svg1a + # 2nd tree should be different depending on whether min_time is "tree" or "ts" + assert svg2a != svg2b + assert ">-1.11<" not in svg2a + assert ">-1.11<" not in svg2b + def test_draw_sized_tree(self): tree = self.get_binary_tree() svg = tree.draw_svg(size=(600, 400)) @@ -1987,7 +2084,7 @@ def test_x_axis(self): assert svg_no_css.count("y-axis") == 0 def test_y_axis(self): - tree = msprime.simulate(4, random_seed=2).first() + tree = self.get_simple_ts().first() for hscale, label in [ (None, "Time"), ("time", "Time"), diff --git a/python/tskit/drawing.py b/python/tskit/drawing.py index cf3df33d72..2d025a22f0 100644 --- a/python/tskit/drawing.py +++ b/python/tskit/drawing.py @@ -62,6 +62,38 @@ class Offsets: mutation: int = 0 +@dataclass +class Timescaling: + "Class used to transform the time axis" + max_time: float + min_time: float + plot_min: float + plot_range: float + use_log_transform: bool + + def __post_init__(self): + if self.plot_range < 0: + raise ValueError("Image size too small to allow space to plot tree") + if self.use_log_transform: + if self.min_time < 0: + raise ValueError("Cannot use a log scale if there are negative times") + self.transform = self.log_transform + else: + self.transform = self.linear_transform + + def log_transform(self, y): + "Standard log transform but allowing for values of 0 by adding 1" + delta = 1 if self.min_time == 0 else 0 + log_max = np.log(self.max_time + delta) + log_min = np.log(self.min_time + delta) + y_scale = self.plot_range / (log_max - log_min) + return self.plot_min - (np.log(y + delta) - log_min) * y_scale + + def linear_transform(self, y): + y_scale = self.plot_range / (self.max_time - self.min_time) + return self.plot_min - (y - self.min_time) * y_scale + + def check_orientation(orientation): if orientation is None: orientation = TOP @@ -84,6 +116,21 @@ def check_max_time(max_time, allow_numeric=True): return max_time +def check_min_time(min_time, allow_numeric=True): + if min_time is None: + min_time = "tree" + if allow_numeric: + is_numeric = isinstance(min_time, numbers.Real) + if min_time not in ["tree", "ts"] and not is_numeric: + raise ValueError( + "min_time must be a numeric value or one of 'tree' or 'ts'" + ) + else: + if min_time not in ["tree", "ts"]: + raise ValueError("min_time must be 'tree' or 'ts'") + return min_time + + def check_time_scale(time_scale): if time_scale is None: time_scale = "time" @@ -308,6 +355,7 @@ def draw_tree( time_scale=None, tree_height_scale=None, max_time=None, + min_time=None, max_tree_height=None, order=None, ): @@ -360,6 +408,7 @@ def remap_style(original_map, new_key, none_value): mutation_labels=mutation_labels, time_scale=time_scale, max_time=max_time, + min_time=min_time, node_attrs=node_attrs, edge_attrs=edge_attrs, node_label_attrs=node_label_attrs, @@ -389,6 +438,7 @@ def remap_style(original_map, new_key, none_value): tree, node_labels=node_labels, max_time=max_time, + min_time=min_time, use_ascii=use_ascii, orientation=TOP, order=order, @@ -723,11 +773,11 @@ def draw_y_axis( if self.y_axis: y_axis.add(dwg.line((x, rnd(lower)), (x, rnd(upper)))) ticks_group = y_axis.add(dwg.g(class_="ticks")) - for pos, label in ticks.items(): + for y, label in ticks.items(): tick = ticks_group.add( dwg.g( class_="tick", - transform=f"translate({x} {rnd(self.y_transform(pos))})", + transform=f"translate({x} {rnd(self.timescaling.transform(y))})", ) ) if gridlines: @@ -800,9 +850,6 @@ def x_transform(self, x): "No transform func defined for genome pos -> plot coords" ) - def y_transform(self, y): - raise NotImplementedError("No transform func defined for time -> plot pos") - class SvgTreeSequence(SvgPlot): """ @@ -832,6 +879,7 @@ def __init__( y_gridlines, x_lim=None, max_time=None, + min_time=None, node_attrs=None, mutation_attrs=None, edge_attrs=None, @@ -862,6 +910,8 @@ def __init__( size = (200 * num_trees, 200) if max_time is None: max_time = "ts" + if min_time is None: + min_time = "ts" # X axis shown by default if x_axis is None: x_axis = True @@ -901,6 +951,7 @@ def __init__( force_root_branch=force_root_branch, symbol_size=symbol_size, max_time=max_time, + min_time=min_time, node_attrs=node_attrs, mutation_attrs=mutation_attrs, edge_attrs=edge_attrs, @@ -922,14 +973,13 @@ def __init__( ) y_low = self.tree_plotbox.bottom if y_axis is not None: - self.y_transform = lambda x: svg_trees[0].y_transform(x) + y + self.timescaling = svg_trees[0].timescaling for svg_tree in svg_trees: - if self.y_transform(1.234) != svg_tree.y_transform(1.234) + y: - # Slight hack: check an arbitrary value is transformed identically + if self.timescaling != svg_tree.timescaling: raise ValueError( - "Can't draw a tree sequence Y axis for trees of varying yscales" + "Can't draw a tree sequence Y axis if trees vary in timescale" ) - y_low = self.y_transform(0) # if poss use zero point for lowest axis value + y_low = self.timescaling.transform(self.timescaling.min_time) if y_ticks is None: y_ticks = np.unique(ts.tables.nodes.time[referenced_nodes(ts)]) if self.time_scale == "rank": @@ -1027,6 +1077,7 @@ def __init__( tree, size=None, max_time=None, + min_time=None, max_tree_height=None, node_labels=None, mutation_labels=None, @@ -1203,7 +1254,7 @@ def __init__( add_class(self.mutation_label_attrs[m], "lab") self.set_spacing(top=10, left=20, bottom=15, right=20) - self.assign_y_coordinates(max_time, force_root_branch) + self.assign_y_coordinates(max_time, min_time, force_root_branch) self.assign_x_coordinates() tick_length_lower = self.default_tick_length # TODO - parameterize tick_length_upper = self.default_tick_length_site # TODO - parameterize @@ -1230,7 +1281,7 @@ def __init__( self.draw_y_axis( ticks=check_y_ticks(y_ticks), - lower=self.y_transform(0), + lower=self.timescaling.transform(self.timescaling.min_time), tick_length_left=self.default_tick_length, gridlines=y_gridlines, ) @@ -1260,20 +1311,22 @@ def process_mutations_over_node(self, u, low_bound, high_bound, ignore_times=Fal def assign_y_coordinates( self, max_time, + min_time, force_root_branch, bottom_space=SvgPlot.line_height, top_space=SvgPlot.line_height, ): """ - Create a self.node_height dict, a self.y_transform func and + Create a self.node_height dict, a self.timescaling instance and self.min_root_branch_plot_length for use in plotting. Allow extra space within the plotbox, at the bottom for leaf labels, and (potentially, if no root branches are plotted) above the topmost root node for root labels. """ max_time = check_max_time(max_time, self.time_scale != "rank") + min_time = check_min_time(min_time, self.time_scale != "rank") node_time = self.ts.tables.nodes.time mut_time = self.ts.tables.mutations.time - root_branch_length = 0 + root_branch_len = 0 if self.time_scale == "rank": t = np.zeros_like(node_time) if max_time == "tree": @@ -1290,16 +1343,19 @@ def assign_y_coordinates( max_node_height = len(times) depth = {t: j for j, t in enumerate(times)} if self.mutations_over_roots or force_root_branch: - root_branch_length = 1 # Will get scaled later - max_time = max(depth.values()) + root_branch_length - # In pathological cases, all the roots are at 0 - if max_time == 0: - max_time = 1 + root_branch_len = 1 # Will get scaled later + max_time = max(depth.values()) + root_branch_len + if min_time in (None, "tree", "ts"): + assert min(depth.values()) == 0 + min_time = 0 + # In pathological cases, all the nodes are at the same time + if max_time == min_time: + max_time = min_time + 1 self.node_height = {u: depth[node_time[u]] for u in self.tree.nodes()} for u in self.node_mutations.keys(): parent = self.tree.parent(u) if parent == NULL: - top = self.node_height[u] + root_branch_length + top = self.node_height[u] + root_branch_len else: top = self.node_height[parent] self.process_mutations_over_node( @@ -1313,62 +1369,59 @@ def assign_y_coordinates( max_mut_height = np.nanmax( [0] + [mut.time for m in self.node_mutations.values() for mut in m] ) - else: + max_time = max(max_node_height, max_mut_height) # Reuse variable + elif max_time == "ts": max_node_height = self.ts.max_root_time max_mut_height = np.nanmax(np.append(mut_time, 0)) - max_time = max(max_node_height, max_mut_height) # Reuse variable - # In pathological cases, all the roots are at 0 - if max_time == 0: - max_time = 1 - + max_time = max(max_node_height, max_mut_height) # Reuse variable + if min_time == "tree": + min_time = min(self.node_height.values()) + # don't need to check mutation times, as they must be above a node + elif min_time == "ts": + min_time = np.min(self.ts.tables.nodes.time[referenced_nodes(self.ts)]) + # In pathological cases, all the nodes are at the same time + if min_time == max_time: + max_time = min_time + 1 if self.mutations_over_roots or force_root_branch: # Define a minimum root branch length, after transformation if necessary if self.time_scale != "log_time": - root_branch_length = max_time * self.root_branch_fraction + root_branch_len = (max_time - min_time) * self.root_branch_fraction else: - log_height = np.log(max_time + 1) - root_branch_length = ( - np.exp(log_height * (1 + self.root_branch_fraction)) - - 1 - - max_time - ) + max_plot_y = np.log(max_time + 1) + diff_plot_y = max_plot_y - np.log(min_time + 1) + root_plot_y = max_plot_y + diff_plot_y * self.root_branch_fraction + root_branch_len = np.exp(root_plot_y) - 1 - max_time # If necessary, allow for this extra branch in max_time - if max_node_height + root_branch_length > max_time: - max_time = max_node_height + root_branch_length + if max_node_height + root_branch_len > max_time: + max_time = max_node_height + root_branch_len for u in self.node_mutations.keys(): parent = self.tree.parent(u) if parent == NULL: # This is a root: if muts have no times we must specify an upper time - top = self.node_height[u] + root_branch_length + top = self.node_height[u] + root_branch_len else: top = self.node_height[parent] self.process_mutations_over_node(u, self.node_height[u], top) assert float(max_time) == max_time - + assert float(min_time) == min_time # Add extra space above the top and below the bottom of the tree to keep the # node labels within the plotbox (but top label space not needed if the # existence of a root branch pushes the whole tree + labels downwards anyway) - top_space = 0 if root_branch_length > 0 else top_space - zero_pos = self.plotbox.height + self.plotbox.top - bottom_space - padding_numerator = self.plotbox.height - top_space - bottom_space - if padding_numerator < 0: - raise ValueError("Image size too small to allow space to plot tree") - # Transform the y values into plot space (inverted y with 0 at the top of screen) - if self.time_scale == "log_time": - # add 1 so that don't reach log(0) = -inf error. - # just shifts entire timeset by 1 unit so shouldn't affect anything - y_scale = padding_numerator / np.log(max_time + 1) - self.y_transform = lambda y: zero_pos - np.log(y + 1) * y_scale - else: - y_scale = padding_numerator / max_time - self.y_transform = lambda y: zero_pos - y * y_scale + top_space = 0 if root_branch_len > 0 else top_space + self.timescaling = Timescaling( + max_time=max_time, + min_time=min_time, + plot_min=self.plotbox.height + self.plotbox.top - bottom_space, + plot_range=self.plotbox.height - top_space - bottom_space, + use_log_transform=(self.time_scale == "log_time"), + ) # Calculate default root branch length to use (in plot coords). This is a # minimum, as branches with deep root mutations could be longer - self.min_root_branch_plot_length = self.y_transform( - max_time - ) - self.y_transform(max_time + root_branch_length) + self.min_root_branch_plot_length = self.timescaling.transform( + self.timescaling.max_time + ) - self.timescaling.transform(self.timescaling.max_time + root_branch_len) def assign_x_coordinates(self): num_leaves = len(list(self.tree.leaves())) @@ -1437,7 +1490,9 @@ def info_classes(self, focal_node_id): def draw_tree(self): dwg = self.drawing node_x_coord_map = self.node_x_coord_map - node_y_coord_map = {u: self.y_transform(h) for u, h in self.node_height.items()} + node_y_coord_map = { + u: self.timescaling.transform(h) for u, h in self.node_height.items() + } tree = self.tree left_child = get_left_child(tree, self.traversal_order) @@ -1485,7 +1540,8 @@ def draw_tree(self): mutation = self.node_mutations[u][0] # Oldest on this branch root_branch_l = max( root_branch_l, - node_y_coord_map[u] - self.y_transform(mutation.time), + node_y_coord_map[u] + - self.timescaling.transform(mutation.time), ) path = dwg.path( [("M", o), ("V", rnd(-root_branch_l)), ("H", 0)], @@ -1498,7 +1554,7 @@ def draw_tree(self): for mutation in self.node_mutations[u]: # TODO get rid of these manual positioning tweaks and add them # as offsets the user can access via a transform or something. - dy = self.y_transform(mutation.time) - pu[1] + dy = self.timescaling.transform(mutation.time) - pu[1] mutation_id = mutation.id + self.offsets.mutation mutation_class = ( f"mut m{mutation_id} " f"s{mutation.site+ self.offsets.site}" @@ -1743,6 +1799,7 @@ def __init__( tree, node_labels=None, max_time=None, + min_time=None, use_ascii=False, orientation=None, order=None, @@ -1750,6 +1807,7 @@ def __init__( self.tree = tree self.traversal_order = check_order(order) self.max_time = check_max_time(max_time, allow_numeric=False) + self.min_time = check_min_time(min_time, allow_numeric=False) self.use_ascii = use_ascii self.orientation = check_orientation(orientation) self.horizontal_line_char = "━" diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 2041ac42cd..34d430145f 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1593,6 +1593,7 @@ def draw_svg( time_scale=None, tree_height_scale=None, max_time=None, + min_time=None, max_tree_height=None, node_labels=None, mutation_labels=None, @@ -1629,14 +1630,22 @@ def draw_svg( heights are spaced equally according to their ranked times. :param str tree_height_scale: Deprecated alias for time_scale. (Deprecated in 0.3.6) - :param str,float max_time: The maximum time value in the current + :param str,float max_time: The maximum plotted time value in the current scaling system (see ``time_scale``). Can be either a string or a numeric value. If equal to ``"tree"`` (the default), the maximum time is set to be that of the oldest root in the tree. If equal to ``"ts"`` the maximum time is set to be the time of the oldest root in the tree sequence; this is useful when drawing trees from the same tree sequence as it ensures that node heights are consistent. If a numeric value, this is used as - the maximum time by which to scale other nodes. + the maximum plotted time by which to scale other nodes. + :param str,float min_time: The minimum plotted time value in the current + scaling system (see ``time_scale``). Can be either a string or a + numeric value. If equal to ``"tree"`` (the default), the minimum time + is set to be that of the youngest node in the tree. If equal to ``"ts"`` the + minimum time is set to be the time of the youngest node in the tree + sequence; this is useful when drawing trees from the same tree sequence as it + ensures that node heights are consistent. If a numeric value, this is used as + the minimum plotted time. :param str,float max_tree_height: Deprecated alias for max_time. (Deprecated in 0.3.6) :param node_labels: If specified, show custom labels for the nodes @@ -1696,6 +1705,7 @@ def draw_svg( time_scale=time_scale, tree_height_scale=tree_height_scale, max_time=max_time, + min_time=min_time, max_tree_height=max_tree_height, node_labels=node_labels, mutation_labels=mutation_labels, @@ -1733,6 +1743,7 @@ def draw( time_scale=None, tree_height_scale=None, max_time=None, + min_time=None, max_tree_height=None, order=None, ): @@ -1824,6 +1835,15 @@ def draw( that node heights are consistent. If a numeric value, this is used as the maximum time by which to scale other nodes. This parameter is not currently supported for text output. + :param str,float min_time: The minimum time value in the current + scaling system (see ``time_scale``). Can be either a string or a + numeric value. If equal to ``"tree"``, the minimum time is set to be + that of the youngest node in the tree. If equal to ``"ts"`` the minimum + time is set to be the time of the youngest node in the tree sequence; + this is useful when drawing trees from the same tree sequence as it ensures + that node heights are consistent. If a numeric value, this is used as the + minimum time to display. This parameter is not currently supported for text + output. :param str max_tree_height: Deprecated alias for max_time. (Deprecated in 0.3.6) :param str order: The left-to-right ordering of child nodes in the drawn tree. @@ -1849,6 +1869,7 @@ def draw( time_scale=time_scale, tree_height_scale=tree_height_scale, max_time=max_time, + min_time=min_time, max_tree_height=max_tree_height, order=order, )