Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions c/tests/test_trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -6629,7 +6629,7 @@ test_single_tree_balance(void)
tsk_treeseq_t ts;
tsk_tree_t t;
tsk_size_t sackin, colless;
double b1;
double b1, b2;

tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL,
NULL, NULL, NULL, 0);
Expand All @@ -6645,6 +6645,13 @@ test_single_tree_balance(void)
CU_ASSERT_EQUAL(colless, 0);
CU_ASSERT_EQUAL_FATAL(tsk_tree_b1_index(&t, &b1), 0);
CU_ASSERT_DOUBLE_EQUAL(b1, 2, 1e-8);
/* Test different bases for b2_index to high-precision */
CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 10, &b2), 0);
CU_ASSERT_DOUBLE_EQUAL(b2, 0.6020599913279623, 1e-14);
CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 2, &b2), 0);
CU_ASSERT_DOUBLE_EQUAL_FATAL(b2, 2, 1e-16);
CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 3, &b2), 0);
CU_ASSERT_DOUBLE_EQUAL_FATAL(b2, 1.2618595071429148, 1e-14);

tsk_treeseq_free(&ts);
tsk_tree_free(&t);
Expand Down Expand Up @@ -6683,6 +6690,7 @@ test_multiroot_balance(void)
CU_ASSERT_EQUAL_FATAL(tsk_tree_colless_index(&t, NULL), TSK_ERR_UNDEFINED_MULTIROOT);
CU_ASSERT_EQUAL_FATAL(tsk_tree_b1_index(&t, &b1), 0);
CU_ASSERT_DOUBLE_EQUAL(b1, 1.0, 1e-8);
CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 10, NULL), TSK_ERR_UNDEFINED_MULTIROOT);

tsk_treeseq_free(&ts);
tsk_tree_free(&t);
Expand All @@ -6701,7 +6709,7 @@ test_nonbinary_balance(void)
tsk_treeseq_t ts;
tsk_tree_t t;
tsk_size_t sackin, colless;
double b1;
double b1, b2;

tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL, 0);
ret = tsk_tree_init(&t, &ts, 0);
Expand All @@ -6716,6 +6724,8 @@ test_nonbinary_balance(void)
tsk_tree_colless_index(&t, &colless), TSK_ERR_UNDEFINED_NONBINARY);
CU_ASSERT_EQUAL_FATAL(tsk_tree_b1_index(&t, &b1), 0);
CU_ASSERT_DOUBLE_EQUAL_FATAL(b1, 0, 1e-8);
CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 10, &b2), 0);
CU_ASSERT_DOUBLE_EQUAL_FATAL(b1, 0, 1e-8);

tsk_treeseq_free(&ts);
tsk_tree_free(&t);
Expand All @@ -6729,7 +6739,7 @@ test_empty_tree_balance(void)
tsk_treeseq_t ts;
tsk_tree_t t;
tsk_size_t sackin, colless;
double b1;
double b1, b2;

ret = tsk_table_collection_init(&tables, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
Expand All @@ -6748,12 +6758,46 @@ test_empty_tree_balance(void)
tsk_tree_colless_index(&t, &colless), TSK_ERR_UNDEFINED_MULTIROOT);
CU_ASSERT_EQUAL_FATAL(tsk_tree_b1_index(&t, &b1), 0);
CU_ASSERT_EQUAL(b1, 0);
CU_ASSERT_EQUAL_FATAL(tsk_tree_b2_index(&t, 10, &b2), TSK_ERR_UNDEFINED_MULTIROOT);

tsk_table_collection_free(&tables);
tsk_treeseq_free(&ts);
tsk_tree_free(&t);
}

static void
test_b2_bad_base(void)
{
int ret;
tsk_treeseq_t ts;
tsk_tree_t t;
double result;
double bad_base[] = { -2, -1, 1 };
size_t j;

tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL,
NULL, NULL, NULL, 0);
ret = tsk_tree_init(&t, &ts, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_tree_first(&t);
CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK);

for (j = 0; j < sizeof(bad_base) / sizeof(*bad_base); j++) {
ret = tsk_tree_b2_index(&t, bad_base[j], &result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_FALSE(tsk_isfinite(result));
}
CU_ASSERT_FATAL(j > 0);

/* this one is peculiar, in that base 0 seems to give a finite answer */
ret = tsk_tree_b2_index(&t, 0, &result);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL_FATAL(result, 0);

tsk_treeseq_free(&ts);
tsk_tree_free(&t);
}

