Skip to content

Commit

Permalink
core sequences support equality testing. Split._sequences renamed to …
Browse files Browse the repository at this point in the history
…._seqs. Run adapter supports equality testing.
  • Loading branch information
ynikitenko committed Aug 31, 2023
1 parent 2ac895c commit 8b85a93
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 15 deletions.
10 changes: 10 additions & 0 deletions lena/core/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,16 @@ def _fc_run(self, flow):
results = self._el.compute()
return results

def __eq__(self, other):
if not isinstance(other, Run):
# Run(el) != el
return NotImplemented
if self._run_name == "run":
if other._run_name not in (None, "run"):
return False
return (self._el == other._el and
self._run_name == other._run_name)

def __repr__(self):
if self._run_name:
return "Run({}, run={})".format(repr(self._el), self._run_name)
Expand Down
6 changes: 5 additions & 1 deletion lena/core/fill_compute_seq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Sequence with a FillCompute element."""
from __future__ import print_function

from . import lena_sequence
from . import sequence
Expand Down Expand Up @@ -107,3 +106,8 @@ def compute(self):

results = self._after.run(vals)
return results

def __eq__(self, other):
if not isinstance(other, FillComputeSeq):
return NotImplemented
return self._seq == other._seq
6 changes: 5 additions & 1 deletion lena/core/fill_seq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""FillSeq sequence and its helpers."""
from __future__ import print_function

from . import lena_sequence
from . import adapters
Expand Down Expand Up @@ -89,3 +88,8 @@ def fill(self, value):
this sequence, and after that fills the last element.
"""
raise exceptions.LenaNotImplementedError

def __eq__(self, other):
if not isinstance(other, FillSeq):
return NotImplemented
return self._seq == other._seq
5 changes: 5 additions & 0 deletions lena/core/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,8 @@ def run(self, flow):
# Most important is that the function is evaluated immediately,
# and raises in case of errors.
return flow

def __eq__(self, other):
if not isinstance(other, Sequence):
return NotImplemented
return self._seq == other._seq
13 changes: 11 additions & 2 deletions lena/core/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ def __init__(self, *args):
Following arguments (if present) form a sequence of elements,
each accepting computational flow from the previous element.
>>> from lena.flow import CountFrom
>>> s = Source(CountFrom())
>>> from lena.flow import CountFrom, Slice
>>> s = Source(CountFrom(), Slice(5))
>>> # iterate in a cycle
>>> for i in s():
... if i == 5:
... break
... print(i, end=" ")
0 1 2 3 4
>>> # if called twice, results depend on the generator
>>> list(s()) == list(range(5, 10))
True
For a *sequence* which transforms the incoming flow,
use :class:`Sequence`.
Expand Down Expand Up @@ -56,3 +60,8 @@ def __call__(self):
return self._sequence.run(arg)
else:
return functions.flow_to_iter(arg)

def __eq__(self, other):
if not isinstance(other, Source):
return NotImplemented
return self._seq == other._seq
27 changes: 16 additions & 11 deletions lena/core/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, seqs, bufsize=1000, copy_buf=True):
"{} provided".format(seqs)
)
seqs = [meta.alter_sequence(seq) for seq in seqs]
self._sequences = []
self._seqs = []
self._seq_types = []

for sequence in seqs:
Expand All @@ -119,7 +119,7 @@ def __init__(self, seqs, bufsize=1000, copy_buf=True):
"FillComputeSeq, FillRequestSeq or Source, "
"{} provided".format(sequence)
)
self._sequences.append(seq)
self._seqs.append(seq)
self._seq_types.append(seq_type)

different_seq_types = set(self._seq_types)
Expand Down Expand Up @@ -156,31 +156,31 @@ def __call__(self):
:class:`.Source`,
otherwise runtime :exc:`.LenaAttributeError` is raised.
"""
if self._n_seq_types != 1 or not ct.is_source(self._sequences[0]):
if self._n_seq_types != 1 or not ct.is_source(self._seqs[0]):
raise exceptions.LenaAttributeError(
"Split has no method '__call__'. It should contain "
"only Source sequences to be callable"
)
# todo: use itertools.chain and check performance difference
for seq in self._sequences:
for seq in self._seqs:
for result in seq():
yield result

def _fill(self, val):
for seq in self._sequences[:-1]:
for seq in self._seqs[:-1]:
if self._copy_buf:
seq.fill(copy.deepcopy(val))
else:
seq.fill(val)
self._sequences[-1].fill(val)
self._seqs[-1].fill(val)

def _compute(self):
for seq in self._sequences:
for seq in self._seqs:
for val in seq.compute():
yield val

def _request(self):
for seq in self._sequences:
for seq in self._seqs:
for val in seq.request():
yield val

Expand Down Expand Up @@ -222,7 +222,7 @@ def run(self, flow):
then the buffer for each sequence except the last one is a deep copy
of the current buffer.
"""
active_seqs = self._sequences[:]
active_seqs = self._seqs[:]
active_seq_types = self._seq_types[:]

n_of_active_seqs = len(active_seqs)
Expand Down Expand Up @@ -339,9 +339,9 @@ def repr_maybe_nested(el, base_indent, indent):
elems = el_separ.join((repr_maybe_nested(el, base_indent=base_indent,
indent=indent)
# diff here
for el in self._sequences))
for el in self._seqs))

if "\n" in el_separ and self._sequences:
if "\n" in el_separ and self._seqs:
# maybe new line
mnl = "\n"
# maybe base indent
Expand All @@ -353,5 +353,10 @@ def repr_maybe_nested(el, base_indent, indent):
return "".join([base_indent, "Split",
"([", mnl, elems, mnl, mbi, "])"])

def __eq__(self, other):
if not isinstance(other, Split):
return NotImplemented
return self._seqs == other._seqs

def __repr__(self):
return self._repr_nested()

0 comments on commit 8b85a93

Please sign in to comment.