diff --git a/docs/_static/different_time_samples.svg b/docs/_static/different_time_samples.svg
new file mode 100644
index 0000000000..b93b0f4dca
--- /dev/null
+++ b/docs/_static/different_time_samples.svg
@@ -0,0 +1,40 @@
+
+
diff --git a/docs/examples.py b/docs/examples.py
index 6d68f8b177..5d11597f87 100644
--- a/docs/examples.py
+++ b/docs/examples.py
@@ -331,6 +331,29 @@ def preorder_dist(tree):
print(list(preorder_dist(tree)))
+def finding_nearest_neighbors():
+ samples = [
+ msprime.Sample(0, 0),
+ msprime.Sample(0, 1),
+ msprime.Sample(0, 20),
+ ]
+ ts = msprime.simulate(
+ Ne=1e6,
+ samples=samples,
+ demographic_events=[
+ msprime.PopulationParametersChange(
+ time=10, growth_rate=2, population_id=0
+ ),
+ ],
+ random_seed=42,
+ )
+
+ tree = ts.first()
+ tree.draw_svg("_static/different_time_samples.svg",
+ tree_height_scale="rank")
+
+
+
# moving_along_tree_sequence()
# parsimony()
# allele_frequency_spectra()
@@ -338,4 +361,5 @@ def preorder_dist(tree):
# stats()
# tree_structure()
tree_traversal()
+finding_nearest_neighbors()
diff --git a/docs/tutorial.rst b/docs/tutorial.rst
index b185478ccf..8609c77991 100644
--- a/docs/tutorial.rst
+++ b/docs/tutorial.rst
@@ -111,6 +111,136 @@ Running this on the example above gives us::
[(7, 0), (6, 1), (4, 2), (3, 2), (5, 1), (2, 2), (1, 2), (0, 2)]
+
+++++++++++++++++++++++++
+Traversals with networkx
+++++++++++++++++++++++++
+
+Traversals and other network analysis can also be performed using the sizeable
+`networkx `_
+library. This can be achieved by calling :meth:`Tree.as_dict_of_dicts` to
+convert a :class:`Tree` instance to a format that can be imported by networkx to
+create a graph::
+
+ import networkx as nx
+
+ g = nx.DiGraph(tree.as_dict_of_dicts())
+ print(sorted(g.edges))
+
+::
+
+ [(5, 0), (5, 1), (5, 2), (6, 3), (6, 4), (7, 5), (7, 6)]
+
+++++++++++++++++++
+Traversing upwards
+++++++++++++++++++
+
+We can revisit the above examples and traverse upwards with
+networkx using a depth-first search algorithm::
+
+ import networkx as nx
+
+ g = nx.DiGraph(tree.as_dict_of_dicts())
+ for u in tree.samples():
+ path = [u] + [parent for parent, child, _ in
+ nx.edge_dfs(g, source=u, orientation="reverse")]
+ print(u, "->", path)
+
+giving::
+
+ 0 -> [0, 5, 7]
+ 1 -> [1, 5, 7]
+ 2 -> [2, 5, 7]
+ 3 -> [3, 6, 7]
+ 4 -> [4, 6, 7]
+
++++++++++++++++++++++++++++++++++
+Calculating distances to the root
++++++++++++++++++++++++++++++++++
+
+Similarly, we can yield the nodes of a tree along with their distance to the
+root in pre-order in networkx as well::
+
+ import networkx as nx
+
+ g = nx.DiGraph(tree.as_dict_of_dicts())
+ for root in tree.roots:
+ print(nx.shortest_path_length(g, source=root).items())
+
+Running this on the example above gives us the same result as before::
+
+ [(7, 0), (6, 1), (4, 2), (3, 2), (5, 1), (2, 2), (1, 2), (0, 2)]
+
++++++++++++++++++++++++++++++++++++++++
+Finding nearest neighbors
++++++++++++++++++++++++++++++++++++++++
+
+If some samples in a tree are not at time 0, then finding the nearest neighbor
+of a sample is a bit more involved. Instead of writing our own traversal code
+we can again draw on a networkx algorithm.
+Let us start with an example tree with three samples that were sampled at
+different time points:
+
+.. image:: _static/different_time_samples.svg
+ :width: 200px
+ :alt: An example tree with samples taken at different times
+
+The generation times for these nodes are:
+
+.. code-block:: python
+
+ for u in tree.nodes():
+ print(u, tree.time(u))
+
+giving::
+
+ 4 20.005398778263334
+ 2 20.0
+ 3 17.833492457579652
+ 0 0.0
+ 1 1.0
+
+Note that samples 0 and 1 are about 35 generations apart from each other even though
+they were sampled at almost the same time. This is why samples 0 and 1 are
+closer to sample 2 than to each other.
+
+For this nearest neighbor search we will be traversing up and down the tree,
+so it is easier to treat the tree as an undirected graph::
+
+ g = nx.Graph(tree.as_dict_of_dicts())
+
+When converting the tree to a networkx graph the edges are annotated with their
+branch length::
+
+ print(g.edges(data=True))
+
+giving::
+
+ [(4, 2, {'branch_length': 0.005398778263334236}),
+ (4, 3, {'branch_length': 2.171906320683682}),
+ (3, 0, {'branch_length': 17.833492457579652}),
+ (3, 1, {'branch_length': 16.833492457579652})]
+
+We can now use the "branch_length" field as a weight for a weighted shortest path
+search::
+
+ # a dictionary of dictionaries to represent our distance matrix
+ dist_dod = collections.defaultdict(dict)
+ for source, target in itertools.combinations(tree.samples(), 2):
+ dist_dod[source][target] = nx.shortest_path_length(
+ g, source=source, target=target, weight="branch_length"
+ )
+ dist_dod[target][source] = dist_dod[source][target]
+
+ # extract the nearest neighbor of nodes 0, 1, and 2
+ nearest_neighbor_of = [min(dist_dod[u], key=dist_dod[u].get) for u in range(3)]
+
+ print(dict(zip(range(3), nearest_neighbor_of)))
+
+gives::
+
+ {0: 2, 1: 2, 2: 1}
+
.. _sec_tutorial_moving_along_a_tree_sequence:
****************************
diff --git a/python/requirements/conda-minimal.txt b/python/requirements/conda-minimal.txt
index fde8aac5b2..c7104af9bf 100644
--- a/python/requirements/conda-minimal.txt
+++ b/python/requirements/conda-minimal.txt
@@ -6,3 +6,4 @@ svgwrite
msprime
kastore
biopython
+networkx
diff --git a/python/requirements/development.txt b/python/requirements/development.txt
index e136ab217b..31edbe3218 100644
--- a/python/requirements/development.txt
+++ b/python/requirements/development.txt
@@ -19,3 +19,4 @@ pysam
PyVCF
python_jsonschema_objects
biopython
+networkx
\ No newline at end of file
diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py
index 5dc1a1aba7..3602fc636c 100644
--- a/python/tests/test_highlevel.py
+++ b/python/tests/test_highlevel.py
@@ -45,6 +45,7 @@
import tests as tests
import tests.tsutil as tsutil
import tests.simplify as simplify
+import networkx as nx
def insert_uniform_mutations(tables, num_mutations, nodes):
@@ -602,7 +603,6 @@ def verify_edge_diffs(self, ts):
children[edge.parent].add(edge.child)
while tree.interval[1] <= left:
tree = next(trees)
- # print(left, right, tree.interval)
self.assertTrue(left >= tree.interval[0])
self.assertTrue(right <= tree.interval[1])
for u in tree.nodes():
@@ -1660,6 +1660,139 @@ def test_newick(self):
for tree in ts.trees():
self.verify_newick(tree)
+ def test_as_dict_of_dicts(self):
+ for ts in get_example_tree_sequences():
+ tree = next(ts.trees())
+ adj_dod = tree.as_dict_of_dicts()
+ g = nx.DiGraph(adj_dod)
+
+ self.verify_nx_graph_topology(tree, g)
+ self.verify_nx_algorithm_equivalence(tree, g)
+ self.verify_nx_for_tutorial_algorithms(tree, g)
+ self.verify_nx_nearest_neighbor_search()
+
+ def verify_nx_graph_topology(self, tree, g):
+ self.assertSetEqual(set(tree.nodes()), set(g.nodes))
+
+ self.assertSetEqual(
+ set(tree.roots),
+ {n for n in g.nodes if g.in_degree(n) == 0}
+ )
+
+ self.assertSetEqual(
+ set(tree.leaves()),
+ {n for n in g.nodes if g.out_degree(n) == 0}
+ )
+
+ # test if tree has no in-degrees > 1
+ self.assertTrue(nx.is_branching(g))
+
+ def verify_nx_algorithm_equivalence(self, tree, g):
+ for root in tree.roots:
+ self.assertTrue(nx.is_directed_acyclic_graph(g))
+
+ # test descendants
+ self.assertSetEqual(
+ set(u for u in tree.nodes() if tree.is_descendant(u, root)),
+ set(nx.descendants(g, root)) | {root}
+ )
+
+ # test MRCA
+ if tree.num_nodes < 20:
+ for u, v in itertools.combinations(tree.nodes(), 2):
+ mrca = nx.lowest_common_ancestor(g, u, v)
+ if mrca is None:
+ mrca = -1
+ self.assertEqual(tree.mrca(u, v), mrca)
+
+ # test node traversal modes
+ self.assertEqual(
+ list(tree.nodes(root=root, order="breadthfirst")),
+ [root] + [v for u, v in nx.bfs_edges(g, root)]
+ )
+ self.assertEqual(
+ list(tree.nodes(root=root, order="preorder")),
+ list(nx.dfs_preorder_nodes(g, root))
+ )
+
+ def verify_nx_for_tutorial_algorithms(self, tree, g):
+ # traversing upwards
+ for u in tree.leaves():
+ path = []
+ v = u
+ while v != tskit.NULL:
+ path.append(v)
+ v = tree.parent(v)
+
+ self.assertSetEqual(set(path), {u} | nx.ancestors(g, u))
+ self.assertEqual(
+ path,
+ [u] +
+ [n1 for n1, n2, _ in nx.edge_dfs(g, u, orientation="reverse")]
+ )
+
+ # traversals with information
+ def preorder_dist(tree, root):
+ stack = [(root, 0)]
+ while len(stack) > 0:
+ u, distance = stack.pop()
+ yield u, distance
+ for v in tree.children(u):
+ stack.append((v, distance + 1))
+
+ for root in tree.roots:
+ self.assertDictEqual(
+ {k: v for k, v in preorder_dist(tree, root)},
+ nx.shortest_path_length(g, source=root)
+ )
+
+ for root in tree.roots:
+ # new traversal: measuring time between root and MRCA
+ for u, v in itertools.combinations(nx.descendants(g, root), 2):
+ mrca = tree.mrca(u, v)
+ tmrca = tree.time(mrca)
+ self.assertAlmostEqual(
+ tree.time(root) - tmrca,
+ nx.shortest_path_length(
+ g,
+ source=root,
+ target=mrca,
+ weight='branch_length'
+ )
+ )
+
+ def verify_nx_nearest_neighbor_search(self):
+ samples = [
+ msprime.Sample(0, 0),
+ msprime.Sample(0, 1),
+ msprime.Sample(0, 20),
+ ]
+ ts = msprime.simulate(
+ Ne=1e6,
+ samples=samples,
+ demographic_events=[
+ msprime.PopulationParametersChange(
+ time=10, growth_rate=2, population_id=0
+ ),
+ ],
+ random_seed=42,
+ )
+
+ tree = ts.first()
+ g = nx.Graph(tree.as_dict_of_dicts())
+
+ dist_dod = collections.defaultdict(dict)
+ for source, target in itertools.combinations(tree.samples(), 2):
+ dist_dod[source][target] = nx.shortest_path_length(
+ g, source=source, target=target, weight='branch_length'
+ )
+ dist_dod[target][source] = dist_dod[source][target]
+
+ nearest_neighbor_of = [
+ min(dist_dod[u], key=dist_dod[u].get) for u in range(3)
+ ]
+ self.assertEqual([2, 2, 1], [nearest_neighbor_of[u] for u in range(3)])
+
def test_traversals(self):
for ts in get_example_tree_sequences():
tree = next(ts.trees())
diff --git a/python/tskit/trees.py b/python/tskit/trees.py
index 8fdcaad374..dfa87e5476 100644
--- a/python/tskit/trees.py
+++ b/python/tskit/trees.py
@@ -1600,6 +1600,33 @@ def newick(self, precision=14, root=None, node_labels=None):
return self.__build_newick(root, precision, node_labels) + ";"
return s
+ def as_dict_of_dicts(self):
+ """
+ Convert tree to dict of dicts for conversion to a
+ `networkx graph `_.
+
+ For example::
+
+ >>> import networkx as nx
+ >>> nx.DiGraph(tree.as_dict_of_dicts())
+ >>> # undirected graphs work as well
+ >>> nx.Graph(tree.as_dict_of_dicts())
+
+ :return: Dictionary of dictionaries of dictionaries where the first key
+ is the source, the second key is the target of an edge, and the
+ third key is an edge annotation. At this point the only annotation
+ is "branch_length", the length of the branch (in generations).
+ """
+ dod = {}
+ for parent in self.nodes():
+ dod[parent] = {}
+ for child in self.children(parent):
+ dod[parent][child] = {
+ 'branch_length': self.branch_length(child)
+ }
+ return dod
+
@property
def parent_dict(self):
return self.get_parent_dict()