From 72474ec0920504a86163dd07c96d3dcde1aadd75 Mon Sep 17 00:00:00 2001 From: "j.guez" Date: Thu, 19 May 2022 13:42:03 +0200 Subject: [PATCH] Adding B1 balance index and tests --- python/tests/test_balance_metrics.py | 44 ++++++++++++++++++++++++++++ python/tskit/trees.py | 23 +++++++++++++++ 2 files changed, 67 insertions(+) diff --git a/python/tests/test_balance_metrics.py b/python/tests/test_balance_metrics.py index 28b816b8c2..d9e125e1c0 100644 --- a/python/tests/test_balance_metrics.py +++ b/python/tests/test_balance_metrics.py @@ -55,6 +55,14 @@ def colless_index_definition(tree): ) +def b1_index_definition(tree): + return sum( + 1 / max(tree.path_length(n, leaf) for leaf in tree.leaves(n)) + for n in tree.nodes() + if tree.parent(n) != tskit.NULL and tree.is_internal(n) + ) + + class TestDefinitions: @pytest.mark.parametrize("ts", get_example_tree_sequences()) def test_sackin(self, ts): @@ -75,6 +83,11 @@ def test_colless(self, ts): else: assert tree.colless_index() == colless_index_definition(tree) + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_b1(self, ts): + for tree in ts.trees(): + assert tree.b1_index() == pytest.approx(b1_index_definition(tree)) + class TestBalancedBinaryOdd: # 2.00┊ 4 ┊ @@ -93,6 +106,9 @@ def test_sackin(self): def test_colless(self): assert self.tree().colless_index() == 1 + def test_b1(self): + assert self.tree().b1_index() == 1 + class TestBalancedBinaryEven: # 2.00┊ 6 ┊ @@ -111,6 +127,9 @@ def test_sackin(self): def test_colless(self): assert self.tree().colless_index() == 0 + def test_b1(self): + assert self.tree().b1_index() == 2 + class TestBalancedTernary: # 2.00┊ 12 ┊ @@ -130,6 +149,9 @@ def test_colless(self): with pytest.raises(ValueError): self.tree().colless_index() + def test_b1(self): + assert self.tree().b1_index() == 3 + class TestStarN10: # 1.00┊ 10 ┊ @@ -147,6 +169,9 @@ def test_colless(self): with pytest.raises(ValueError): self.tree().colless_index() + def test_b1(self): + assert self.tree().b1_index() == 0 + class TestCombN5: # 4.00┊ 8 ┊ @@ -169,6 +194,9 @@ def test_sackin(self): def test_colless(self): assert self.tree().colless_index() == 6 + def test_b1(self): + assert self.tree().b1_index() == pytest.approx(1.833, rel=1e-3) + class TestMultiRootBinary: # 3.00┊ 15 ┊ @@ -196,6 +224,9 @@ def test_colless(self): with pytest.raises(ValueError): self.tree().colless_index() + def test_b1(self): + assert self.tree().b1_index() == 4.5 + class TestEmpty: @tests.cached_example @@ -210,6 +241,9 @@ def test_colless(self): with pytest.raises(ValueError): self.tree().colless_index() + def test_b1(self): + assert self.tree().b1_index() == 0 + class TestTreeInNullState: @tests.cached_example @@ -221,6 +255,13 @@ def tree(self): def test_sackin(self): assert self.tree().sackin_index() == 0 + def test_colless(self): + with pytest.raises(ValueError): + self.tree().colless_index() + + def test_b1(self): + assert self.tree().b1_index() == 0 + class TestAllRootsN5: @tests.cached_example @@ -236,3 +277,6 @@ def test_sackin(self): def test_colless(self): with pytest.raises(ValueError): self.tree().colless_index() + + def test_b1(self): + assert self.tree().b1_index() == 0 diff --git a/python/tskit/trees.py b/python/tskit/trees.py index f432dfaca6..2282b8f541 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -2814,6 +2814,29 @@ def path_length(self, u, v): return math.inf return self.depth(u) + self.depth(v) - 2 * self.depth(mrca) + def b1_index(self): + """ + Returns the B1 balance index for this tree. + This is defined as the inverse of the sum of all longest paths + to leaves for each node besides roots. + + .. seealso:: See `Shao and Sokal (1990) + `_ for details. + + :return: The B1 balance index. + :rtype: float + """ + # TODO implement in C + max_path_length = np.zeros(self.tree_sequence.num_nodes, dtype=int) + total = 0.0 + for u in self.postorder(): + if self.parent(u) != tskit.NULL and self.is_internal(u): + max_path_length[u] = 1 + max( + max_path_length[v] for v in self.children(u) + ) + total += 1 / max_path_length[u] + return total + def colless_index(self): """ Returns the Colless imbalance index for this tree.