Skip to content

Commit

Permalink
Adding path_length method and the corresponding tests
Browse files Browse the repository at this point in the history
Minor change in path_length documentation

Minor changes and added tests for path_length method

Add tests for path_length method
  • Loading branch information
jeremyguez committed May 11, 2022
1 parent 5bff374 commit 04da881
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 0 deletions.
62 changes: 62 additions & 0 deletions python/tests/test_highlevel.py
Expand Up @@ -885,6 +885,68 @@ def test_virtual_root_arg(self):
assert self.t.mrca(0, 5) == 5


class TestPathLength:
t = tskit.Tree.generate_balanced(9)
# 16
# ┏━━━━┻━━━┓
# ┃ 15
# ┃ ┏━━┻━┓
# 11 ┃ 14
# ┏━┻━┓ ┃ ┏━┻┓
# 9 10 12 ┃ 13
# ┏┻┓ ┏┻┓ ┏┻┓ ┃ ┏┻┓
# 0 1 2 3 4 5 6 7 8

def test_tmrca_leaf(self):
assert self.t.path_length(0, 16) == 3
assert self.t.path_length(16, 0) == 3
assert self.t.path_length(7, 16) == 4

def test_equal_depth(self):
assert self.t.path_length(5, 16) == self.t.depth(5)

def test_two_leaves(self):
assert self.t.path_length(0, 8) == 7

def test_two_leaves_depth(self):
assert self.t.path_length(0, 8) == self.t.depth(0) + self.t.depth(8)

@pytest.mark.parametrize("args", [[], [1], [1, 2, 3]])
def test_bad_num_args(self, args):
with pytest.raises(TypeError):
self.t.path_length(*args)

@pytest.mark.parametrize("bad_arg", [[], "1"])
def test_bad_arg_type(self, bad_arg):
with pytest.raises(TypeError):
self.t.path_length(0, bad_arg)
with pytest.raises(TypeError):
self.t.path_length(bad_arg, 0)

def test_same_args(self):
assert self.t.path_length(10, 10) == 0

def test_different_tree_levels(self):
assert self.t.path_length(1, 10) == 3

def test_out_of_bounds_args(self):
with pytest.raises(ValueError):
self.t.path_length(0, 20)

@pytest.mark.parametrize("u", range(17))
def test_virtual_root_arg(self, u):
assert self.t.path_length(u, self.t.virtual_root) == self.t.depth(u) + 1
assert self.t.path_length(self.t.virtual_root, u) == self.t.depth(u) + 1

def test_both_args_virtual_root(self):
assert self.t.path_length(self.t.virtual_root, self.t.virtual_root) == 0

def test_no_mrca(self):
tree = self.t.copy()
tree.clear()
assert math.isinf(tree.path_length(0, 1))


class TestMRCACalculator:
"""
Class to test the Schieber-Vishkin algorithm.
Expand Down
23 changes: 23 additions & 0 deletions python/tskit/trees.py
Expand Up @@ -2793,6 +2793,29 @@ def kc_distance(self, other, lambda_=0.0):
"""
return self._ll_tree.get_kc_distance(other._ll_tree, lambda_)

def path_length(self, u, v):

"""
Returns the path length between two nodes
(i.e., the number of edges between two nodes in this tree).
If the two nodes have a most recent common ancestor, then this is defined as
``tree.depth(u) + tree.depth(v) - 2 * tree.depth(tree.mrca(u, v))``. If the nodes
do not have an MRCA (i.e., they are in disconnected subtrees) the path length
is infinity.
.. seealso:: See also the :meth:`.depth` method
:param int u: The first node for path length computation.
:param int v: The second node for path length computation.
:return: The number of edges between the two nodes.
:rtype: int
"""

mrca = self.mrca(u, v)
if mrca == -1:
return math.inf
return self.depth(u) + self.depth(v) - 2 * self.depth(mrca)

def sackin_index(self):
"""
Returns the Sackin imbalance index for this tree. This is defined
Expand Down

0 comments on commit 04da881

Please sign in to comment.