Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions c/tests/test_trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -6391,6 +6391,90 @@ test_genealogical_nearest_neighbours_errors(void)
free(A);
}

static void
test_single_tree_balance(void)
{
int ret;
tsk_treeseq_t ts;
tsk_tree_t t;
tsk_size_t sackin;

tsk_treeseq_from_text(&ts, 1, single_tree_ex_nodes, single_tree_ex_edges, NULL, NULL,
NULL, NULL, NULL, 0);
ret = tsk_tree_init(&t, &ts, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_tree_first(&t);
CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK);

/* Balanced binary tree with 4 leaves */
CU_ASSERT_EQUAL_FATAL(tsk_tree_sackin_index(&t, &sackin), 0);
CU_ASSERT_EQUAL(sackin, 8);

tsk_treeseq_free(&ts);
tsk_tree_free(&t);
}

static void
test_multiroot_balance(void)
{
int ret;
tsk_treeseq_t ts;
tsk_tree_t t;
tsk_size_t sackin;

tsk_treeseq_from_text(&ts, 10, multiroot_ex_nodes, multiroot_ex_edges, NULL, NULL,
NULL, NULL, NULL, 0);
ret = tsk_tree_init(&t, &ts, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_tree_first(&t);
CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK);

/* 0.80┊ 10 */
/* ┊ ┏┻┓ */
/* 0.40┊ 9 ┃ ┃ */
/* ┊ ┏━┻┓ ┃ ┃ */
/* 0.30┊ ┃ ┃ ┃ ┃ */
/* ┊ ┃ ┃ ┃ ┃ */
/* 0.20┊ ┃ 7 ┃ ┃ */
/* ┊ ┃ ┏┻┓ ┃ ┃ */
/* 0.10┊ ┃ ┃ ┃ ┃ ┃ */
/* ┊ ┃ ┃ ┃ ┃ ┃ */
/* 0.00┊ 5 2 3 4 0 1 */

CU_ASSERT_EQUAL_FATAL(tsk_tree_sackin_index(&t, &sackin), 0);
CU_ASSERT_EQUAL(sackin, 7);

tsk_treeseq_free(&ts);
tsk_tree_free(&t);
}