static void
test_tree_errors(void)
{
Expand Down Expand Up @@ -7950,6 +7994,7 @@ main(int argc, char **argv)
{ "test_multiroot_balance", test_multiroot_balance },
{ "test_nonbinary_balance", test_nonbinary_balance },
{ "test_empty_tree_balance", test_empty_tree_balance },
{ "test_b2_bad_base", test_b2_bad_base },

/* Misc */
{ "test_tree_errors", test_tree_errors },
Expand Down
62 changes: 62 additions & 0 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -5009,6 +5009,68 @@ tsk_tree_b1_index(const tsk_tree_t *self, double *result)
return ret;
}

static double
general_log(double x, double base)
{
return log(x) / log(base);
}

int
tsk_tree_b2_index(const tsk_tree_t *self, double base, double *result)
{
struct stack_elem {
tsk_id_t node;
double path_product;
};
int ret = 0;
const tsk_id_t *restrict right_child = self->right_child;
const tsk_id_t *restrict left_sib = self->left_sib;
struct stack_elem *stack
= tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*stack));
int stack_top;
double total_proba = 0;
double num_children;
tsk_id_t u;
struct stack_elem s = { .node = TSK_NULL, .path_product = 1 };

if (stack == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}
if (tsk_tree_get_num_roots(self) != 1) {
ret = TSK_ERR_UNDEFINED_MULTIROOT;
goto out;
}

stack_top = 0;
s.node = tsk_tree_get_left_root(self);
stack[stack_top] = s;

while (stack_top >= 0) {
s = stack[stack_top];
stack_top--;
u = right_child[s.node];
if (u == TSK_NULL) {
total_proba -= s.path_product * general_log(s.path_product, base);
} else {
num_children = 0;
for (; u != TSK_NULL; u = left_sib[u]) {
num_children++;
}
s.path_product *= 1 / num_children;
for (u = right_child[s.node]; u != TSK_NULL; u = left_sib[u]) {
stack_top++;
s.node = u;
stack[stack_top] = s;
}
}
}
*result = total_proba;
out:
tsk_safe_free(stack);
return ret;
}

/* Parsimony methods */

