Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions c/tests/test_trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -7576,25 +7576,6 @@ test_split_edges_no_populations(void)
}
tsk_treeseq_free(&split_ts);

/* And again with imputed population value */
ret = tsk_treeseq_split_edges(&ts, time, 1234, 0, metadata, strlen(metadata),
TSK_SPLIT_EDGES_IMPUTE_POPULATION, &split_ts);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&split_ts), 3);
CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&split_ts), 12);

for (j = 0; j < num_new_nodes; j++) {
ret = tsk_treeseq_get_node(&split_ts, new_nodes[j], &node);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL(node.time, time);
CU_ASSERT_EQUAL(node.flags, 1234);
CU_ASSERT_EQUAL(node.individual, TSK_NULL);
CU_ASSERT_EQUAL(node.population, TSK_NULL);
CU_ASSERT_EQUAL(node.metadata_length, strlen(metadata));
CU_ASSERT_EQUAL(strncmp(node.metadata, metadata, strlen(metadata)), 0);
}
tsk_treeseq_free(&split_ts);

tsk_table_collection_free(&tables);
tsk_treeseq_free(&ts);
}
Expand Down Expand Up @@ -7639,16 +7620,6 @@ test_split_edges_populations(void)
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL(node.population, population);
tsk_treeseq_free(&split_ts);

ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0,
TSK_SPLIT_EDGES_IMPUTE_POPULATION, &split_ts);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&split_ts), 1);
CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&split_ts), 3);
ret = tsk_treeseq_get_node(&split_ts, 2, &node);
CU_ASSERT_EQUAL_FATAL(ret, 0);
CU_ASSERT_EQUAL(node.population, 0);
tsk_treeseq_free(&split_ts);
}

tsk_table_collection_free(&tables);
Expand Down Expand Up @@ -7693,12 +7664,6 @@ test_split_edges_errors(void)
ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, 0, &split_ts);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS);
tsk_treeseq_free(&split_ts);

/* We always check population values, even if they aren't used */
ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0,
TSK_SPLIT_EDGES_IMPUTE_POPULATION, &split_ts);
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS);
tsk_treeseq_free(&split_ts);
}
tsk_treeseq_free(&ts);

