From 6aaa04b3046fd6efa05e73d4694f1f6d915e35b9 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 17 Jun 2022 11:18:26 +0100 Subject: [PATCH] Simplify population semantics in split_edges Closes #2334 --- c/tests/test_trees.c | 35 --------------------------- c/tskit/trees.c | 13 +++------- c/tskit/trees.h | 2 -- python/_tskitmodule.c | 19 +++------------ python/tests/test_lowlevel.py | 10 ++------ python/tests/test_table_transforms.py | 18 +++----------- python/tskit/trees.py | 17 ++++--------- 7 files changed, 19 insertions(+), 95 deletions(-) diff --git a/c/tests/test_trees.c b/c/tests/test_trees.c index 392c29bff5..fc679abb8e 100644 --- a/c/tests/test_trees.c +++ b/c/tests/test_trees.c @@ -7576,25 +7576,6 @@ test_split_edges_no_populations(void) } tsk_treeseq_free(&split_ts); - /* And again with imputed population value */ - ret = tsk_treeseq_split_edges(&ts, time, 1234, 0, metadata, strlen(metadata), - TSK_SPLIT_EDGES_IMPUTE_POPULATION, &split_ts); - CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&split_ts), 3); - CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&split_ts), 12); - - for (j = 0; j < num_new_nodes; j++) { - ret = tsk_treeseq_get_node(&split_ts, new_nodes[j], &node); - CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(node.time, time); - CU_ASSERT_EQUAL(node.flags, 1234); - CU_ASSERT_EQUAL(node.individual, TSK_NULL); - CU_ASSERT_EQUAL(node.population, TSK_NULL); - CU_ASSERT_EQUAL(node.metadata_length, strlen(metadata)); - CU_ASSERT_EQUAL(strncmp(node.metadata, metadata, strlen(metadata)), 0); - } - tsk_treeseq_free(&split_ts); - tsk_table_collection_free(&tables); tsk_treeseq_free(&ts); } @@ -7639,16 +7620,6 @@ test_split_edges_populations(void) CU_ASSERT_EQUAL_FATAL(ret, 0); CU_ASSERT_EQUAL(node.population, population); tsk_treeseq_free(&split_ts); - - ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, - TSK_SPLIT_EDGES_IMPUTE_POPULATION, &split_ts); - CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(tsk_treeseq_get_num_trees(&split_ts), 1); - CU_ASSERT_EQUAL(tsk_treeseq_get_num_nodes(&split_ts), 3); - ret = tsk_treeseq_get_node(&split_ts, 2, &node); - CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL(node.population, 0); - tsk_treeseq_free(&split_ts); } tsk_table_collection_free(&tables); @@ -7693,12 +7664,6 @@ test_split_edges_errors(void) ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, 0, &split_ts); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); tsk_treeseq_free(&split_ts); - - /* We always check population values, even if they aren't used */ - ret = tsk_treeseq_split_edges(&ts, time, 0, population, NULL, 0, - TSK_SPLIT_EDGES_IMPUTE_POPULATION, &split_ts); - CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_POPULATION_OUT_OF_BOUNDS); - tsk_treeseq_free(&split_ts); } tsk_treeseq_free(&ts); diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 1e7f51ef8d..a4ad97d02f 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3318,12 +3318,11 @@ tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, int TSK_WARN_UNUSED tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags, tsk_id_t population, const char *metadata, tsk_size_t metadata_length, - tsk_flags_t options, tsk_treeseq_t *output) + tsk_flags_t TSK_UNUSED(options), tsk_treeseq_t *output) { int ret = 0; tsk_table_collection_t *tables = tsk_malloc(sizeof(*tables)); const double *restrict node_time = self->tables->nodes.time; - const tsk_id_t *restrict node_population = self->tables->nodes.population; const tsk_size_t num_edges = self->tables->edges.num_rows; const tsk_size_t num_mutations = self->tables->mutations.num_rows; tsk_id_t *split_edge = tsk_malloc(num_edges * sizeof(*split_edge)); @@ -3332,7 +3331,6 @@ tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flag tsk_edge_t edge; tsk_mutation_t mutation; tsk_bookmark_t sort_start; - bool impute_population = options & TSK_SPLIT_EDGES_IMPUTE_POPULATION; memset(output, 0, sizeof(*output)); if (split_edge == NULL) { @@ -3347,6 +3345,9 @@ tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flag ret = TSK_ERR_MIGRATIONS_NOT_SUPPORTED; goto out; } + /* We could catch this below in add_row, but it's simpler to guarantee + * that we always catch the error in corner cases where the values + * aren't used. */ if (population < -1 || population >= (tsk_id_t) self->tables->populations.num_rows) { ret = TSK_ERR_POPULATION_OUT_OF_BOUNDS; goto out; @@ -3365,12 +3366,6 @@ tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flag ret = tsk_edge_table_get_row(&self->tables->edges, j, &edge); tsk_bug_assert(ret == 0); if (node_time[edge.child] < time && time < node_time[edge.parent]) { - if (impute_population) { - population = TSK_NULL; - if (node_population[edge.child] != TSK_NULL) { - population = node_population[edge.child]; - } - } u = tsk_node_table_add_row(&tables->nodes, flags, time, population, TSK_NULL, metadata, metadata_length); if (u < 0) { diff --git a/c/tskit/trees.h b/c/tskit/trees.h index cbc4f7ca98..dcc07d491c 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -886,8 +886,6 @@ int tsk_treeseq_simplify(const tsk_treeseq_t *self, const tsk_id_t *samples, /** @} */ -#define TSK_SPLIT_EDGES_IMPUTE_POPULATION (1 << 1) - int tsk_treeseq_split_edges(const tsk_treeseq_t *self, double time, tsk_flags_t flags, tsk_id_t population, const char *metadata, tsk_size_t metadata_length, tsk_flags_t options, tsk_treeseq_t *output); diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index b40f9421c7..a70c2ba73d 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -9466,38 +9466,27 @@ static PyObject * TreeSequence_split_edges(TreeSequence *self, PyObject *args, PyObject *kwds) { PyObject *ret = NULL; - static char *kwlist[] - = { "time", "flags", "population", "metadata", "impute_population", NULL }; + static char *kwlist[] = { "time", "flags", "population", "metadata", NULL }; double time; tsk_flags_t flags; tsk_id_t population; - int impute_population; PyObject *py_metadata = Py_None; char *metadata; Py_ssize_t metadata_length; int err; - tsk_flags_t options = 0; TreeSequence *output = NULL; if (TreeSequence_check_state(self) != 0) { goto out; } - /* NOTE: we could make most of these arguments optional and reason about - * impute_population through the value of the population parameter, - * but we're trying to keep this code as simple as possible and put - * the logic in the high-level Python code */ - if (!PyArg_ParseTupleAndKeywords(args, kwds, "dO&O&Oi", kwlist, &time, - &uint32_converter, &flags, &tsk_id_converter, &population, &py_metadata, - &impute_population)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "dO&O&O", kwlist, &time, + &uint32_converter, &flags, &tsk_id_converter, &population, &py_metadata)) { goto out; } if (PyBytes_AsStringAndSize(py_metadata, &metadata, &metadata_length) < 0) { goto out; } - if (impute_population) { - options |= TSK_SPLIT_EDGES_IMPUTE_POPULATION; - } output = (TreeSequence *) _PyObject_New((PyTypeObject *) &TreeSequenceType); if (output == NULL) { @@ -9509,7 +9498,7 @@ TreeSequence_split_edges(TreeSequence *self, PyObject *args, PyObject *kwds) goto out; } err = tsk_treeseq_split_edges(self->tree_sequence, time, flags, population, metadata, - metadata_length, options, output->tree_sequence); + metadata_length, 0, output->tree_sequence); if (err != 0) { handle_library_error(err); goto out; diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index ca2b96165d..44bbe27ef7 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1544,21 +1544,18 @@ def test_discrete_time(self): def test_split_edges_return_type(self): ts = self.get_example_tree_sequence() - split = ts.split_edges( - time=0, flags=0, population=0, metadata=b"", impute_population=True - ) + split = ts.split_edges(time=0, flags=0, population=0, metadata=b"") assert isinstance(split, _tskit.TreeSequence) def test_split_edges_bad_types(self): ts = self.get_example_tree_sequence() - def f(time=0, flags=0, population=0, metadata=b"", impute_population=False): + def f(time=0, flags=0, population=0, metadata=b""): return ts.split_edges( time=time, flags=flags, population=population, metadata=metadata, - impute_population=impute_population, ) with pytest.raises(TypeError): @@ -1567,8 +1564,6 @@ def f(time=0, flags=0, population=0, metadata=b"", impute_population=False): f(flags="0") with pytest.raises(TypeError): f(metadata="0") - with pytest.raises(TypeError): - f(impute_population="0") def test_split_edges_bad_population(self): ts = self.get_example_tree_sequence() @@ -1578,7 +1573,6 @@ def test_split_edges_bad_population(self): flags=0, population=ts.get_num_populations(), metadata=b"", - impute_population=False, ) diff --git a/python/tests/test_table_transforms.py b/python/tests/test_table_transforms.py index a137c8d074..8c7250c453 100644 --- a/python/tests/test_table_transforms.py +++ b/python/tests/test_table_transforms.py @@ -324,23 +324,16 @@ def test_older(self, time): def split_edges_definition(ts, time, *, flags=0, population=None, metadata=None): + population = -1 if population is None else population tables = ts.dump_tables() if ts.num_migrations > 0: raise ValueError("Migrations not supported") - default_population = population is None - if not default_population: - # -1 is a valid value - if population < -1 or population >= ts.num_populations: - raise ValueError("Population out of bounds") node_time = tables.nodes.time - node_population = tables.nodes.population tables.edges.clear() split_edge = np.full(ts.num_edges, tskit.NULL, dtype=int) for edge in ts.edges(): if node_time[edge.child] < time < node_time[edge.parent]: - if default_population: - population = node_population[edge.child] u = tables.nodes.add_row( flags=flags, time=time, population=population, metadata=metadata ) @@ -673,7 +666,7 @@ def ts_with_schema(self): def test_default_population(self): ts = self.ts().split_edges(0.5) - assert ts.node(2).population == 0 + assert ts.node(2).population == -1 @pytest.mark.parametrize("population", range(-1, 5)) def test_specify_population(self, population): @@ -712,8 +705,7 @@ def decapitate_definition(ts, time, *, flags=0, population=None, metadata=None): """ Simple loop implementation of the decapitate operation """ - default_population = population is None - + population = -1 if population is None else population tables = ts.dump_tables() node_time = tables.nodes.time tables.edges.clear() @@ -721,8 +713,6 @@ def decapitate_definition(ts, time, *, flags=0, population=None, metadata=None): if node_time[edge.parent] <= time: tables.edges.append(edge) elif node_time[edge.child] < time: - if default_population: - population = tables.nodes[edge.child].population new_parent = tables.nodes.add_row( time=time, population=population, flags=flags, metadata=metadata ) @@ -1137,7 +1127,7 @@ def ts_with_schema(self): def test_default_population(self): ts = self.ts().decapitate(0.5) - assert ts.node(2).population == 0 + assert ts.node(2).population == tskit.NULL @pytest.mark.parametrize("population", range(-1, 5)) def test_specify_population(self, population): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index aa848e3760..e8600d1db8 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -6035,9 +6035,8 @@ def split_edges(self, time, *, flags=None, population=None, metadata=None): added nodes will be assigned these values. Otherwise, default values will be used. The default metadata is an empty dictionary if a metadata schema is defined for the node table, and is an empty byte string - otherwise. The default population for the new node will be derived from - the population of the edge's child. Newly added have a default - ``flags`` value of 0. + otherwise. The default population for the new node is + :data:`tskit.NULL`. Newly added have a default ``flags`` value of 0. Any metadata associated with a split edge will be copied to the new edge. @@ -6055,18 +6054,14 @@ def split_edges(self, time, *, flags=None, population=None, metadata=None): :param float time: The cutoff time. :param int flags: The flags value for newly-inserted nodes. (Default = 0) :param int population: The population value for newly inserted nodes. - Defaults to the population of the child node of the split edge - if not specified. + Defaults to ``tskit.NULL`` if not specified. :param metadata: The metadata for any newly inserted nodes. See :meth:`.NodeTable.add_row` for details on how default metadata is produced for a given schema (or none). :return: A copy of this tree sequence with edges split at the specified time. :rtype: tskit.TreeSequence """ - impute_population = False - if population is None: - impute_population = True - population = tskit.NULL + population = tskit.NULL if population is None else population flags = 0 if flags is None else flags schema = self.table_metadata_schemas.node if metadata is None: @@ -6077,7 +6072,6 @@ def split_edges(self, time, *, flags=None, population=None, metadata=None): flags=flags, population=population, metadata=metadata, - impute_population=impute_population, ) return TreeSequence(ll_ts) @@ -6112,8 +6106,7 @@ def decapitate(self, time, *, flags=None, population=None, metadata=None): :param float time: The cutoff time. :param int flags: The flags value for newly-inserted nodes. (Default = 0) :param int population: The population value for newly inserted nodes. - Defaults to the population of the child node of the split edge - if not specified. + Defaults to ``tskit.NULL`` if not specified. :param metadata: The metadata for any newly inserted nodes. See :meth:`.NodeTable.add_row` for details on how default metadata is produced for a given schema (or none).