diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c
index 8e14d57e7c..11f2ac2b4b 100644
--- a/c/tests/test_trees.c
+++ b/c/tests/test_trees.c
@@ -6629,7 +6629,7 @@ test_single_tree_balance(void)
tsk_treeseq_t ts;
tsk_tree_t t;
tsk_size_t sackin, colless;
- double b1;
+ double b1, b2;
tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL,
NULL, NULL, NULL, 0);
@@ -6645,6 +6645,13 @@ test_single_tree_balance(void)
CU_ASSERT_EQUAL(colless, 0);
CU_ASSERT_EQUAL_FATAL(tsk_tree_b1_index(&t, &b1), 0);
CU_ASSERT_DOUBLE_EQUAL(b1, 2, 1e-8);
+ /* Test different bases for b2_index to high-precision */
+ CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 10, &b2), 0);
+ CU_ASSERT_DOUBLE_EQUAL(b2, 0.6020599913279623, 1e-14);
+ CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 2, &b2), 0);
+ CU_ASSERT_DOUBLE_EQUAL_FATAL(b2, 2, 1e-16);
+ CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 3, &b2), 0);
+ CU_ASSERT_DOUBLE_EQUAL_FATAL(b2, 1.2618595071429148, 1e-14);
tsk_treeseq_free(&ts);
tsk_tree_free(&t);
@@ -6683,6 +6690,7 @@ test_multiroot_balance(void)
CU_ASSERT_EQUAL_FATAL(tsk_tree_colless_index(&t, NULL), TSK_ERR_UNDEFINED_MULTIROOT);
CU_ASSERT_EQUAL_FATAL(tsk_tree_b1_index(&t, &b1), 0);
CU_ASSERT_DOUBLE_EQUAL(b1, 1.0, 1e-8);
+ CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 10, NULL), TSK_ERR_UNDEFINED_MULTIROOT);
tsk_treeseq_free(&ts);
tsk_tree_free(&t);
@@ -6701,7 +6709,7 @@ test_nonbinary_balance(void)
tsk_treeseq_t ts;
tsk_tree_t t;
tsk_size_t sackin, colless;
- double b1;
+ double b1, b2;
tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0);
ret = tsk_tree_init(&t, &ts, 0);
@@ -6716,6 +6724,8 @@ test_nonbinary_balance(void)
tsk_tree_colless_index(&t, &colless), TSK_ERR_UNDEFINED_NONBINARY);
CU_ASSERT_EQUAL_FATAL(tsk_tree_b1_index(&t, &b1), 0);
CU_ASSERT_DOUBLE_EQUAL_FATAL(b1, 0, 1e-8);
+ CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 10, &b2), 0);
+ CU_ASSERT_DOUBLE_EQUAL_FATAL(b1, 0, 1e-8);
tsk_treeseq_free(&ts);
tsk_tree_free(&t);
@@ -6729,7 +6739,7 @@ test_empty_tree_balance(void)
tsk_treeseq_t ts;
tsk_tree_t t;
tsk_size_t sackin, colless;
- double b1;
+ double b1, b2;
ret = tsk_table_collection_init(&tables, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
@@ -6748,12 +6758,46 @@ test_empty_tree_balance(void)
tsk_tree_colless_index(&t, &colless), TSK_ERR_UNDEFINED_MULTIROOT);
CU_ASSERT_EQUAL_FATAL(tsk_tree_b1_index(&t, &b1), 0);
CU_ASSERT_EQUAL(b1, 0);
+ CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 10, &b2), TSK_ERR_UNDEFINED_MULTIROOT);
tsk_table_collection_free(&tables);
tsk_treeseq_free(&ts);
tsk_tree_free(&t);
}
+static void
+test_b2_bad_base(void)
+{
+ int ret;
+ tsk_treeseq_t ts;
+ tsk_tree_t t;
+ double result;
+ double bad_base[] = { -2, -1, 1 };
+ size_t j;
+
+ tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL,
+ NULL, NULL, NULL, 0);
+ ret = tsk_tree_init(&t, &ts, 0);
+ CU_ASSERT_EQUAL_FATAL(ret, 0);
+ ret = tsk_tree_first(&t);
+ CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK);
+
+ for (j = 0; j < sizeof(bad_base) / sizeof(*bad_base); j++) {
+ ret = tsk_tree_b2_index(&t, bad_base[j], &result);
+ CU_ASSERT_EQUAL_FATAL(ret, 0);
+ CU_ASSERT_FALSE(tsk_isfinite(result));
+ }
+ CU_ASSERT_FATAL(j > 0);
+
+ /* this one is peculiar, in that base 0 seems to give a finite answer */
+ ret = tsk_tree_b2_index(&t, 0, &result);
+ CU_ASSERT_EQUAL_FATAL(ret, 0);
+ CU_ASSERT_EQUAL_FATAL(result, 0);
+
+ tsk_treeseq_free(&ts);
+ tsk_tree_free(&t);
+}
+
static void
test_tree_errors(void)
{
@@ -7950,6 +7994,7 @@ main(int argc, char **argv)
{ "test_multiroot_balance", test_multiroot_balance },
{ "test_nonbinary_balance", test_nonbinary_balance },
{ "test_empty_tree_balance", test_empty_tree_balance },
+ { "test_b2_bad_base", test_b2_bad_base },
/* Misc */
{ "test_tree_errors", test_tree_errors },
diff --git a/c/tskit/trees.c b/c/tskit/trees.c
index b921254ff8..92cc789c5b 100644
--- a/c/tskit/trees.c
+++ b/c/tskit/trees.c
@@ -5009,6 +5009,68 @@ tsk_tree_b1_index(const tsk_tree_t *self, double *result)
return ret;
}
+static double
+general_log(double x, double base)
+{
+ return log(x) / log(base);
+}
+
+int
+tsk_tree_b2_index(const tsk_tree_t *self, double base, double *result)
+{
+ struct stack_elem {
+ tsk_id_t node;
+ double path_product;
+ };
+ int ret = 0;
+ const tsk_id_t *restrict right_child = self->right_child;
+ const tsk_id_t *restrict left_sib = self->left_sib;
+ struct stack_elem *stack
+ = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*stack));
+ int stack_top;
+ double total_proba = 0;
+ double num_children;
+ tsk_id_t u;
+ struct stack_elem s = { .node = TSK_NULL, .path_product = 1 };
+
+ if (stack == NULL) {
+ ret = TSK_ERR_NO_MEMORY;
+ goto out;
+ }
+ if (tsk_tree_get_num_roots(self) != 1) {
+ ret = TSK_ERR_UNDEFINED_MULTIROOT;
+ goto out;
+ }
+
+ stack_top = 0;
+ s.node = tsk_tree_get_left_root(self);
+ stack[stack_top] = s;
+
+ while (stack_top >= 0) {
+ s = stack[stack_top];
+ stack_top--;
+ u = right_child[s.node];
+ if (u == TSK_NULL) {
+ total_proba -= s.path_product * general_log(s.path_product, base);
+ } else {
+ num_children = 0;
+ for (; u != TSK_NULL; u = left_sib[u]) {
+ num_children++;
+ }
+ s.path_product *= 1 / num_children;
+ for (u = right_child[s.node]; u != TSK_NULL; u = left_sib[u]) {
+ stack_top++;
+ s.node = u;
+ stack[stack_top] = s;
+ }
+ }
+ }
+ *result = total_proba;
+out:
+ tsk_safe_free(stack);
+ return ret;
+}
+
/* Parsimony methods */
static inline uint64_t
diff --git a/c/tskit/trees.h b/c/tskit/trees.h
index 6cd03dc986..509e287a03 100644
--- a/c/tskit/trees.h
+++ b/c/tskit/trees.h
@@ -1696,6 +1696,11 @@ int tsk_tree_kc_distance(
int tsk_tree_sackin_index(const tsk_tree_t *self, tsk_size_t *result);
int tsk_tree_colless_index(const tsk_tree_t *self, tsk_size_t *result);
int tsk_tree_b1_index(const tsk_tree_t *self, double *result);
+/* NOTE: if we document this as part of the C API we'll have to be more careful
+ * about the error behaviour on bad log bases. At the moment we're just returning
+ * the resulting value which can be nan, inf etc, but some surprising results
+ * happen like a base 0 seems to return a finite value. */
+int tsk_tree_b2_index(const tsk_tree_t *self, double base, double *result);
/* Things to consider removing: */
diff --git a/docs/python-api.md b/docs/python-api.md
index 79a44c05cc..7aad6f155a 100644
--- a/docs/python-api.md
+++ b/docs/python-api.md
@@ -555,6 +555,7 @@ Functions and static methods
Tree.colless_index
Tree.sackin_index
Tree.b1_index
+ Tree.b2_index
```
(sec_python_api_trees_sites_mutations)=
diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst
index ff74095ab6..94039b9a22 100644
--- a/python/CHANGELOG.rst
+++ b/python/CHANGELOG.rst
@@ -57,6 +57,9 @@
- Add B1 tree balance index.
(:user:`jeremyguez`, :user:`jeromekelleher`, :issue:`2251`, :pr:`2281`, :pr:`2346`).
+- Add B2 tree balance index.
+ (:user:`jeremyguez`, :user:`jeromekelleher`, :issue:`2252`, :pr:`2353`, :pr:`2354`).
+
- Add Sackin tree imbalance index.
(:user:`jeremyguez`, :user:`jeromekelleher`, :pr:`2246`, :pr:`2258`).
diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c
index 92865dd426..b00de231c9 100644
--- a/python/_tskitmodule.c
+++ b/python/_tskitmodule.c
@@ -10817,6 +10817,30 @@ Tree_get_b1_index(Tree *self)
return ret;
}
+static PyObject *
+Tree_get_b2_index(Tree *self, PyObject *args)
+{
+ PyObject *ret = NULL;
+ int err;
+ double base;
+ double result;
+
+ if (Tree_check_state(self) != 0) {
+ goto out;
+ }
+ if (!PyArg_ParseTuple(args, "d", &base)) {
+ goto out;
+ }
+ err = tsk_tree_b2_index(self->tree, base, &result);
+ if (err != 0) {
+ handle_library_error(err);
+ goto out;
+ }
+ ret = Py_BuildValue("d", result);
+out:
+ return ret;
+}
+
static PyObject *
Tree_get_root_threshold(Tree *self)
{
@@ -11240,6 +11264,10 @@ static PyMethodDef Tree_methods[] = {
.ml_meth = (PyCFunction) Tree_get_b1_index,
.ml_flags = METH_NOARGS,
.ml_doc = "Returns the B1 index for this tree." },
+ { .ml_name = "get_b2_index",
+ .ml_meth = (PyCFunction) Tree_get_b2_index,
+ .ml_flags = METH_VARARGS,
+ .ml_doc = "Returns the B2 index for this tree." },
{ NULL } /* Sentinel */
};
diff --git a/python/tests/test_balance_metrics.py b/python/tests/test_balance_metrics.py
index 2b31753d75..f3c16790ea 100644
--- a/python/tests/test_balance_metrics.py
+++ b/python/tests/test_balance_metrics.py
@@ -35,7 +35,7 @@
# we can remove this.
-def path(tree, u):
+def node_path(tree, u):
path = []
u = tree.parent(u)
while u != tskit.NULL:
@@ -79,7 +79,7 @@ def b2_index_definition(tree, base=10):
if tree.num_roots != 1:
raise ValueError("B2 index is only defined for trees with one root")
proba = [
- np.prod([1 / tree.num_children(u) for u in path(tree, leaf)])
+ np.prod([1 / tree.num_children(u) for u in node_path(tree, leaf)])
for leaf in tree.leaves()
]
return -sum(p * math.log(p, base) for p in proba)
@@ -111,18 +111,27 @@ def test_b1(self, ts):
assert tree.b1_index() == pytest.approx(b1_index_definition(tree))
@pytest.mark.parametrize("ts", get_example_tree_sequences())
- @pytest.mark.parametrize("base", [2, 10, math.e, np.array([3])[0]])
- def test_b2_base(self, ts, base):
+ def test_b2(self, ts):
for tree in ts.trees():
if tree.num_roots != 1:
+ with pytest.raises(tskit.LibraryError, match="MULTIROOT"):
+ tree.b2_index()
with pytest.raises(ValueError):
+ b2_index_definition(tree)
+ else:
+ assert tree.b2_index() == b2_index_definition(tree)
+
+ @pytest.mark.parametrize("ts", get_example_tree_sequences())
+ @pytest.mark.parametrize("base", [0.1, 1.1, 2, 10, math.e, np.array([3])[0]])
+ def test_b2_base(self, ts, base):
+ for tree in ts.trees():
+ if tree.num_roots != 1:
+ with pytest.raises(tskit.LibraryError, match="MULTIROOT"):
tree.b2_index(base)
with pytest.raises(ValueError):
b2_index_definition(tree, base)
else:
- assert tree.b2_index(base) == pytest.approx(
- b2_index_definition(tree, base)
- )
+ assert tree.b2_index(base) == b2_index_definition(tree, base)
class TestBalancedBinaryOdd:
@@ -172,6 +181,31 @@ def test_b1(self):
def test_b2(self):
assert self.tree().b2_index() == pytest.approx(0.602, rel=1e-3)
+ @pytest.mark.parametrize(
+ ("base", "expected"),
+ [
+ (2, 2),
+ (3, 1.2618595071429148),
+ (4, 1.0),
+ (5, 0.8613531161467861),
+ (10, 0.6020599913279623),
+ (100, 0.30102999566398114),
+ (1000000, 0.10034333188799373),
+ (2.718281828459045, 1.3862943611198906),
+ ],
+ )
+ def test_b2_base(self, base, expected):
+ assert self.tree().b2_index(base) == expected
+
+ @pytest.mark.parametrize("base", [0, -0.001, -1, -1e-6, -1e200])
+ def test_b2_bad_base(self, base):
+ with pytest.raises(ValueError, match="math domain"):
+ self.tree().b2_index(base=base)
+
+ def test_b2_base1(self):
+ with pytest.raises(ZeroDivisionError):
+ self.tree().b2_index(base=1)
+
class TestBalancedTernary:
# 2.00┊ 12 ┊
@@ -279,7 +313,7 @@ def test_b1(self):
assert self.tree().b1_index() == 4.5
def test_b2(self):
- with pytest.raises(ValueError):
+ with pytest.raises(tskit.LibraryError, match="UNDEFINED_MULTIROOT"):
self.tree().b2_index()
@@ -300,7 +334,7 @@ def test_b1(self):
assert self.tree().b1_index() == 0
def test_b2(self):
- with pytest.raises(ValueError):
+ with pytest.raises(tskit.LibraryError, match="UNDEFINED_MULTIROOT"):
self.tree().b2_index()
@@ -322,7 +356,7 @@ def test_b1(self):
assert self.tree().b1_index() == 0
def test_b2(self):
- with pytest.raises(ValueError):
+ with pytest.raises(tskit.LibraryError, match="UNDEFINED_MULTIROOT"):
self.tree().b2_index()
@@ -345,5 +379,5 @@ def test_b1(self):
assert self.tree().b1_index() == 0
def test_b2(self):
- with pytest.raises(ValueError):
+ with pytest.raises(tskit.LibraryError, match="UNDEFINED_MULTIROOT"):
self.tree().b2_index()
diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py
index 553d1fb6e4..058a9beefc 100644
--- a/python/tests/test_lowlevel.py
+++ b/python/tests/test_lowlevel.py
@@ -3194,6 +3194,15 @@ def test_equality(self):
assert not t2.equals(t1)
last_ts = ts
+ def test_b2_errors(self):
+ ts1 = self.get_example_tree_sequence(10)
+ t1 = _tskit.Tree(ts1)
+ t1.first()
+ with pytest.raises(TypeError):
+ t1.get_b2_index()
+ with pytest.raises(TypeError):
+ t1.get_b2_index("asdf")
+
def test_kc_distance_errors(self):
ts1 = self.get_example_tree_sequence(10)
t1 = _tskit.Tree(ts1, options=_tskit.SAMPLE_LISTS)
diff --git a/python/tskit/trees.py b/python/tskit/trees.py
index f99c5f86ee..f3de2208a2 100644
--- a/python/tskit/trees.py
+++ b/python/tskit/trees.py
@@ -2793,7 +2793,7 @@ def path_length(self, u, v):
def b1_index(self):
"""
Returns the
- `B1 balance index `_
+ `B1 balance index `_
for this tree. This is defined as the inverse of the sum of all
longest paths to leaves for each node besides roots.
@@ -2807,11 +2807,12 @@ def b1_index(self):
def b2_index(self, base=10):
"""
- Returns the B2 balance index for this tree.
+ Returns the
+ `B2 balance index `_
+ this tree.
This is defined as the Shannon entropy of the probability
distribution to reach leaves assuming a random walk
- from a root. The base used for default is 10, to match with
- Shao and Sokal (1990).
+ from a root. The default base is 10, following Shao and Sokal (1990).
.. seealso:: See `Shao and Sokal (1990)
`_ for details.
@@ -2821,22 +2822,9 @@ def b2_index(self, base=10):
:return: The B2 balance index.
:rtype: float
"""
- # TODO implement in C
- # Note that this will take into account the number of roots also, by considering
- # them as children of the virtual root.
- if self.num_roots != 1:
- raise ValueError("B2 index is only defined for trees with one root")
- stack = [(self.root, 1)]
- total_proba = 0
- while len(stack) > 0:
- u, path_product = stack.pop()
- if self.is_leaf(u):
- total_proba -= path_product * math.log(path_product, base)
- else:
- path_product *= 1 / self.num_children(u)
- for v in self.children(u):
- stack.append((v, path_product))
- return total_proba
+ # Let Python decide if the base is acceptable
+ math.log(10, base)
+ return self._ll_tree.get_b2_index(base)
def colless_index(self):
"""