static void
test_empty_tree_balance(void)
{
int ret;
tsk_table_collection_t tables;
tsk_treeseq_t ts;
tsk_tree_t t;
tsk_size_t sackin;

ret = tsk_table_collection_init(&tables, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
tables.sequence_length = 1.0;
ret = tsk_treeseq_init(&ts, &tables, TSK_TS_INIT_BUILD_INDEXES);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_tree_init(&t, &ts, 0);
CU_ASSERT_EQUAL_FATAL(ret, 0);
ret = tsk_tree_first(&t);
CU_ASSERT_EQUAL_FATAL(ret, TSK_TREE_OK);

CU_ASSERT_EQUAL_FATAL(tsk_tree_sackin_index(&t, &sackin), 0);
CU_ASSERT_EQUAL(sackin, 0);

tsk_table_collection_free(&tables);
tsk_treeseq_free(&ts);
tsk_tree_free(&t);
}

static void
test_tree_errors(void)
{
Expand Down Expand Up @@ -7213,6 +7297,11 @@ main(int argc, char **argv)
{ "test_different_number_trees_kc", test_different_number_trees_kc },
{ "test_offset_trees_with_errors_kc", test_offset_trees_with_errors_kc },

/* Tree balance/imbalance index tests */
{ "test_single_tree_balance", test_single_tree_balance },
{ "test_multiroot_balance", test_multiroot_balance },
{ "test_empty_tree_balance", test_empty_tree_balance },

/* Misc */
{ "test_tree_errors", test_tree_errors },
{ "test_tree_copy_flags", test_tree_copy_flags },
Expand Down
59 changes: 59 additions & 0 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -4644,6 +4644,65 @@ tsk_tree_postorder_from(
return ret;
}

/* Balance/imbalance metrics */

/* Result is a tsk_size_t value here because we could imagine the total
* depth overflowing a 32bit integer for a large tree. */
int
tsk_tree_sackin_index(const tsk_tree_t *self, tsk_size_t *result)
{
/* Keep the size of the stack elements to 8 bytes in total in the
* standard case. A tsk_id_t depth value is always safe, since
* depth counts the number of nodes encountered on a path.
*/
struct stack_elem {
tsk_id_t node;
tsk_id_t depth;
};
int ret = 0;
const tsk_id_t *restrict right_child = self->right_child;
const tsk_id_t *restrict left_sib = self->left_sib;
struct stack_elem *stack
= tsk_malloc(tsk_tree_get_size_bound(self) * sizeof(*stack));
int stack_top;
tsk_size_t total_depth;
tsk_id_t u;
struct stack_elem s = { .node = TSK_NULL, .depth = 0 };

if (stack == NULL) {
ret = TSK_ERR_NO_MEMORY;
goto out;
}

stack_top = -1;
for (u = right_child[self->virtual_root]; u != TSK_NULL; u = left_sib[u]) {
stack_top++;
s.node = u;
stack[stack_top] = s;
}
total_depth = 0;
while (stack_top >= 0) {
s = stack[stack_top];
stack_top--;
u = right_child[s.node];
if (u == TSK_NULL) {
total_depth += (tsk_size_t) s.depth;
} else {
s.depth++;
while (u != TSK_NULL) {
stack_top++;
s.node = u;
stack[stack_top] = s;
u = left_sib[u];
}
}
}
*result = total_depth;
out:
tsk_safe_free(stack);
return ret;
}

/* Parsimony methods */

static inline uint64_t
Expand Down
4 changes: 4 additions & 0 deletions c/tskit/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,10 @@ int tsk_tree_map_mutations(tsk_tree_t *self, int32_t *genotypes, double *cost_ma
int tsk_tree_kc_distance(
const tsk_tree_t *self, const tsk_tree_t *other, double lambda, double *result);

/* Don't document these balance metrics for now so it doesn't get in the way of
* C API 1.0, but should be straightforward to document based on Python docs. */
int tsk_tree_sackin_index(const tsk_tree_t *self, tsk_size_t *result);

/* Things to consider removing: */

/* This is redundant, really */
Expand Down
25 changes: 25 additions & 0 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -10568,6 +10568,27 @@ Tree_get_kc_distance(Tree *self, PyObject *args, PyObject *kwds)
return ret;
}

static PyObject *
Tree_get_sackin_index(Tree *self)
{
PyObject *ret = NULL;
int err;
tsk_size_t result;

if (Tree_check_state(self) != 0) {
goto out;
}

err = tsk_tree_sackin_index(self->tree, &result);
if (err != 0) {
handle_library_error(err);
goto out;
}
ret = Py_BuildValue("K", (unsigned long long) result);
out:
return ret;
}

static PyObject *
Tree_get_root_threshold(Tree *self)
{
Expand Down Expand Up @@ -10963,6 +10984,10 @@ static PyMethodDef Tree_methods[] = {
.ml_meth = (PyCFunction) Tree_get_postorder,
.ml_flags = METH_VARARGS,
.ml_doc = "Returns the nodes in this tree in postorder." },
{ .ml_name = "get_sackin_index",
.ml_meth = (PyCFunction) Tree_get_sackin_index,
.ml_flags = METH_NOARGS,
.ml_doc = "Returns the root threshold for this tree." },
{ NULL } /* Sentinel */
};

Expand Down
11 changes: 11 additions & 0 deletions python/tests/test_balance_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,17 @@ def test_sackin(self):
assert self.tree().sackin_index() == 0


class TestTreeInNullState:
@tests.cached_example
def tree(self):
tree = tskit.Tree.generate_comb(5)
tree.clear()
return tree

def test_sackin(self):
assert self.tree().sackin_index() == 0


class TestAllRootsN5:
@tests.cached_example
def tree(self):
Expand Down
12 changes: 1 addition & 11 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -2826,17 +2826,7 @@ def sackin_index(self):
:return: The Sackin imbalance index.
:rtype: int
"""
# TODO implement in C
stack = [(root, 0) for root in self.roots]
total_depth = 0
while len(stack) > 0:
u, depth = stack.pop()
if self.is_leaf(u):
total_depth += depth
else:
for v in self.children(u):
stack.append((v, depth + 1))
return total_depth
return self._ll_tree.get_sackin_index()

def split_polytomies(
self,
Expand Down