Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions python/tests/test_balance_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def colless_index_definition(tree):
)


def b1_index_definition(tree):
return sum(
1 / max(tree.path_length(n, leaf) for leaf in tree.leaves(n))
for n in tree.nodes()
if tree.parent(n) != tskit.NULL and tree.is_internal(n)
)


class TestDefinitions:
@pytest.mark.parametrize("ts", get_example_tree_sequences())
def test_sackin(self, ts):
Expand All @@ -75,6 +83,11 @@ def test_colless(self, ts):
else:
assert tree.colless_index() == colless_index_definition(tree)

@pytest.mark.parametrize("ts", get_example_tree_sequences())
def test_b1(self, ts):
for tree in ts.trees():
assert tree.b1_index() == pytest.approx(b1_index_definition(tree))


class TestBalancedBinaryOdd:
# 2.00┊ 4 ┊
Expand All @@ -93,6 +106,9 @@ def test_sackin(self):
def test_colless(self):
assert self.tree().colless_index() == 1

def test_b1(self):
assert self.tree().b1_index() == 1


class TestBalancedBinaryEven:
# 2.00┊ 6 ┊
Expand All @@ -111,6 +127,9 @@ def test_sackin(self):
def test_colless(self):
assert self.tree().colless_index() == 0

def test_b1(self):
assert self.tree().b1_index() == 2


class TestBalancedTernary:
# 2.00┊ 12 ┊
Expand All @@ -130,6 +149,9 @@ def test_colless(self):
with pytest.raises(ValueError):
self.tree().colless_index()

def test_b1(self):
assert self.tree().b1_index() == 3


class TestStarN10:
# 1.00┊ 10 ┊
Expand All @@ -147,6 +169,9 @@ def test_colless(self):
with pytest.raises(ValueError):
self.tree().colless_index()

def test_b1(self):
assert self.tree().b1_index() == 0


class TestCombN5:
# 4.00┊ 8 ┊
Expand All @@ -169,6 +194,9 @@ def test_sackin(self):
def test_colless(self):
assert self.tree().colless_index() == 6

def test_b1(self):
assert self.tree().b1_index() == pytest.approx(1.833, rel=1e-3)


class TestMultiRootBinary:
# 3.00┊ 15 ┊
Expand Down Expand Up @@ -196,6 +224,9 @@ def test_colless(self):
with pytest.raises(ValueError):
self.tree().colless_index()

def test_b1(self):
assert self.tree().b1_index() == 4.5


class TestEmpty:
@tests.cached_example
Expand All @@ -210,6 +241,9 @@ def test_colless(self):
with pytest.raises(ValueError):
self.tree().colless_index()

def test_b1(self):
assert self.tree().b1_index() == 0


class TestTreeInNullState:
@tests.cached_example
Expand All @@ -221,6 +255,13 @@ def tree(self):
def test_sackin(self):
assert self.tree().sackin_index() == 0

def test_colless(self):
with pytest.raises(ValueError):
self.tree().colless_index()

def test_b1(self):
assert self.tree().b1_index() == 0


class TestAllRootsN5:
@tests.cached_example
Expand All @@ -236,3 +277,6 @@ def test_sackin(self):
def test_colless(self):
with pytest.raises(ValueError):
self.tree().colless_index()

def test_b1(self):
assert self.tree().b1_index() == 0
23 changes: 23 additions & 0 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -2814,6 +2814,29 @@ def path_length(self, u, v):
return math.inf
return self.depth(u) + self.depth(v) - 2 * self.depth(mrca)

def b1_index(self):
"""
Returns the B1 balance index for this tree.
This is defined as the inverse of the sum of all longest paths
to leaves for each node besides roots.

.. seealso:: See `Shao and Sokal (1990)
<https://www.jstor.org/stable/2992186>`_ for details.

:return: The B1 balance index.
:rtype: float
"""
# TODO implement in C
max_path_length = np.zeros(self.tree_sequence.num_nodes, dtype=int)
total = 0.0
for u in self.postorder():
if self.parent(u) != tskit.NULL and self.is_internal(u):
max_path_length[u] = 1 + max(
max_path_length[v] for v in self.children(u)
)
total += 1 / max_path_length[u]
return total

def colless_index(self):
"""
Returns the Colless imbalance index for this tree.
Expand Down