diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 0f640b4570..623469fc40 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -16,6 +16,9 @@ - Add ``contig_id`` and ``isolated_as_missing`` to ``VcfModelMapping`` (:user:`benjeffery`, :pr:`3219`, :issue:`3177`) +- Add ``TreeSequence.mutations_edge`` which returns the edge ID for each mutation's + edge. (:user:`benjeffery`, :pr:`3226`, :issue:`3189`) + **Bugfixes** - Fix bug in ``TreeSequence.pair_coalescence_counts`` when ``span_normalise=True`` diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 66d53a6489..b358ae91ab 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -8780,6 +8780,34 @@ TreeSequence_get_individuals_nodes(TreeSequence *self) return ret; } +static PyObject * +TreeSequence_get_mutations_edge(TreeSequence *self) +{ + PyObject *ret = NULL; + PyArrayObject *array = NULL; + npy_intp num_mutations; + tsk_size_t j; + tsk_id_t *data; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + + num_mutations = (npy_intp) tsk_treeseq_get_num_mutations(self->tree_sequence); + array = (PyArrayObject *) PyArray_SimpleNew(1, &num_mutations, NPY_INT32); + if (array == NULL) { + goto out; + } + + data = (tsk_id_t *) PyArray_DATA(array); + for (j = 0; j < (tsk_size_t) num_mutations; j++) { + data[j] = self->tree_sequence->site_mutations_mem[j].edge; + } + ret = (PyObject *) array; +out: + return ret; +} + static PyObject * TreeSequence_genealogical_nearest_neighbours( TreeSequence *self, PyObject *args, PyObject *kwds) @@ -11484,6 +11512,10 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_get_individuals_nodes, .ml_flags = METH_NOARGS, .ml_doc = "Returns an array of the node ids for each individual" }, + { .ml_name = "get_mutations_edge", + .ml_meth = (PyCFunction) TreeSequence_get_mutations_edge, + .ml_flags = METH_NOARGS, + .ml_doc = "Returns an array of the edge ids of each mutation's edge" }, { .ml_name = "genealogical_nearest_neighbours", .ml_meth = (PyCFunction) TreeSequence_genealogical_nearest_neighbours, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 7d98d2184a..a5835a4803 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -5993,3 +5993,9 @@ def test_isolated_as_missing(self): result = ts.map_to_vcf_model() assert result.isolated_as_missing is True + + +@pytest.mark.parametrize("ts", get_example_tree_sequences()) +def test_mutations_edge(ts): + for mut, mut_edge in itertools.zip_longest(ts.mutations(), ts.mutations_edge): + assert mut.edge == mut_edge diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 49dadd301f..ffdff495a2 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1949,20 +1949,22 @@ def test_array_lifetime(self, name, ts_fixture): a2[:] = 0 assert a3 is not a2 - def test_individuals_nodes(self, ts_fixture): + @pytest.mark.parametrize("name", ("individuals_nodes", "mutations_edge")) + def test_generated_columns(self, ts_fixture, name): + name = f"get_{name}" ts_fixture = ts_fixture.ll_tree_sequence # Properties - a = ts_fixture.get_individuals_nodes() + a = getattr(ts_fixture, name)() assert a.flags.aligned assert a.flags.c_contiguous assert a.flags.owndata - b = ts_fixture.get_individuals_nodes() + b = getattr(ts_fixture, name)() assert a is not b assert np.all(a == b) # Lifetime - a1 = ts_fixture.get_individuals_nodes() + a1 = getattr(ts_fixture, name)() a2 = a1.copy() assert a1 is not a2 del ts_fixture diff --git a/python/tskit/trees.py b/python/tskit/trees.py index cdbdd4e806..59cf5a26ae 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -4137,6 +4137,7 @@ def __init__(self, ll_tree_sequence): self._individuals_population = None self._individuals_location = None self._individuals_nodes = None + self._mutations_edge = None # NOTE: when we've implemented read-only access via the underlying # tables we can replace these arrays with reference to the read-only # tables here (and remove the low-level boilerplate). @@ -6021,6 +6022,18 @@ def mutations_metadata(self): self._mutations_metadata ) + @property + def mutations_edge(self): + """ + Return an array of the ID of the edge each mutation sits on in the tree sequence. + + :return: Array of shape (num_mutations,) containing edge IDs. + :rtype: numpy.ndarray (dtype=np.int32) + """ + if self._mutations_edge is None: + self._mutations_edge = self._ll_tree_sequence.get_mutations_edge() + return self._mutations_edge + @property def migrations_left(self): """