From 254199e866a591ec079e1f7c63bbef65c6fe286f Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 15 Jan 2020 13:03:27 +0100 Subject: [PATCH] Introduce capability to merge single loop children in further cases --- qupulse/_program/_loop.py | 29 +++++++++++++++++++----- tests/_program/loop_tests.py | 43 ++++++++++++------------------------ 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/qupulse/_program/_loop.py b/qupulse/_program/_loop.py index 4817f50e7..6173a43db 100644 --- a/qupulse/_program/_loop.py +++ b/qupulse/_program/_loop.py @@ -232,7 +232,7 @@ def copy_tree_structure(self, new_parent: Union['Loop', bool]=False) -> 'Loop': waveform=self._waveform, repetition_count=self.repetition_count, repetition_parameter=self._repetition_parameter, - measurements=self._measurements, + measurements=None if self._measurements is None else list(self._measurements), children=(child.copy_tree_structure() for child in self)) def _get_measurement_windows(self) -> Mapping[str, np.ndarray]: @@ -347,7 +347,7 @@ def flatten_and_balance(self, depth: int) -> None: # subprogram is balanced with the correct depth i += 1 - elif len(sub_program) == 1 and not sub_program._measurements: + elif sub_program._has_single_child_that_can_be_merged(): # subprogram is balanced but to deep and has no measurements -> we can "lift" the sub-sub-program # TODO: There was a len(sub_sub_program) == 1 check here that I cannot explain sub_program._merge_single_child() @@ -360,14 +360,31 @@ def flatten_and_balance(self, depth: int) -> None: # we land in this case if the function gets called with depth == 0 and the current subprogram is a leaf i += 1 + def _has_single_child_that_can_be_merged(self) -> bool: + if len(self) == 1: + child = cast(Loop, self[0]) + return not self._measurements or (child.repetition_count == 1 and child.repetition_parameter is None) + else: + return False + def _merge_single_child(self): - """Lift the single child to current level""" + """Lift the single child to current level. Requires _has_single_child_that_can_be_merged to be true""" assert len(self) == 1, "bug: _merge_single_child called on loop with len != 1" - assert not self._measurements, "bug: _merge_single_child called on loop with measurements" + child = cast(Loop, self[0]) + + # if the child has a fixed repetition count of 1 the measurements can be merged + mergable_measurements = child.repetition_count == 1 and child.repetition_parameter is None + + assert not self._measurements or mergable_measurements, "bug: _merge_single_child called on loop with measurements" assert not self._waveform, "bug: _merge_single_child called on loop with children and waveform" - child = cast(Loop, self[0]) measurements = child._measurements + if self._measurements: + if measurements: + measurements.extend(self._measurements) + else: + measurements = self._measurements + repetition_count = self.repetition_count * child.repetition_count if self._repetition_parameter is None and child._repetition_parameter is None: @@ -434,7 +451,7 @@ def cleanup(self, actions=('remove_empty_loops', 'merge_single_child')): for child in self: child.cleanup(actions) - if 'merge_single_child' in actions and len(self) == 1 and not self._measurements: + if 'merge_single_child' in actions and self._has_single_child_that_can_be_merged(): self._merge_single_child() def get_duration_structure(self) -> Tuple[int, Union[TimeType, tuple]]: diff --git a/tests/_program/loop_tests.py b/tests/_program/loop_tests.py index a2a6bbd57..6fdf03090 100644 --- a/tests/_program/loop_tests.py +++ b/tests/_program/loop_tests.py @@ -6,7 +6,7 @@ from qupulse.utils.types import TimeType, time_from_float from qupulse._program._loop import Loop, MultiChannelProgram, _make_compatible, _is_compatible, _CompatibilityLevel,\ - RepetitionWaveform, SequenceWaveform, make_compatible, MakeCompatibleWarning + RepetitionWaveform, SequenceWaveform, make_compatible, MakeCompatibleWarning, DroppedMeasurementWarning from qupulse._program.instructions import InstructionBlock, ImmutableInstructionBlock from tests.pulses.sequencing_dummies import DummyWaveform from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform @@ -330,32 +330,6 @@ def test_unroll(self): Loop(waveform=wf3)]) self.assertEqual(expected, root) - def test_remove_empty_loops(self): - wfs = [DummyWaveform(duration=i) for i in range(2)] - - root = Loop(children=[ - Loop(waveform=wfs[0]), - Loop(waveform=None), - Loop(children=[Loop(waveform=None)]), - Loop(children=[Loop(waveform=wfs[1])]) - ]) - - expected = Loop(children=[ - Loop(waveform=wfs[0]), - Loop(children=[Loop(waveform=wfs[1])]) - ]) - - root.remove_empty_loops() - - self.assertEqual(expected, root) - - root = Loop(children=[ - Loop(measurements=[('m', 0, 1)]) - ]) - - with self.assertWarnsRegex(UserWarning, 'Dropping measurement'): - root.remove_empty_loops() - def test_cleanup(self): wfs = [DummyWaveform(duration=i) for i in range(3)] @@ -377,18 +351,29 @@ def test_cleanup(self): self.assertEqual(expected, root) + def test_cleanup_single_rep(self): + wf = DummyWaveform(duration=1) + measurements = [('n', 0, 1)] + + root = Loop(children=[Loop(waveform=wf, repetition_count=1)], + measurements=measurements, repetition_count=10) + + expected = Loop(waveform=wf, repetition_count=10, measurements=measurements) + root.cleanup() + self.assertEqual(expected, root) + def test_cleanup_warnings(self): root = Loop(children=[ Loop(measurements=[('m', 0, 1)]) ]) - with self.assertWarnsRegex(UserWarning, 'Dropping measurement'): + with self.assertWarnsRegex(DroppedMeasurementWarning, 'Dropping measurement'): root.cleanup() root = Loop(children=[ Loop(measurements=[('m', 0, 1)], children=[Loop()]) ]) - with self.assertWarnsRegex(UserWarning, 'Dropping measurement since there is no waveform in children'): + with self.assertWarnsRegex(DroppedMeasurementWarning, 'Dropping measurement since there is no waveform in children'): root.cleanup()