diff --git a/docs/_static/different_time_samples.svg b/docs/_static/different_time_samples.svg new file mode 100644 index 0000000000..b93b0f4dca --- /dev/null +++ b/docs/_static/different_time_samples.svg @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + + + + + + + + + + + 3 + + + 4 + 2 + 0 + 1 + + + + + + + + + + diff --git a/docs/examples.py b/docs/examples.py index 6d68f8b177..5d11597f87 100644 --- a/docs/examples.py +++ b/docs/examples.py @@ -331,6 +331,29 @@ def preorder_dist(tree): print(list(preorder_dist(tree))) +def finding_nearest_neighbors(): + samples = [ + msprime.Sample(0, 0), + msprime.Sample(0, 1), + msprime.Sample(0, 20), + ] + ts = msprime.simulate( + Ne=1e6, + samples=samples, + demographic_events=[ + msprime.PopulationParametersChange( + time=10, growth_rate=2, population_id=0 + ), + ], + random_seed=42, + ) + + tree = ts.first() + tree.draw_svg("_static/different_time_samples.svg", + tree_height_scale="rank") + + + # moving_along_tree_sequence() # parsimony() # allele_frequency_spectra() @@ -338,4 +361,5 @@ def preorder_dist(tree): # stats() # tree_structure() tree_traversal() +finding_nearest_neighbors() diff --git a/docs/tutorial.rst b/docs/tutorial.rst index b185478ccf..8609c77991 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -111,6 +111,136 @@ Running this on the example above gives us:: [(7, 0), (6, 1), (4, 2), (3, 2), (5, 1), (2, 2), (1, 2), (0, 2)] + +++++++++++++++++++++++++ +Traversals with networkx +++++++++++++++++++++++++ + +Traversals and other network analysis can also be performed using the sizeable +`networkx `_ +library. This can be achieved by calling :meth:`Tree.as_dict_of_dicts` to +convert a :class:`Tree` instance to a format that can be imported by networkx to +create a graph:: + + import networkx as nx + + g = nx.DiGraph(tree.as_dict_of_dicts()) + print(sorted(g.edges)) + +:: + + [(5, 0), (5, 1), (5, 2), (6, 3), (6, 4), (7, 5), (7, 6)] + +++++++++++++++++++ +Traversing upwards +++++++++++++++++++ + +We can revisit the above examples and traverse upwards with +networkx using a depth-first search algorithm:: + + import networkx as nx + + g = nx.DiGraph(tree.as_dict_of_dicts()) + for u in tree.samples(): + path = [u] + [parent for parent, child, _ in + nx.edge_dfs(g, source=u, orientation="reverse")] + print(u, "->", path) + +giving:: + + 0 -> [0, 5, 7] + 1 -> [1, 5, 7] + 2 -> [2, 5, 7] + 3 -> [3, 6, 7] + 4 -> [4, 6, 7] + ++++++++++++++++++++++++++++++++++ +Calculating distances to the root ++++++++++++++++++++++++++++++++++ + +Similarly, we can yield the nodes of a tree along with their distance to the +root in pre-order in networkx as well:: + + import networkx as nx + + g = nx.DiGraph(tree.as_dict_of_dicts()) + for root in tree.roots: + print(nx.shortest_path_length(g, source=root).items()) + +Running this on the example above gives us the same result as before:: + + [(7, 0), (6, 1), (4, 2), (3, 2), (5, 1), (2, 2), (1, 2), (0, 2)] + ++++++++++++++++++++++++++++++++++++++++ +Finding nearest neighbors ++++++++++++++++++++++++++++++++++++++++ + +If some samples in a tree are not at time 0, then finding the nearest neighbor +of a sample is a bit more involved. Instead of writing our own traversal code +we can again draw on a networkx algorithm. +Let us start with an example tree with three samples that were sampled at +different time points: + +.. image:: _static/different_time_samples.svg + :width: 200px + :alt: An example tree with samples taken at different times + +The generation times for these nodes are: + +.. code-block:: python + + for u in tree.nodes(): + print(u, tree.time(u)) + +giving:: + + 4 20.005398778263334 + 2 20.0 + 3 17.833492457579652 + 0 0.0 + 1 1.0 + +Note that samples 0 and 1 are about 35 generations apart from each other even though +they were sampled at almost the same time. This is why samples 0 and 1 are +closer to sample 2 than to each other. + +For this nearest neighbor search we will be traversing up and down the tree, +so it is easier to treat the tree as an undirected graph:: + + g = nx.Graph(tree.as_dict_of_dicts()) + +When converting the tree to a networkx graph the edges are annotated with their +branch length:: + + print(g.edges(data=True)) + +giving:: + + [(4, 2, {'branch_length': 0.005398778263334236}), + (4, 3, {'branch_length': 2.171906320683682}), + (3, 0, {'branch_length': 17.833492457579652}), + (3, 1, {'branch_length': 16.833492457579652})] + +We can now use the "branch_length" field as a weight for a weighted shortest path +search:: + + # a dictionary of dictionaries to represent our distance matrix + dist_dod = collections.defaultdict(dict) + for source, target in itertools.combinations(tree.samples(), 2): + dist_dod[source][target] = nx.shortest_path_length( + g, source=source, target=target, weight="branch_length" + ) + dist_dod[target][source] = dist_dod[source][target] + + # extract the nearest neighbor of nodes 0, 1, and 2 + nearest_neighbor_of = [min(dist_dod[u], key=dist_dod[u].get) for u in range(3)] + + print(dict(zip(range(3), nearest_neighbor_of))) + +gives:: + + {0: 2, 1: 2, 2: 1} + .. _sec_tutorial_moving_along_a_tree_sequence: **************************** diff --git a/python/requirements/conda-minimal.txt b/python/requirements/conda-minimal.txt index fde8aac5b2..c7104af9bf 100644 --- a/python/requirements/conda-minimal.txt +++ b/python/requirements/conda-minimal.txt @@ -6,3 +6,4 @@ svgwrite msprime kastore biopython +networkx diff --git a/python/requirements/development.txt b/python/requirements/development.txt index e136ab217b..31edbe3218 100644 --- a/python/requirements/development.txt +++ b/python/requirements/development.txt @@ -19,3 +19,4 @@ pysam PyVCF python_jsonschema_objects biopython +networkx \ No newline at end of file diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 5dc1a1aba7..3602fc636c 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -45,6 +45,7 @@ import tests as tests import tests.tsutil as tsutil import tests.simplify as simplify +import networkx as nx def insert_uniform_mutations(tables, num_mutations, nodes): @@ -602,7 +603,6 @@ def verify_edge_diffs(self, ts): children[edge.parent].add(edge.child) while tree.interval[1] <= left: tree = next(trees) - # print(left, right, tree.interval) self.assertTrue(left >= tree.interval[0]) self.assertTrue(right <= tree.interval[1]) for u in tree.nodes(): @@ -1660,6 +1660,139 @@ def test_newick(self): for tree in ts.trees(): self.verify_newick(tree) + def test_as_dict_of_dicts(self): + for ts in get_example_tree_sequences(): + tree = next(ts.trees()) + adj_dod = tree.as_dict_of_dicts() + g = nx.DiGraph(adj_dod) + + self.verify_nx_graph_topology(tree, g) + self.verify_nx_algorithm_equivalence(tree, g) + self.verify_nx_for_tutorial_algorithms(tree, g) + self.verify_nx_nearest_neighbor_search() + + def verify_nx_graph_topology(self, tree, g): + self.assertSetEqual(set(tree.nodes()), set(g.nodes)) + + self.assertSetEqual( + set(tree.roots), + {n for n in g.nodes if g.in_degree(n) == 0} + ) + + self.assertSetEqual( + set(tree.leaves()), + {n for n in g.nodes if g.out_degree(n) == 0} + ) + + # test if tree has no in-degrees > 1 + self.assertTrue(nx.is_branching(g)) + + def verify_nx_algorithm_equivalence(self, tree, g): + for root in tree.roots: + self.assertTrue(nx.is_directed_acyclic_graph(g)) + + # test descendants + self.assertSetEqual( + set(u for u in tree.nodes() if tree.is_descendant(u, root)), + set(nx.descendants(g, root)) | {root} + ) + + # test MRCA + if tree.num_nodes < 20: + for u, v in itertools.combinations(tree.nodes(), 2): + mrca = nx.lowest_common_ancestor(g, u, v) + if mrca is None: + mrca = -1 + self.assertEqual(tree.mrca(u, v), mrca) + + # test node traversal modes + self.assertEqual( + list(tree.nodes(root=root, order="breadthfirst")), + [root] + [v for u, v in nx.bfs_edges(g, root)] + ) + self.assertEqual( + list(tree.nodes(root=root, order="preorder")), + list(nx.dfs_preorder_nodes(g, root)) + ) + + def verify_nx_for_tutorial_algorithms(self, tree, g): + # traversing upwards + for u in tree.leaves(): + path = [] + v = u + while v != tskit.NULL: + path.append(v) + v = tree.parent(v) + + self.assertSetEqual(set(path), {u} | nx.ancestors(g, u)) + self.assertEqual( + path, + [u] + + [n1 for n1, n2, _ in nx.edge_dfs(g, u, orientation="reverse")] + ) + + # traversals with information + def preorder_dist(tree, root): + stack = [(root, 0)] + while len(stack) > 0: + u, distance = stack.pop() + yield u, distance + for v in tree.children(u): + stack.append((v, distance + 1)) + + for root in tree.roots: + self.assertDictEqual( + {k: v for k, v in preorder_dist(tree, root)}, + nx.shortest_path_length(g, source=root) + ) + + for root in tree.roots: + # new traversal: measuring time between root and MRCA + for u, v in itertools.combinations(nx.descendants(g, root), 2): + mrca = tree.mrca(u, v) + tmrca = tree.time(mrca) + self.assertAlmostEqual( + tree.time(root) - tmrca, + nx.shortest_path_length( + g, + source=root, + target=mrca, + weight='branch_length' + ) + ) + + def verify_nx_nearest_neighbor_search(self): + samples = [ + msprime.Sample(0, 0), + msprime.Sample(0, 1), + msprime.Sample(0, 20), + ] + ts = msprime.simulate( + Ne=1e6, + samples=samples, + demographic_events=[ + msprime.PopulationParametersChange( + time=10, growth_rate=2, population_id=0 + ), + ], + random_seed=42, + ) + + tree = ts.first() + g = nx.Graph(tree.as_dict_of_dicts()) + + dist_dod = collections.defaultdict(dict) + for source, target in itertools.combinations(tree.samples(), 2): + dist_dod[source][target] = nx.shortest_path_length( + g, source=source, target=target, weight='branch_length' + ) + dist_dod[target][source] = dist_dod[source][target] + + nearest_neighbor_of = [ + min(dist_dod[u], key=dist_dod[u].get) for u in range(3) + ] + self.assertEqual([2, 2, 1], [nearest_neighbor_of[u] for u in range(3)]) + def test_traversals(self): for ts in get_example_tree_sequences(): tree = next(ts.trees()) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 8fdcaad374..dfa87e5476 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1600,6 +1600,33 @@ def newick(self, precision=14, root=None, node_labels=None): return self.__build_newick(root, precision, node_labels) + ";" return s + def as_dict_of_dicts(self): + """ + Convert tree to dict of dicts for conversion to a + `networkx graph `_. + + For example:: + + >>> import networkx as nx + >>> nx.DiGraph(tree.as_dict_of_dicts()) + >>> # undirected graphs work as well + >>> nx.Graph(tree.as_dict_of_dicts()) + + :return: Dictionary of dictionaries of dictionaries where the first key + is the source, the second key is the target of an edge, and the + third key is an edge annotation. At this point the only annotation + is "branch_length", the length of the branch (in generations). + """ + dod = {} + for parent in self.nodes(): + dod[parent] = {} + for child in self.children(parent): + dod[parent][child] = { + 'branch_length': self.branch_length(child) + } + return dod + @property def parent_dict(self): return self.get_parent_dict()