From 16166da3743989be4cfcf1608fae100f83eef5d0 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 21 Jun 2022 11:03:13 +0100 Subject: [PATCH] Implement B2 index in C Closes #2252 Closes #2256 --- c/tests/test_trees.c | 51 +++++++++++++++++++++-- c/tskit/trees.c | 62 ++++++++++++++++++++++++++++ c/tskit/trees.h | 5 +++ docs/python-api.md | 1 + python/CHANGELOG.rst | 3 ++ python/_tskitmodule.c | 28 +++++++++++++ python/tests/test_balance_metrics.py | 56 ++++++++++++++++++++----- python/tests/test_lowlevel.py | 9 ++++ python/tskit/trees.py | 28 ++++--------- 9 files changed, 209 insertions(+), 34 deletions(-) diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 8e14d57e7c..11f2ac2b4b 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -6629,7 +6629,7 @@ test_single_tree_balance(void) tsk_treeseq_t ts; tsk_tree_t t; tsk_size_t sackin, colless; - double b1; + double b1, b2; tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, NULL, NULL, NULL, 0); @@ -6645,6 +6645,13 @@ test_single_tree_balance(void) CU_ASSERT_EQUAL(colless, 0); CU_ASSERT_EQUAL_FATAL(tsk_tree_b1_index(&t, &b1), 0); CU_ASSERT_DOUBLE_EQUAL(b1, 2, 1e-8); + /* Test different bases for b2_index to high-precision */ + CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 10, &b2), 0); + CU_ASSERT_DOUBLE_EQUAL(b2, 0.6020599913279623, 1e-14); + CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 2, &b2), 0); + CU_ASSERT_DOUBLE_EQUAL_FATAL(b2, 2, 1e-16); + CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 3, &b2), 0); + CU_ASSERT_DOUBLE_EQUAL_FATAL(b2, 1.2618595071429148, 1e-14); tsk_treeseq_free(&ts); tsk_tree_free(&t); @@ -6683,6 +6690,7 @@ test_multiroot_balance(void) CU_ASSERT_EQUAL_FATAL(tsk_tree_colless_index(&t, NULL), TSK_ERR_UNDEFINED_MULTIROOT); CU_ASSERT_EQUAL_FATAL(tsk_tree_b1_index(&t, &b1), 0); CU_ASSERT_DOUBLE_EQUAL(b1, 1.0, 1e-8); + CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 10, NULL), TSK_ERR_UNDEFINED_MULTIROOT); tsk_treeseq_free(&ts); tsk_tree_free(&t); @@ -6701,7 +6709,7 @@ test_nonbinary_balance(void) tsk_treeseq_t ts; tsk_tree_t t; tsk_size_t sackin, colless; - double b1; + double b1, b2; tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0); ret = tsk_tree_init(&t, &ts, 0); @@ -6716,6 +6724,8 @@ test_nonbinary_balance(void) tsk_tree_colless_index(&t, &colless), TSK_ERR_UNDEFINED_NONBINARY); CU_ASSERT_EQUAL_FATAL(tsk_tree_b1_index(&t, &b1), 0); CU_ASSERT_DOUBLE_EQUAL_FATAL(b1, 0, 1e-8); + CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 10, &b2), 0); + CU_ASSERT_DOUBLE_EQUAL_FATAL(b1, 0, 1e-8); tsk_treeseq_free(&ts); tsk_tree_free(&t); @@ -6729,7 +6739,7 @@ test_empty_tree_balance(void) tsk_treeseq_t ts; tsk_tree_t t; tsk_size_t sackin, colless; - double b1; + double b1, b2; ret = tsk_table_collection_init(&tables, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); @@ -6748,12 +6758,46 @@ test_empty_tree_balance(void) tsk_tree_colless_index(&t, &colless), TSK_ERR_UNDEFINED_MULTIROOT); CU_ASSERT_EQUAL_FATAL(tsk_tree_b1_index(&t, &b1), 0); CU_ASSERT_EQUAL(b1, 0); + CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 10, &b2), TSK_ERR_UNDEFINED_MULTIROOT); tsk_table_collection_free(&tables); tsk_treeseq_free(&ts); tsk_tree_free(&t); } +static void +test_b2_bad_base(void) +{ + int ret; + tsk_treeseq_t ts; + tsk_tree_t t; + double result; + double bad_base[] = { -2, -1, 1 }; + size_t j; + + tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL, + NULL, NULL, NULL, 0); + ret = tsk_tree_init(&t, &ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = tsk_tree_first(&t); + CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK); + + for (j = 0; j < sizeof(bad_base) / sizeof(*bad_base); j++) { + ret = tsk_tree_b2_index(&t, bad_base[j], &result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_FALSE(tsk_isfinite(result)); + } + CU_ASSERT_FATAL(j > 0); + + /* this one is peculiar, in that base 0 seems to give a finite answer */ + ret = tsk_tree_b2_index(&t, 0, &result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + CU_ASSERT_EQUAL_FATAL(result, 0); + + tsk_treeseq_free(&ts); + tsk_tree_free(&t); +} + static void test_tree_errors(void) { @@ -7950,6 +7994,7 @@ main(int argc, char **argv) { "test_multiroot_balance", test_multiroot_balance }, { "test_nonbinary_balance", test_nonbinary_balance }, { "test_empty_tree_balance", test_empty_tree_balance }, + { "test_b2_bad_base", test_b2_bad_base }, /* Misc */ { "test_tree_errors", test_tree_errors }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index b921254ff8..92cc789c5b 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -5009,6 +5009,68 @@ tsk_tree_b1_index(const tsk_tree_t *self, double *result) return ret; } +static double +general_log(double x, double base) +{ + return log(x) / log(base); +} + +int +tsk_tree_b2_index(const tsk_tree_t *self, double base, double *result) +{ + struct stack_elem { + tsk_id_t node; + double path_product; + }; + int ret = 0; + const tsk_id_t *restrict right_child = self->right_child; + const tsk_id_t *restrict left_sib = self->left_sib; + struct stack_elem *stack + = tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*stack)); + int stack_top; + double total_proba = 0; + double num_children; + tsk_id_t u; + struct stack_elem s = { .node = TSK_NULL, .path_product = 1 }; + + if (stack == NULL) { + ret = TSK_ERR_NO_MEMORY; + goto out; + } + if (tsk_tree_get_num_roots(self) != 1) { + ret = TSK_ERR_UNDEFINED_MULTIROOT; + goto out; + } + + stack_top = 0; + s.node = tsk_tree_get_left_root(self); + stack[stack_top] = s; + + while (stack_top >= 0) { + s = stack[stack_top]; + stack_top--; + u = right_child[s.node]; + if (u == TSK_NULL) { + total_proba -= s.path_product * general_log(s.path_product, base); + } else { + num_children = 0; + for (; u != TSK_NULL; u = left_sib[u]) { + num_children++; + } + s.path_product *= 1 / num_children; + for (u = right_child[s.node]; u != TSK_NULL; u = left_sib[u]) { + stack_top++; + s.node = u; + stack[stack_top] = s; + } + } + } + *result = total_proba; +out: + tsk_safe_free(stack); + return ret; +} + /* Parsimony methods */ static inline uint64_t diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 6cd03dc986..509e287a03 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1696,6 +1696,11 @@ int tsk_tree_kc_distance( int tsk_tree_sackin_index(const tsk_tree_t *self, tsk_size_t *result); int tsk_tree_colless_index(const tsk_tree_t *self, tsk_size_t *result); int tsk_tree_b1_index(const tsk_tree_t *self, double *result); +/* NOTE: if we document this as part of the C API we'll have to be more careful + * about the error behaviour on bad log bases. At the moment we're just returning + * the resulting value which can be nan, inf etc, but some surprising results + * happen like a base 0 seems to return a finite value. */ +int tsk_tree_b2_index(const tsk_tree_t *self, double base, double *result); /* Things to consider removing: */ diff --git a/docs/python-api.md b/docs/python-api.md index 79a44c05cc..7aad6f155a 100644 --- a/docs/python-api.md +++ b/docs/python-api.md @@ -555,6 +555,7 @@ Functions and static methods Tree.colless_index Tree.sackin_index Tree.b1_index + Tree.b2_index ``` (sec_python_api_trees_sites_mutations)= diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index ff74095ab6..94039b9a22 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -57,6 +57,9 @@ - Add B1 tree balance index. (:user:`jeremyguez`, :user:`jeromekelleher`, :issue:`2251`, :pr:`2281`, :pr:`2346`). +- Add B2 tree balance index. + (:user:`jeremyguez`, :user:`jeromekelleher`, :issue:`2252`, :pr:`2353`, :pr:`2354`). + - Add Sackin tree imbalance index. (:user:`jeremyguez`, :user:`jeromekelleher`, :pr:`2246`, :pr:`2258`). diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 92865dd426..b00de231c9 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -10817,6 +10817,30 @@ Tree_get_b1_index(Tree *self) return ret; } +static PyObject * +Tree_get_b2_index(Tree *self, PyObject *args) +{ + PyObject *ret = NULL; + int err; + double base; + double result; + + if (Tree_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "d", &base)) { + goto out; + } + err = tsk_tree_b2_index(self->tree, base, &result); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("d", result); +out: + return ret; +} + static PyObject * Tree_get_root_threshold(Tree *self) { @@ -11240,6 +11264,10 @@ static PyMethodDef Tree_methods[] = { .ml_meth = (PyCFunction) Tree_get_b1_index, .ml_flags = METH_NOARGS, .ml_doc = "Returns the B1 index for this tree." }, + { .ml_name = "get_b2_index", + .ml_meth = (PyCFunction) Tree_get_b2_index, + .ml_flags = METH_VARARGS, + .ml_doc = "Returns the B2 index for this tree." }, { NULL } /* Sentinel */ }; diff --git a/python/tests/test_balance_metrics.py b/python/tests/test_balance_metrics.py index 2b31753d75..f3c16790ea 100644 --- a/python/tests/test_balance_metrics.py +++ b/python/tests/test_balance_metrics.py @@ -35,7 +35,7 @@ # we can remove this. -def path(tree, u): +def node_path(tree, u): path = [] u = tree.parent(u) while u != tskit.NULL: @@ -79,7 +79,7 @@ def b2_index_definition(tree, base=10): if tree.num_roots != 1: raise ValueError("B2 index is only defined for trees with one root") proba = [ - np.prod([1 / tree.num_children(u) for u in path(tree, leaf)]) + np.prod([1 / tree.num_children(u) for u in node_path(tree, leaf)]) for leaf in tree.leaves() ] return -sum(p * math.log(p, base) for p in proba) @@ -111,18 +111,27 @@ def test_b1(self, ts): assert tree.b1_index() == pytest.approx(b1_index_definition(tree)) @pytest.mark.parametrize("ts", get_example_tree_sequences()) - @pytest.mark.parametrize("base", [2, 10, math.e, np.array([3])[0]]) - def test_b2_base(self, ts, base): + def test_b2(self, ts): for tree in ts.trees(): if tree.num_roots != 1: + with pytest.raises(tskit.LibraryError, match="MULTIROOT"): + tree.b2_index() with pytest.raises(ValueError): + b2_index_definition(tree) + else: + assert tree.b2_index() == b2_index_definition(tree) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("base", [0.1, 1.1, 2, 10, math.e, np.array([3])[0]]) + def test_b2_base(self, ts, base): + for tree in ts.trees(): + if tree.num_roots != 1: + with pytest.raises(tskit.LibraryError, match="MULTIROOT"): tree.b2_index(base) with pytest.raises(ValueError): b2_index_definition(tree, base) else: - assert tree.b2_index(base) == pytest.approx( - b2_index_definition(tree, base) - ) + assert tree.b2_index(base) == b2_index_definition(tree, base) class TestBalancedBinaryOdd: @@ -172,6 +181,31 @@ def test_b1(self): def test_b2(self): assert self.tree().b2_index() == pytest.approx(0.602, rel=1e-3) + @pytest.mark.parametrize( + ("base", "expected"), + [ + (2, 2), + (3, 1.2618595071429148), + (4, 1.0), + (5, 0.8613531161467861), + (10, 0.6020599913279623), + (100, 0.30102999566398114), + (1000000, 0.10034333188799373), + (2.718281828459045, 1.3862943611198906), + ], + ) + def test_b2_base(self, base, expected): + assert self.tree().b2_index(base) == expected + + @pytest.mark.parametrize("base", [0, -0.001, -1, -1e-6, -1e200]) + def test_b2_bad_base(self, base): + with pytest.raises(ValueError, match="math domain"): + self.tree().b2_index(base=base) + + def test_b2_base1(self): + with pytest.raises(ZeroDivisionError): + self.tree().b2_index(base=1) + class TestBalancedTernary: # 2.00┊ 12 ┊ @@ -279,7 +313,7 @@ def test_b1(self): assert self.tree().b1_index() == 4.5 def test_b2(self): - with pytest.raises(ValueError): + with pytest.raises(tskit.LibraryError, match="UNDEFINED_MULTIROOT"): self.tree().b2_index() @@ -300,7 +334,7 @@ def test_b1(self): assert self.tree().b1_index() == 0 def test_b2(self): - with pytest.raises(ValueError): + with pytest.raises(tskit.LibraryError, match="UNDEFINED_MULTIROOT"): self.tree().b2_index() @@ -322,7 +356,7 @@ def test_b1(self): assert self.tree().b1_index() == 0 def test_b2(self): - with pytest.raises(ValueError): + with pytest.raises(tskit.LibraryError, match="UNDEFINED_MULTIROOT"): self.tree().b2_index() @@ -345,5 +379,5 @@ def test_b1(self): assert self.tree().b1_index() == 0 def test_b2(self): - with pytest.raises(ValueError): + with pytest.raises(tskit.LibraryError, match="UNDEFINED_MULTIROOT"): self.tree().b2_index() diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 553d1fb6e4..058a9beefc 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -3194,6 +3194,15 @@ def test_equality(self): assert not t2.equals(t1) last_ts = ts + def test_b2_errors(self): + ts1 = self.get_example_tree_sequence(10) + t1 = _tskit.Tree(ts1) + t1.first() + with pytest.raises(TypeError): + t1.get_b2_index() + with pytest.raises(TypeError): + t1.get_b2_index("asdf") + def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) t1 = _tskit.Tree(ts1, options=_tskit.SAMPLE_LISTS) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index f99c5f86ee..f3de2208a2 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -2793,7 +2793,7 @@ def path_length(self, u, v): def b1_index(self): """ Returns the - `B1 balance index `_ + `B1 balance index `_ for this tree. This is defined as the inverse of the sum of all longest paths to leaves for each node besides roots. @@ -2807,11 +2807,12 @@ def b1_index(self): def b2_index(self, base=10): """ - Returns the B2 balance index for this tree. + Returns the + `B2 balance index `_ + this tree. This is defined as the Shannon entropy of the probability distribution to reach leaves assuming a random walk - from a root. The base used for default is 10, to match with - Shao and Sokal (1990). + from a root. The default base is 10, following Shao and Sokal (1990). .. seealso:: See `Shao and Sokal (1990) `_ for details. @@ -2821,22 +2822,9 @@ def b2_index(self, base=10): :return: The B2 balance index. :rtype: float """ - # TODO implement in C - # Note that this will take into account the number of roots also, by considering - # them as children of the virtual root. - if self.num_roots != 1: - raise ValueError("B2 index is only defined for trees with one root") - stack = [(self.root, 1)] - total_proba = 0 - while len(stack) > 0: - u, path_product = stack.pop() - if self.is_leaf(u): - total_proba -= path_product * math.log(path_product, base) - else: - path_product *= 1 / self.num_children(u) - for v in self.children(u): - stack.append((v, path_product)) - return total_proba + # Let Python decide if the base is acceptable + math.log(10, base) + return self._ll_tree.get_b2_index(base) def colless_index(self): """