Skip to content

Commit

Permalink
flow.MapGroup issues a warning when its sequence produces empty resul…
Browse files Browse the repository at this point in the history
…ts. Rearrange code.
  • Loading branch information
ynikitenko committed Sep 27, 2023
1 parent 1107169 commit a154829
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 45 deletions.
97 changes: 53 additions & 44 deletions lena/flow/group_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,48 @@
"""

import copy
import warnings

import lena.core
import lena.flow


# group common context transform should update value context
def _update_with_group(context, new_grp_context, old_inter_context):
# can context.output.changed be any different value?
context_changed = lena.context.get_recursively(
context, "output.changed", None
)
# copied from GroupPlots
all_changed = set(
(lena.context.get_recursively(c, "output.changed", None)
for c in new_grp_context)
)
all_changed.add(context_changed)
if any(all_changed):
changed = True
elif False in all_changed:
# at least one is not changed
# (this is known, not None)
changed = False
else:
changed = None
# output.changed is unlikely in the intersection,
# but it will work if so.
if changed is not None:
lena.context.update_recursively(
context, "output.changed", changed
)

new_inter_context = lena.context.intersection(*new_grp_context)
context_update = lena.context.difference(new_inter_context,
old_inter_context)
# hopefully there is no "group" in these context intersection.
lena.context.update_recursively(context,
copy.deepcopy(context_update))
context["group"] = new_grp_context


class MapGroup(object):
"""Apply a sequence to groups."""

Expand Down Expand Up @@ -122,11 +159,10 @@ def run(self, flow):
if "group" not in context or not hasattr(data, "__iter__"):
if not self._map_scalars:
yield val
continue

# process scalars
for res in self._seq.run([val]):
yield res
else:
# process scalars
for res in self._seq.run([val]):
yield res
continue

if len(data) != len(context["group"]):
Expand All @@ -139,8 +175,8 @@ def run(self, flow):

# apply seq to each value from group
new_vals = []
for i in range(len(data)):
res_i = list(self._seq.run([(data[i], context["group"][i])]))
for i, dt in enumerate(data):
res_i = list(self._seq.run([(dt, context["group"][i])]))
new_vals.append(res_i)

# check that new values have same length
Expand All @@ -160,48 +196,21 @@ def run(self, flow):
new_group_context = [get_context(val[i]) for val in new_vals]
results.append((new_data, new_group_context))

# group common context transform should update value context
def update_with_group(context, new_grp_context, old_inter_context):
# can context.output.changed be any different value?
context_changed = lena.context.get_recursively(
context, "output.changed", None
)
# copied from GroupPlots
all_changed = set(
(lena.context.get_recursively(c, "output.changed", None)
for c in new_grp_context)
)
all_changed.add(context_changed)
if any(all_changed):
changed = True
elif False in all_changed:
# at least one is not changed
# (this is known, not None)
changed = False
else:
changed = None
# output.changed is unlikely in the intersection,
# but it will work if so.
if changed is not None:
lena.context.update_recursively(
context, "output.changed", changed
)

new_inter_context = lena.context.intersection(*new_grp_context)
context_update = lena.context.difference(new_inter_context,
old_inter_context)
# hopefully there is no "group" in these context intersection.
lena.context.update_recursively(context,
copy.deepcopy(context_update))
context["group"] = new_grp_context

for new_val in results[:-1]:
newc = copy.deepcopy(context)
update_with_group(newc, new_val[1], old_inter_context)
_update_with_group(newc, new_val[1], old_inter_context)
yield (new_val[0], newc)

if not results:
warnings.warn(
"empty results produced in MapGroup({}) for {}"\
.format(self._seq, context),
RuntimeWarning, stacklevel=2
)
continue

# save one deep copy if there is only one resulting value
update_with_group(context, results[-1][1], old_inter_context)
_update_with_group(context, results[-1][1], old_inter_context)
yield (results[-1][0], context)


Expand Down
11 changes: 10 additions & 1 deletion tests/flow/test_group_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,16 @@ def run(self, flow):
with pytest.raises(lena.core.LenaRuntimeError):
assert list(mg4.run([([1, 2], {"group": [{}, {}]})]))

class RunBadly():

def run(self, flow):
if False:
yield "smth"

mg5 = MapGroup(RunBadly())
with pytest.warns(RuntimeWarning):
assert list(mg5.run([grp2])) == []


def test_group_plots():
data = [1, 2]
Expand Down Expand Up @@ -138,7 +148,6 @@ def tp(data):
gp = GroupPlots(tp, lambda _: True, transform=(), yield_selected=False)
# groups are yielded in arbitrary order (because they are in a dict)
results = list(gp.run(data))
print(results)
expected_results = [
(
[
Expand Down

0 comments on commit a154829

Please sign in to comment.