Expand Down
13 changes: 4 additions & 9 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -3318,12 +3318,11 @@ tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples,
int TSK_WARN_UNUSED
tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags,
tsk_id_t population, const char *metadata, tsk_size_t metadata_length,
tsk_flags_t options, tsk_treeseq_t *output)
tsk_flags_t TSK_UNUSED(options), tsk_treeseq_t *output)
{
int ret = 0;
tsk_table_collection_t *tables = tsk_malloc(sizeof(*tables));
const double *restrict node_time = self->tables->nodes.time;
const tsk_id_t *restrict node_population = self->tables->nodes.population;
const tsk_size_t num_edges = self->tables->edges.num_rows;
const tsk_size_t num_mutations = self->tables->mutations.num_rows;
tsk_id_t *split_edge = tsk_malloc(num_edges * sizeof(*split_edge));
Expand All @@ -3332,7 +3331,6 @@ tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flag
tsk_edge_t edge;
tsk_mutation_t mutation;
tsk_bookmark_t sort_start;
bool impute_population = options & TSK_SPLIT_EDGES_IMPUTE_POPULATION;

memset(output, 0, sizeof(*output));
if (split_edge == NULL) {
Expand All @@ -3347,6 +3345,9 @@ tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flag
ret = TSK_ERR_MIGRATIONS_NOT_SUPPORTED;
goto out;
}
/* We could catch this below in add_row, but it's simpler to guarantee
* that we always catch the error in corner cases where the values
* aren't used. */
if (population < -1 || population >= (tsk_id_t) self->tables->populations.num_rows) {
ret = TSK_ERR_POPULATION_OUT_OF_BOUNDS;
goto out;
Expand All @@ -3365,12 +3366,6 @@ tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flag
ret = tsk_edge_table_get_row(&self->tables->edges, j, &edge);
tsk_bug_assert(ret == 0);
if (node_time[edge.child] < time && time < node_time[edge.parent]) {
if (impute_population) {
population = TSK_NULL;
if (node_population[edge.child] != TSK_NULL) {
population = node_population[edge.child];
}
}
u = tsk_node_table_add_row(&tables->nodes, flags, time, population, TSK_NULL,
metadata, metadata_length);
if (u < 0) {
Expand Down
2 changes: 0 additions & 2 deletions c/tskit/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -886,8 +886,6 @@ int tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples,

/** @} */

#define TSK_SPLIT_EDGES_IMPUTE_POPULATION (1 << 1)

int tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags,
tsk_id_t population, const char *metadata, tsk_size_t metadata_length,
tsk_flags_t options, tsk_treeseq_t *output);
Expand Down
19 changes: 4 additions & 15 deletions python/_tskitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -9466,38 +9466,27 @@ static PyObject *
TreeSequence_split_edges(TreeSequence *self, PyObject *args, PyObject *kwds)
{
PyObject *ret = NULL;
static char *kwlist[]
= { "time", "flags", "population", "metadata", "impute_population", NULL };
static char *kwlist[] = { "time", "flags", "population", "metadata", NULL };
double time;
tsk_flags_t flags;
tsk_id_t population;
int impute_population;
PyObject *py_metadata = Py_None;
char *metadata;
Py_ssize_t metadata_length;
int err;
tsk_flags_t options = 0;
TreeSequence *output = NULL;

if (TreeSequence_check_state(self) != 0) {
goto out;
}
/* NOTE: we could make most of these arguments optional and reason about
* impute_population through the value of the population parameter,
* but we're trying to keep this code as simple as possible and put
* the logic in the high-level Python code */
if (!PyArg_ParseTupleAndKeywords(args, kwds, "dO&O&Oi", kwlist, &time,
&uint32_converter, &flags, &tsk_id_converter, &population, &py_metadata,
&impute_population)) {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "dO&O&O", kwlist, &time,
&uint32_converter, &flags, &tsk_id_converter, &population, &py_metadata)) {
goto out;
}

if (PyBytes_AsStringAndSize(py_metadata, &metadata, &metadata_length) < 0) {
goto out;
}
if (impute_population) {
options |= TSK_SPLIT_EDGES_IMPUTE_POPULATION;
}

output = (TreeSequence *) _PyObject_New((PyTypeObject *) &TreeSequenceType);
if (output == NULL) {
Expand All @@ -9509,7 +9498,7 @@ TreeSequence_split_edges(TreeSequence *self, PyObject *args, PyObject *kwds)
goto out;
}
err = tsk_treeseq_split_edges(self->tree_sequence, time, flags, population, metadata,
metadata_length, options, output->tree_sequence);
metadata_length, 0, output->tree_sequence);
if (err != 0) {
handle_library_error(err);
goto out;
Expand Down
10 changes: 2 additions & 8 deletions python/tests/test_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,21 +1544,18 @@ def test_discrete_time(self):

def test_split_edges_return_type(self):
ts = self.get_example_tree_sequence()
split = ts.split_edges(
time=0, flags=0, population=0, metadata=b"", impute_population=True
)
split = ts.split_edges(time=0, flags=0, population=0, metadata=b"")
assert isinstance(split, _tskit.TreeSequence)

def test_split_edges_bad_types(self):
ts = self.get_example_tree_sequence()

def f(time=0, flags=0, population=0, metadata=b"", impute_population=False):
def f(time=0, flags=0, population=0, metadata=b""):
return ts.split_edges(
time=time,
flags=flags,
population=population,
metadata=metadata,
impute_population=impute_population,
)

with pytest.raises(TypeError):
Expand All @@ -1567,8 +1564,6 @@ def f(time=0, flags=0, population=0, metadata=b"", impute_population=False):
f(flags="0")
with pytest.raises(TypeError):
f(metadata="0")
with pytest.raises(TypeError):
f(impute_population="0")

def test_split_edges_bad_population(self):
ts = self.get_example_tree_sequence()
Expand All @@ -1578,7 +1573,6 @@ def test_split_edges_bad_population(self):
flags=0,
population=ts.get_num_populations(),
metadata=b"",
impute_population=False,
)


Expand Down
18 changes: 4 additions & 14 deletions python/tests/test_table_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,23 +324,16 @@ def test_older(self, time):


def split_edges_definition(ts, time, *, flags=0, population=None, metadata=None):
population = -1 if population is None else population
tables = ts.dump_tables()
if ts.num_migrations > 0:
raise ValueError("Migrations not supported")
default_population = population is None
if not default_population:
# -1 is a valid value
if population < -1 or population >= ts.num_populations:
raise ValueError("Population out of bounds")

node_time = tables.nodes.time
node_population = tables.nodes.population
tables.edges.clear()
split_edge = np.full(ts.num_edges, tskit.NULL, dtype=int)
for edge in ts.edges():
if node_time[edge.child] < time < node_time[edge.parent]:
if default_population:
population = node_population[edge.child]
u = tables.nodes.add_row(
flags=flags, time=time, population=population, metadata=metadata
)
Expand Down Expand Up @@ -673,7 +666,7 @@ def ts_with_schema(self):

def test_default_population(self):
ts = self.ts().split_edges(0.5)
assert ts.node(2).population == 0
assert ts.node(2).population == -1

@pytest.mark.parametrize("population", range(-1, 5))
def test_specify_population(self, population):
Expand Down Expand Up @@ -712,17 +705,14 @@ def decapitate_definition(ts, time, *, flags=0, population=None, metadata=None):
"""
Simple loop implementation of the decapitate operation
"""
default_population = population is None

population = -1 if population is None else population
tables = ts.dump_tables()
node_time = tables.nodes.time
tables.edges.clear()
for edge in ts.edges():
if node_time[edge.parent] <= time:
tables.edges.append(edge)
elif node_time[edge.child] < time:
if default_population:
population = tables.nodes[edge.child].population
new_parent = tables.nodes.add_row(
time=time, population=population, flags=flags, metadata=metadata
)
Expand Down Expand Up @@ -1137,7 +1127,7 @@ def ts_with_schema(self):

def test_default_population(self):
ts = self.ts().decapitate(0.5)
assert ts.node(2).population == 0
assert ts.node(2).population == tskit.NULL

@pytest.mark.parametrize("population", range(-1, 5))
def test_specify_population(self, population):
Expand Down
17 changes: 5 additions & 12 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -6035,9 +6035,8 @@ def split_edges(self, time, *, flags=None, population=None, metadata=None):
added nodes will be assigned these values. Otherwise, default values
will be used. The default metadata is an empty dictionary if a metadata
schema is defined for the node table, and is an empty byte string
otherwise. The default population for the new node will be derived from
the population of the edge's child. Newly added have a default
``flags`` value of 0.
otherwise. The default population for the new node is
:data:`tskit.NULL`. Newly added have a default ``flags`` value of 0.

Any metadata associated with a split edge will be copied to the new edge.

Expand All @@ -6055,18 +6054,14 @@ def split_edges(self, time, *, flags=None, population=None, metadata=None):
:param float time: The cutoff time.
:param int flags: The flags value for newly-inserted nodes. (Default = 0)
:param int population: The population value for newly inserted nodes.
Defaults to the population of the child node of the split edge
if not specified.
Defaults to ``tskit.NULL`` if not specified.
:param metadata: The metadata for any newly inserted nodes. See
:meth:`.NodeTable.add_row` for details on how default metadata
is produced for a given schema (or none).
:return: A copy of this tree sequence with edges split at the specified time.
:rtype: tskit.TreeSequence
"""
impute_population = False
if population is None:
impute_population = True
population = tskit.NULL
population = tskit.NULL if population is None else population
flags = 0 if flags is None else flags
schema = self.table_metadata_schemas.node
if metadata is None:
Expand All @@ -6077,7 +6072,6 @@ def split_edges(self, time, *, flags=None, population=None, metadata=None):
flags=flags,
population=population,
metadata=metadata,
impute_population=impute_population,
)
return TreeSequence(ll_ts)

Expand Down Expand Up @@ -6112,8 +6106,7 @@ def decapitate(self, time, *, flags=None, population=None, metadata=None):
:param float time: The cutoff time.
:param int flags: The flags value for newly-inserted nodes. (Default = 0)
:param int population: The population value for newly inserted nodes.
Defaults to the population of the child node of the split edge
if not specified.
Defaults to ``tskit.NULL`` if not specified.
:param metadata: The metadata for any newly inserted nodes. See
:meth:`.NodeTable.add_row` for details on how default metadata
is produced for a given schema (or none).
Expand Down