diff --git a/python/CHANGELOG.rst b/python/CHANGELOG.rst index 00a4aaf916..a520399904 100644 --- a/python/CHANGELOG.rst +++ b/python/CHANGELOG.rst @@ -63,6 +63,12 @@ - Add Colless tree imbalance index. (:user:`jeremyguez`, :user:`jeromekelleher`, :issue:`2250`, :pr:`2266`, :pr:`2344`). +- Add ``direction`` argument to ``TreeSequence.edge_diffs``, allowing iteration + over diffs in the reverse direction. NOTE: this comes with a ~10% performance + regression as the implementation was moved from C to Python for simplicity + and maintainability. Please open an issue if this affects your application. + (:user:`jeromekelleher`, :user:`benjeffery`, :pr:`2120`). + **Breaking Changes** - The JSON metadata codec now interprets the empty string as an empty object. This means diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index d46aaecfa1..92865dd426 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -143,12 +143,6 @@ typedef struct { tsk_tree_t *tree; } Tree; -typedef struct { - PyObject_HEAD - TreeSequence *tree_sequence; - tsk_diff_iter_t *tree_diff_iterator; -} TreeDiffIterator; - typedef struct { PyObject_HEAD TreeSequence *tree_sequence; @@ -11264,165 +11258,6 @@ static PyTypeObject TreeType = { // clang-format on }; -/*=================================================================== - * TreeDiffIterator - *=================================================================== - */ - -static int -TreeDiffIterator_check_state(TreeDiffIterator *self) -{ - int ret = 0; - if (self->tree_diff_iterator == NULL) { - PyErr_SetString(PyExc_SystemError, "iterator not initialised"); - ret = -1; - } - return ret; -} - -static void -TreeDiffIterator_dealloc(TreeDiffIterator *self) -{ - if (self->tree_diff_iterator != NULL) { - tsk_diff_iter_free(self->tree_diff_iterator); - PyMem_Free(self->tree_diff_iterator); - self->tree_diff_iterator = NULL; - } - Py_XDECREF(self->tree_sequence); - Py_TYPE(self)->tp_free((PyObject *) self); -} - -static int -TreeDiffIterator_init(TreeDiffIterator *self, PyObject *args, PyObject *kwds) -{ - int ret = -1; - int err; - static char *kwlist[] = { "tree_sequence", "include_terminal", NULL }; - TreeSequence *tree_sequence; - int include_terminal = 0; - tsk_flags_t options = 0; - - self->tree_diff_iterator = NULL; - self->tree_sequence = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|p", kwlist, &TreeSequenceType, - &tree_sequence, &include_terminal)) { - goto out; - } - if (include_terminal) { - options |= TSK_INCLUDE_TERMINAL; - } - self->tree_sequence = tree_sequence; - Py_INCREF(self->tree_sequence); - if (TreeSequence_check_state(self->tree_sequence) != 0) { - goto out; - } - self->tree_diff_iterator = PyMem_Malloc(sizeof(tsk_diff_iter_t)); - if (self->tree_diff_iterator == NULL) { - PyErr_NoMemory(); - goto out; - } - memset(self->tree_diff_iterator, 0, sizeof(tsk_diff_iter_t)); - err = tsk_diff_iter_init( - self->tree_diff_iterator, self->tree_sequence->tree_sequence, options); - if (err != 0) { - handle_library_error(err); - goto out; - } - ret = 0; -out: - return ret; -} - -static PyObject * -TreeDiffIterator_next(TreeDiffIterator *self) -{ - PyObject *ret = NULL; - PyObject *out_list = NULL; - PyObject *in_list = NULL; - PyObject *value = NULL; - int err; - double left, right; - tsk_size_t list_size, j; - tsk_edge_list_node_t *record; - tsk_edge_list_t records_out, records_in; - - if (TreeDiffIterator_check_state(self) != 0) { - goto out; - } - err = tsk_diff_iter_next( - self->tree_diff_iterator, &left, &right, &records_out, &records_in); - if (err < 0) { - handle_library_error(err); - goto out; - } - if (err == 1) { - /* out records */ - record = records_out.head; - list_size = 0; - while (record != NULL) { - list_size++; - record = record->next; - } - out_list = PyList_New(list_size); - if (out_list == NULL) { - goto out; - } - record = records_out.head; - j = 0; - while (record != NULL) { - value = make_edge(&record->edge, true); - if (value == NULL) { - goto out; - } - PyList_SET_ITEM(out_list, j, value); - record = record->next; - j++; - } - /* in records */ - record = records_in.head; - list_size = 0; - while (record != NULL) { - list_size++; - record = record->next; - } - in_list = PyList_New(list_size); - if (in_list == NULL) { - goto out; - } - record = records_in.head; - j = 0; - while (record != NULL) { - value = make_edge(&record->edge, true); - if (value == NULL) { - goto out; - } - PyList_SET_ITEM(in_list, j, value); - record = record->next; - j++; - } - ret = Py_BuildValue("(dd)OO", left, right, out_list, in_list); - } -out: - Py_XDECREF(out_list); - Py_XDECREF(in_list); - return ret; -} - -static PyTypeObject TreeDiffIteratorType = { - // clang-format off - PyVarObject_HEAD_INIT(NULL, 0) - .tp_name = "_tskit.TreeDiffIterator", - .tp_basicsize = sizeof(TreeDiffIterator), - .tp_dealloc = (destructor) TreeDiffIterator_dealloc, - .tp_flags = Py_TPFLAGS_DEFAULT, - .tp_doc = "TreeDiffIterator objects", - .tp_iter = PyObject_SelfIter, - .tp_iternext = (iternextfunc) TreeDiffIterator_next, - .tp_init = (initproc) TreeDiffIterator_init, - .tp_new = PyType_GenericNew, - // clang-format on -}; - /*=================================================================== * Variant *=================================================================== @@ -12679,13 +12514,6 @@ PyInit__tskit(void) Py_INCREF(&TreeType); PyModule_AddObject(module, "Tree", (PyObject *) &TreeType); - /* TreeDiffIterator type */ - if (PyType_Ready(&TreeDiffIteratorType) < 0) { - return NULL; - } - Py_INCREF(&TreeDiffIteratorType); - PyModule_AddObject(module, "TreeDiffIterator", (PyObject *) &TreeDiffIteratorType); - /* Variant type */ if (PyType_Ready(&VariantType) < 0) { return NULL; diff --git a/python/tests/__init__.py b/python/tests/__init__.py index 5965ac4af7..cb67f45345 100644 --- a/python/tests/__init__.py +++ b/python/tests/__init__.py @@ -178,41 +178,6 @@ def make_mutation(id_): ) ) - def edge_diffs(self): - M = self._tree_sequence.num_edges - sequence_length = self._tree_sequence.sequence_length - edges = list(self._tree_sequence.edges()) - time = [self._tree_sequence.node(edge.parent).time for edge in edges] - in_order = sorted( - range(M), - key=lambda j: (edges[j].left, time[j], edges[j].parent, edges[j].child), - ) - out_order = sorted( - range(M), - key=lambda j: (edges[j].right, -time[j], -edges[j].parent, -edges[j].child), - ) - j = 0 - k = 0 - left = 0.0 - while j < M or left < sequence_length: - e_out = [] - e_in = [] - while k < M and edges[out_order[k]].right == left: - h = out_order[k] - e_out.append(edges[h]) - k += 1 - while j < M and edges[in_order[j]].left == left: - h = in_order[j] - e_in.append(edges[h]) - j += 1 - right = sequence_length - if j < M: - right = min(right, edges[in_order[j]].left) - if k < M: - right = min(right, edges[out_order[k]].right) - yield (left, right), e_out, e_in - left = right - def trees(self): pt = PythonTree(self._tree_sequence.get_num_nodes()) pt.index = 0 diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 9d894565b5..ec4b8a79c1 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -1411,71 +1411,6 @@ def test_pairwise_diversity(self): for ts in get_example_tree_sequences(): self.verify_pairwise_diversity(ts) - def verify_edge_diffs(self, ts): - pts = tests.PythonTreeSequence(ts) - d1 = list(ts.edge_diffs()) - d2 = list(pts.edge_diffs()) - assert d1 == d2 - - # check that we have the correct set of children at all nodes. - children = collections.defaultdict(set) - trees = iter(ts.trees()) - tree = next(trees) - edge_ids = [] - last_right = 0 - for (left, right), edges_out, edges_in in ts.edge_diffs(): - assert left == last_right - last_right = right - for edge in edges_out: - assert edge == ts.edge(edge.id) - children[edge.parent].remove(edge.child) - for edge in edges_in: - edge_ids.append(edge.id) - assert edge == ts.edge(edge.id) - children[edge.parent].add(edge.child) - while tree.interval.right <= left: - tree = next(trees) - assert left >= tree.interval.left - assert right <= tree.interval.right - for u in tree.nodes(): - if tree.is_internal(u): - assert u in children - assert children[u] == set(tree.children(u)) - # check that we have seen all the edge ids - assert np.array_equal(np.unique(edge_ids), np.arange(0, ts.num_edges)) - - def test_edge_diffs(self): - for ts in get_example_tree_sequences(): - self.verify_edge_diffs(ts) - - def test_edge_diffs_names(self, simple_degree2_ts_fixture): - for val in simple_degree2_ts_fixture.edge_diffs(): - assert len(val) == 3 - assert val[0] == val.interval - assert val[1] == val.edges_out - assert val[2] == val.edges_in - - def test_edge_diffs_include_terminal(self): - for ts in get_example_tree_sequences(): - edges = set() - i = 0 - breakpoints = list(ts.breakpoints()) - for (left, right), e_out, e_in in ts.edge_diffs(include_terminal=True): - assert left == breakpoints[i] - if i == ts.num_trees: - # Last iteration, right==left==sequence_length - assert left == ts.sequence_length - assert right == ts.sequence_length - else: - assert right == breakpoints[i + 1] - for e in e_out: - edges.remove(e.id) - for e in e_in: - edges.add(e.id) - i += 1 - assert i == ts.num_trees + 1 - assert len(edges) == 0 - def verify_edgesets(self, ts): """ Verifies that the edgesets we return are equivalent to the original edges. @@ -2656,6 +2591,85 @@ def test_individual_properties(self, n): self.verify_individual_properties(ts) +class TestEdgeDiffs: + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_correct_trees_forward(self, ts): + parent = np.full(ts.num_nodes + 1, tskit.NULL, dtype=np.int32) + for edge_diff, tree in itertools.zip_longest(ts.edge_diffs(), ts.trees()): + assert edge_diff.interval == tree.interval + for edge in edge_diff.edges_out: + parent[edge.child] = tskit.NULL + for edge in edge_diff.edges_in: + parent[edge.child] = edge.parent + np.testing.assert_array_equal(parent, tree.parent_array) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + def test_correct_trees_reverse(self, ts): + parent = np.full(ts.num_nodes + 1, tskit.NULL, dtype=np.int32) + iterator = itertools.zip_longest( + ts.edge_diffs(direction=tskit.REVERSE), reversed(ts.trees()) + ) + for edge_diff, tree in iterator: + assert edge_diff.interval == tree.interval + for edge in edge_diff.edges_out: + parent[edge.child] = tskit.NULL + for edge in edge_diff.edges_in: + parent[edge.child] = edge.parent + np.testing.assert_array_equal(parent, tree.parent_array) + + def test_elements_are_like_named_tuple(self, simple_degree2_ts_fixture): + for val in simple_degree2_ts_fixture.edge_diffs(): + assert len(val) == 3 + assert val[0] == val.interval + assert val[1] == val.edges_out + assert val[2] == val.edges_in + + @pytest.mark.parametrize("direction", [-6, "forward", None]) + def test_bad_direction(self, direction, simple_degree2_ts_fixture): + ts = simple_degree2_ts_fixture + with pytest.raises(ValueError, match="direction must be"): + ts.edge_diffs(direction=direction) + + @pytest.mark.parametrize("direction", [tskit.FORWARD, tskit.REVERSE]) + def test_edge_properties(self, direction, simple_degree2_ts_fixture): + ts = simple_degree2_ts_fixture + edge_ids = set() + for _, e_out, e_in in ts.edge_diffs(direction=direction): + for edge in e_in: + assert edge.id not in edge_ids + edge_ids.add(edge.id) + assert ts.edge(edge.id) == edge + for edge in e_out: + assert ts.edge(edge.id) == edge + assert edge_ids == set(range(ts.num_edges)) + + @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("direction", [tskit.FORWARD, tskit.REVERSE]) + def test_include_terminal(self, ts, direction): + edges = set() + i = 0 + diffs = ts.edge_diffs(include_terminal=True, direction=direction) + parent = np.full(ts.num_nodes + 1, tskit.NULL, dtype=np.int32) + for (left, right), e_out, e_in in diffs: # noqa: B007 + for e in e_out: + edges.remove(e.id) + parent[e.child] = tskit.NULL + for e in e_in: + edges.add(e.id) + parent[e.child] = e.parent + i += 1 + assert np.all(parent == tskit.NULL) + assert i == ts.num_trees + 1 + assert len(edges) == 0 + # On last iteration, interval is empty + if direction == tskit.FORWARD: + assert left == ts.sequence_length + assert right == ts.sequence_length + else: + assert left == 0 + assert right == 0 + + class TestTreeSequenceMethodSignatures: ts = msprime.simulate(10, random_seed=1234) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 44bbe27ef7..553d1fb6e4 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -2288,34 +2288,6 @@ def f(x): ts.general_stat(W, lambda x: bad_array, 1, ts.get_breakpoints()) -class TestTreeDiffIterator(LowLevelTestCase): - """ - Tests for the low-level tree diff iterator. - """ - - def test_uninitialised_tree_sequence(self): - ts = _tskit.TreeSequence() - with pytest.raises(ValueError): - _tskit.TreeDiffIterator(ts) - - def test_constructor(self): - with pytest.raises(TypeError): - _tskit.TreeDiffIterator() - with pytest.raises(TypeError): - _tskit.TreeDiffIterator(None) - ts = self.get_example_tree_sequence() - before = list(_tskit.TreeDiffIterator(ts)) - iterator = _tskit.TreeDiffIterator(ts) - del ts - # We should keep a reference to the tree sequence. - after = list(iterator) - assert before == after - - def test_iterator(self): - ts = self.get_example_tree_sequence() - self.verify_iterator(_tskit.TreeDiffIterator(ts)) - - class TestVariant(LowLevelTestCase): """ Tests for the Variant class. @@ -2938,7 +2910,6 @@ def test_while_loop_semantics(self): def test_count_all_samples(self): for ts in self.get_example_tree_sequences(): - self.verify_iterator(_tskit.TreeDiffIterator(ts)) st = _tskit.Tree(ts) # Without initialisation we should be 0 samples for every node # that is not a sample. diff --git a/python/tskit/metadata.py b/python/tskit/metadata.py index 3eb80acc46..7274d66c98 100644 --- a/python/tskit/metadata.py +++ b/python/tskit/metadata.py @@ -751,38 +751,44 @@ def __set__(self, row, value): __builtins__object__setattr__(row, "_metadata", value) -def lazy_decode(cls): - """ - Modifies a dataclass such that it lazily decodes metadata, if it is encoded. - If the metadata passed to the constructor is encoded a `metadata_decoder` parameter - must be also be passed. - """ - wrapped_init = cls.__init__ - - # Intercept the init to record the decoder - def new_init(self, *args, metadata_decoder=None, **kwargs): - __builtins__object__setattr__(self, "_metadata_decoder", metadata_decoder) - wrapped_init(self, *args, **kwargs) - - cls.__init__ = new_init - - # Add a descriptor to the class to decode and cache metadata - cls.metadata = _CachedMetadata() - - # Add slots needed to the class - slots = cls.__slots__ - slots.extend(["_metadata", "_metadata_decoder"]) - dict_ = dict() - sloted_members = dict() - for k, v in cls.__dict__.items(): - if k not in slots: - dict_[k] = v - elif not isinstance(v, types.MemberDescriptorType): - sloted_members[k] = v - new_cls = type(cls.__name__, cls.__bases__, dict_) - for k, v in sloted_members.items(): - setattr(new_cls, k, v) - return new_cls +def lazy_decode(own_init=False): + def _lazy_decode(cls): + """ + Modifies a dataclass such that it lazily decodes metadata, if it is encoded. + If the metadata passed to the constructor is encoded a `metadata_decoder` + parameter must be also be passed. + """ + if not own_init: + wrapped_init = cls.__init__ + + # Intercept the init to record the decoder + def new_init(self, *args, metadata_decoder=None, **kwargs): + __builtins__object__setattr__( + self, "_metadata_decoder", metadata_decoder + ) + wrapped_init(self, *args, **kwargs) + + cls.__init__ = new_init + + # Add a descriptor to the class to decode and cache metadata + cls.metadata = _CachedMetadata() + + # Add slots needed to the class + slots = cls.__slots__ + slots.extend(["_metadata", "_metadata_decoder"]) + dict_ = dict() + sloted_members = dict() + for k, v in cls.__dict__.items(): + if k not in slots: + dict_[k] = v + elif not isinstance(v, types.MemberDescriptorType): + sloted_members[k] = v + new_cls = type(cls.__name__, cls.__bases__, dict_) + for k, v in sloted_members.items(): + setattr(new_cls, k, v) + return new_cls + + return _lazy_decode class MetadataProvider: diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 156fbc14f6..eb96a66cd0 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -61,7 +61,7 @@ class NOTSET(metaclass=NotSetMeta): pass -@metadata.lazy_decode +@metadata.lazy_decode() @dataclass(**dataclass_options) class IndividualTableRow(util.Dataclass): """ @@ -97,7 +97,7 @@ def __eq__(self, other): ) -@metadata.lazy_decode +@metadata.lazy_decode() @dataclass(**dataclass_options) class NodeTableRow(util.Dataclass): """ @@ -127,7 +127,7 @@ class NodeTableRow(util.Dataclass): """ -@metadata.lazy_decode +@metadata.lazy_decode() @dataclass(**dataclass_options) class EdgeTableRow(util.Dataclass): """ @@ -157,7 +157,7 @@ class EdgeTableRow(util.Dataclass): """ -@metadata.lazy_decode +@metadata.lazy_decode() @dataclass(**dataclass_options) class MigrationTableRow(util.Dataclass): """ @@ -195,7 +195,7 @@ class MigrationTableRow(util.Dataclass): """ -@metadata.lazy_decode +@metadata.lazy_decode() @dataclass(**dataclass_options) class SiteTableRow(util.Dataclass): """ @@ -217,7 +217,7 @@ class SiteTableRow(util.Dataclass): """ -@metadata.lazy_decode +@metadata.lazy_decode() @dataclass(**dataclass_options) class MutationTableRow(util.Dataclass): """ @@ -268,7 +268,7 @@ def __eq__(self, other): ) -@metadata.lazy_decode +@metadata.lazy_decode() @dataclass(**dataclass_options) class PopulationTableRow(util.Dataclass): """ diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 77571c9a1c..8a2682408c 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -1,4 +1,3 @@ -# # MIT License # # Copyright (c) 2018-2022 Tskit Developers @@ -107,7 +106,7 @@ def new_init(self, *args, tree_sequence=None, **kwargs): @store_tree_sequence -@metadata_module.lazy_decode +@metadata_module.lazy_decode() @dataclass class Individual(util.Dataclass): """ @@ -189,7 +188,7 @@ def __eq__(self, other): ) -@metadata_module.lazy_decode +@metadata_module.lazy_decode() @dataclass class Node(util.Dataclass): """ @@ -239,7 +238,7 @@ def is_sample(self): return self.flags & NODE_IS_SAMPLE -@metadata_module.lazy_decode +@metadata_module.lazy_decode(own_init=True) @dataclass class Edge(util.Dataclass): """ @@ -282,13 +281,23 @@ class Edge(util.Dataclass): """ # Custom init to define default values with slots - def __init__(self, left, right, parent, child, metadata=b"", id=None): # noqa A003 + def __init__( + self, + left, + right, + parent, + child, + metadata=b"", + id=None, # noqa A002 + metadata_decoder=None, + ): self.id = id self.left = left self.right = right self.parent = parent self.child = child self.metadata = metadata + self._metadata_decoder = metadata_decoder @property def span(self): @@ -301,7 +310,7 @@ def span(self): return self.right - self.left -@metadata_module.lazy_decode +@metadata_module.lazy_decode() @dataclass class Site(util.Dataclass): """ @@ -349,7 +358,7 @@ def __eq__(self, other): ) -@metadata_module.lazy_decode +@metadata_module.lazy_decode() @dataclass class Mutation(util.Dataclass): """ @@ -458,7 +467,7 @@ def __eq__(self, other): ) -@metadata_module.lazy_decode +@metadata_module.lazy_decode() @dataclass class Migration(util.Dataclass): """ @@ -509,7 +518,7 @@ class Migration(util.Dataclass): """ -@metadata_module.lazy_decode +@metadata_module.lazy_decode() @dataclass class Population(util.Dataclass): """ @@ -4304,7 +4313,117 @@ def edgesets(self): edgeset.children = sorted(children[edgeset.parent]) yield edgeset - def edge_diffs(self, include_terminal=False): + def _edge_diffs_forward(self, include_terminal=False): + metadata_decoder = self.table_metadata_schemas.edge.decode_row + tables = self.tables + edges = tables.edges + edge_left = edges.left + edge_right = edges.right + sequence_length = self.sequence_length + in_order = tables.indexes.edge_insertion_order + out_order = tables.indexes.edge_removal_order + M = len(edges) + j = 0 + k = 0 + left = 0.0 + while j < M or left < sequence_length: + edges_out = [] + edges_in = [] + while k < M and edge_right[out_order[k]] == left: + edges_out.append( + Edge( + *self._ll_tree_sequence.get_edge(out_order[k]), + id=out_order[k], + metadata_decoder=metadata_decoder, + ) + ) + k += 1 + while j < M and edge_left[in_order[j]] == left: + edges_in.append( + Edge( + *self._ll_tree_sequence.get_edge(in_order[j]), + id=in_order[j], + metadata_decoder=metadata_decoder, + ) + ) + j += 1 + right = sequence_length + if j < M: + right = min(right, edge_left[in_order[j]]) + if k < M: + right = min(right, edge_right[out_order[k]]) + yield EdgeDiff(Interval(left, right), edges_out, edges_in) + left = right + + if include_terminal: + edges_out = [] + while k < M: + edges_out.append( + Edge( + *self._ll_tree_sequence.get_edge(out_order[k]), + id=out_order[k], + metadata_decoder=metadata_decoder, + ) + ) + k += 1 + yield EdgeDiff(Interval(left, right), edges_out, []) + + def _edge_diffs_reverse(self, include_terminal=False): + metadata_decoder = self.table_metadata_schemas.edge.decode_row + tables = self.tables + edges = tables.edges + edge_left = edges.left + edge_right = edges.right + sequence_length = self.sequence_length + in_order = tables.indexes.edge_removal_order + out_order = tables.indexes.edge_insertion_order + M = len(edges) + j = M - 1 + k = M - 1 + right = sequence_length + while j >= 0 or right > 0: + edges_out = [] + edges_in = [] + while k >= 0 and edge_left[out_order[k]] == right: + edges_out.append( + Edge( + *self._ll_tree_sequence.get_edge(out_order[k]), + id=out_order[k], + metadata_decoder=metadata_decoder, + ) + ) + k -= 1 + while j >= 0 and edge_right[in_order[j]] == right: + edges_in.append( + Edge( + *self._ll_tree_sequence.get_edge(in_order[j]), + id=in_order[j], + metadata_decoder=metadata_decoder, + ) + ) + j -= 1 + left = 0 + if j >= 0: + left = max(left, edge_right[in_order[j]]) + if k >= 0: + left = max(left, edge_left[out_order[k]]) + yield EdgeDiff(Interval(left, right), edges_out, edges_in) + right = left + + if include_terminal: + edges_out = [] + while k >= 0: + edges_out.append( + Edge( + *self._ll_tree_sequence.get_edge(out_order[k]), + id=out_order[k], + metadata_decoder=metadata_decoder, + ) + ) + k -= 1 + yield EdgeDiff(Interval(left, right), edges_out, []) + + def edge_diffs(self, include_terminal=False, *, direction=tskit.FORWARD): """ Returns an iterator over all the :ref:`edges ` that are inserted and removed to build the trees as we move from left-to-right along @@ -4323,6 +4442,10 @@ def edge_diffs(self, include_terminal=False): descending parent time, parent id, then child_id). This means that within each list, edges with the same parent appear consecutively. + The ``direction`` argument can be used to control whether diffs are produced + in the forward (left-to-right, increasing genome coordinate value) + or reverse (right-to-left, decreasing genome coordinate value) direction. + :param bool include_terminal: If False (default), the iterator terminates after the final interval in the tree sequence (i.e., it does not report a final removal of all remaining edges), and the number @@ -4330,21 +4453,20 @@ def edge_diffs(self, include_terminal=False): sequence. If True, an additional iteration takes place, with the last ``edges_out`` value reporting all the edges contained in the final tree (with both ``left`` and ``right`` equal to the sequence length). + :param int direction: The direction of travel along the sequence for + diffs. Must be one of :data:`.FORWARD` or :data:`.REVERSE`. + (Default: :data:`.FORWARD`). :return: An iterator over the (interval, edges_out, edges_in) tuples. This is a named tuple, so the 3 values can be accessed by position (e.g. ``returned_tuple[0]``) or name (e.g. ``returned_tuple.interval``). :rtype: :class:`collections.abc.Iterable` """ - iterator = _tskit.TreeDiffIterator(self._ll_tree_sequence, include_terminal) - metadata_decoder = self.table_metadata_schemas.edge.decode_row - for interval, edge_tuples_out, edge_tuples_in in iterator: - edges_out = [ - Edge(*e, metadata_decoder=metadata_decoder) for e in edge_tuples_out - ] - edges_in = [ - Edge(*e, metadata_decoder=metadata_decoder) for e in edge_tuples_in - ] - yield EdgeDiff(Interval(*interval), edges_out, edges_in) + if direction == _tskit.FORWARD: + return self._edge_diffs_forward(include_terminal=include_terminal) + elif direction == _tskit.REVERSE: + return self._edge_diffs_reverse(include_terminal=include_terminal) + else: + raise ValueError("direction must be either tskit.FORWARD or tskit.REVERSE") def sites(self): """