From 448d7dcb1e7b673b48eb39726765039127f98026 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Mon, 22 Sep 2025 13:56:35 +0100 Subject: [PATCH] Add ts.mutations_inherited_state --- python/CHANGELOG.rst | 3 +++ python/tests/test_highlevel.py | 27 ++++++++++++++++++++++++++- python/tskit/trees.py | 24 ++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 584f6ece7f..f82be8f972 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -30,6 +30,9 @@ - Add ``TreeSequence.mutations_edge`` which returns the edge ID for each mutation's edge. (:user:`benjeffery`, :pr:`3226`, :issue:`3189`) +- Add ``TreeSequence.mutations_inherited_state`` which returns the inherited state + for each mutation. (:user:`benjeffery`, :pr:`3276`, :issue:`2631`) + **Bugfixes** - In some tables with mutations out-of-order `TableCollection.sort` did not re-order diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index c4f5aadbce..ac42a5ca9e 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -5546,9 +5546,34 @@ def test_equality_mutations_derived_state(self, ts): [mutation.derived_state for mutation in ts.mutations()], ) + @pytest.mark.skipif(not _tskit.HAS_NUMPY_2, reason="Requires NumPy 2.0 or higher") + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) + def test_mutations_inherited_state(self, ts): + inherited_state = ts.mutations_inherited_state + assert len(inherited_state) == ts.num_mutations + assert isinstance(inherited_state, np.ndarray) + assert inherited_state.shape == (ts.num_mutations,) + assert inherited_state.dtype == np.dtype("T") + assert inherited_state.size == ts.num_mutations + + for mut in ts.mutations(): + state0 = ts.site(mut.site).ancestral_state + if mut.parent != -1: + state0 = ts.mutation(mut.parent).derived_state + assert state0 == inherited_state[mut.id] + + # Test caching - second access should return the same object + inherited_state2 = ts.mutations_inherited_state + assert inherited_state is inherited_state2 + @pytest.mark.skipif(_tskit.HAS_NUMPY_2, reason="Test only on Numpy 1.X") @pytest.mark.parametrize( - "column", ["sites_ancestral_state", "mutations_derived_state"] + "column", + [ + "sites_ancestral_state", + "mutations_derived_state", + "mutations_inherited_state", + ], ) def test_ragged_array_not_supported(self, column): tables = tskit.TableCollection(sequence_length=100) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index d2aea24bad..5874392bd0 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -4140,6 +4140,7 @@ def __init__(self, ll_tree_sequence): self._individuals_location = None self._individuals_nodes = None self._mutations_edge = None + self._mutations_inherited_state = None self._sites_ancestral_state = None self._mutations_derived_state = None # NOTE: when we've implemented read-only access via the underlying @@ -6068,6 +6069,29 @@ def mutations_edge(self): self._mutations_edge = self._ll_tree_sequence.get_mutations_edge() return self._mutations_edge + @property + def mutations_inherited_state(self): + """ + Return an array of the inherited state for each mutation in the tree sequence. + + The inherited state for a mutation is the state that existed at the site + before the mutation occurred. This is either the ancestral state of the site + (if the mutation has no parent) or the derived state of the mutation's + parent mutation (if it has a parent). + + :return: Array of shape (num_mutations,) containing inherited states. + :rtype: numpy.ndarray + """ + if self._mutations_inherited_state is None: + inherited_state = self.sites_ancestral_state[self.mutations_site] + mutations_with_parent = self.mutations_parent != -1 + parent = self.mutations_parent[mutations_with_parent] + inherited_state[mutations_with_parent] = self.mutations_derived_state[ + parent + ] + self._mutations_inherited_state = inherited_state + return self._mutations_inherited_state + @property def migrations_left(self): """