static inline uint64_t
Expand Down
5 changes: 5 additions & 0 deletions c/tskit/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,11 @@ int tsk_tree_kc_distance(
int tsk_tree_sackin_index(const tsk_tree_t *self, tsk_size_t *result);
int tsk_tree_colless_index(const tsk_tree_t *self, tsk_size_t *result);
int tsk_tree_b1_index(const tsk_tree_t *self, double *result);
/* NOTE: if we document this as part of the C API we'll have to be more careful
* about the error behaviour on bad log bases. At the moment we're just returning
* the resulting value which can be nan, inf etc, but some surprising results
* happen like a base 0 seems to return a finite value. */
int tsk_tree_b2_index(const tsk_tree_t *self, double base, double *result);

/* Things to consider removing: */

Expand Down
1 change: 1 addition & 0 deletions docs/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ Functions and static methods
Tree.colless_index
Tree.sackin_index
Tree.b1_index
Tree.b2_index
```

(sec_python_api_trees_sites_mutations)=
Expand Down
3 changes: 3 additions & 0 deletions python/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
- Add B1 tree balance index.
(:user:`jeremyguez`, :user:`jeromekelleher`, :issue:`2251`, :pr:`2281`, :pr:`2346`).

- Add B2 tree balance index.
(:user:`jeremyguez`, :user:`jeromekelleher`, :issue:`2252`, :pr:`2353`, :pr:`2354`).

- Add Sackin tree imbalance index.
(:user:`jeremyguez`, :user:`jeromekelleher`, :pr:`2246`, :pr:`2258`).

Expand Down
28 changes: 28 additions & 0 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -10817,6 +10817,30 @@ Tree_get_b1_index(Tree *self)
return ret;
}

static PyObject *
Tree_get_b2_index(Tree *self, PyObject *args)
{
PyObject *ret = NULL;
int err;
double base;
double result;

if (Tree_check_state(self) != 0) {
goto out;
}
if (!PyArg_ParseTuple(args, "d", &base)) {
goto out;
}
err = tsk_tree_b2_index(self->tree, base, &result);
if (err != 0) {
handle_library_error(err);
goto out;
}
ret = Py_BuildValue("d", result);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming this coverage error is spurious.

out:
return ret;
}

static PyObject *
Tree_get_root_threshold(Tree *self)
{
Expand Down Expand Up @@ -11240,6 +11264,10 @@ static PyMethodDef Tree_methods[] = {
.ml_meth = (PyCFunction) Tree_get_b1_index,
.ml_flags = METH_NOARGS,
.ml_doc = "Returns the B1 index for this tree." },
{ .ml_name = "get_b2_index",
.ml_meth = (PyCFunction) Tree_get_b2_index,
.ml_flags = METH_VARARGS,
.ml_doc = "Returns the B2 index for this tree." },
{ NULL } /* Sentinel */
};

Expand Down
56 changes: 45 additions & 11 deletions python/tests/test_balance_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# we can remove this.


def path(tree, u):
def node_path(tree, u):
path = []
u = tree.parent(u)
while u != tskit.NULL:
Expand Down Expand Up @@ -79,7 +79,7 @@ def b2_index_definition(tree, base=10):
if tree.num_roots != 1:
raise ValueError("B2 index is only defined for trees with one root")
proba = [
np.prod([1 / tree.num_children(u) for u in path(tree, leaf)])
np.prod([1 / tree.num_children(u) for u in node_path(tree, leaf)])
for leaf in tree.leaves()
]
return -sum(p * math.log(p, base) for p in proba)
Expand Down Expand Up @@ -111,18 +111,27 @@ def test_b1(self, ts):
assert tree.b1_index() == pytest.approx(b1_index_definition(tree))

@pytest.mark.parametrize("ts", get_example_tree_sequences())
@pytest.mark.parametrize("base", [2, 10, math.e, np.array([3])[0]])
def test_b2_base(self, ts, base):
def test_b2(self, ts):
for tree in ts.trees():
if tree.num_roots != 1:
with pytest.raises(tskit.LibraryError, match="MULTIROOT"):
tree.b2_index()
with pytest.raises(ValueError):
b2_index_definition(tree)
else:
assert tree.b2_index() == b2_index_definition(tree)

@pytest.mark.parametrize("ts", get_example_tree_sequences())
@pytest.mark.parametrize("base", [0.1, 1.1, 2, 10, math.e, np.array([3])[0]])
def test_b2_base(self, ts, base):
for tree in ts.trees():
if tree.num_roots != 1:
with pytest.raises(tskit.LibraryError, match="MULTIROOT"):
tree.b2_index(base)
with pytest.raises(ValueError):
b2_index_definition(tree, base)
else:
assert tree.b2_index(base) == pytest.approx(
b2_index_definition(tree, base)
)
assert tree.b2_index(base) == b2_index_definition(tree, base)


class TestBalancedBinaryOdd:
Expand Down Expand Up @@ -172,6 +181,31 @@ def test_b1(self):
def test_b2(self):
assert self.tree().b2_index() == pytest.approx(0.602, rel=1e-3)

@pytest.mark.parametrize(
("base", "expected"),
[
(2, 2),
(3, 1.2618595071429148),
(4, 1.0),
(5, 0.8613531161467861),
(10, 0.6020599913279623),
(100, 0.30102999566398114),
(1000000, 0.10034333188799373),
(2.718281828459045, 1.3862943611198906),
],
)
def test_b2_base(self, base, expected):
assert self.tree().b2_index(base) == expected

@pytest.mark.parametrize("base", [0, -0.001, -1, -1e-6, -1e200])
def test_b2_bad_base(self, base):
with pytest.raises(ValueError, match="math domain"):
self.tree().b2_index(base=base)

def test_b2_base1(self):
with pytest.raises(ZeroDivisionError):
self.tree().b2_index(base=1)


class TestBalancedTernary:
# 2.00┊ 12 ┊
Expand Down Expand Up @@ -279,7 +313,7 @@ def test_b1(self):
assert self.tree().b1_index() == 4.5

def test_b2(self):
with pytest.raises(ValueError):
with pytest.raises(tskit.LibraryError, match="UNDEFINED_MULTIROOT"):
self.tree().b2_index()


Expand All @@ -300,7 +334,7 @@ def test_b1(self):
assert self.tree().b1_index() == 0

def test_b2(self):
with pytest.raises(ValueError):
with pytest.raises(tskit.LibraryError, match="UNDEFINED_MULTIROOT"):
self.tree().b2_index()


Expand All @@ -322,7 +356,7 @@ def test_b1(self):
assert self.tree().b1_index() == 0

def test_b2(self):
with pytest.raises(ValueError):
with pytest.raises(tskit.LibraryError, match="UNDEFINED_MULTIROOT"):
self.tree().b2_index()


Expand All @@ -345,5 +379,5 @@ def test_b1(self):
assert self.tree().b1_index() == 0

def test_b2(self):
with pytest.raises(ValueError):
with pytest.raises(tskit.LibraryError, match="UNDEFINED_MULTIROOT"):
self.tree().b2_index()
9 changes: 9 additions & 0 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3194,6 +3194,15 @@ def test_equality(self):
assert not t2.equals(t1)
last_ts = ts

def test_b2_errors(self):
ts1 = self.get_example_tree_sequence(10)
t1 = _tskit.Tree(ts1)
t1.first()
with pytest.raises(TypeError):
t1.get_b2_index()
with pytest.raises(TypeError):
t1.get_b2_index("asdf")

def test_kc_distance_errors(self):
ts1 = self.get_example_tree_sequence(10)
t1 = _tskit.Tree(ts1, options=_tskit.SAMPLE_LISTS)
Expand Down
Loading