Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions qupulse/_program/_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def copy_tree_structure(self, new_parent: Union['Loop', bool]=False) -> 'Loop':
waveform=self._waveform,
repetition_count=self.repetition_count,
repetition_parameter=self._repetition_parameter,
measurements=self._measurements,
measurements=None if self._measurements is None else list(self._measurements),
children=(child.copy_tree_structure() for child in self))

def _get_measurement_windows(self) -> Mapping[str, np.ndarray]:
Expand Down Expand Up @@ -347,7 +347,7 @@ def flatten_and_balance(self, depth: int) -> None:
# subprogram is balanced with the correct depth
i += 1

elif len(sub_program) == 1 and not sub_program._measurements:
elif sub_program._has_single_child_that_can_be_merged():
# subprogram is balanced but to deep and has no measurements -> we can "lift" the sub-sub-program
# TODO: There was a len(sub_sub_program) == 1 check here that I cannot explain
sub_program._merge_single_child()
Expand All @@ -360,14 +360,31 @@ def flatten_and_balance(self, depth: int) -> None:
# we land in this case if the function gets called with depth == 0 and the current subprogram is a leaf
i += 1

def _has_single_child_that_can_be_merged(self) -> bool:
if len(self) == 1:
child = cast(Loop, self[0])
return not self._measurements or (child.repetition_count == 1 and child.repetition_parameter is None)
else:
return False

def _merge_single_child(self):
"""Lift the single child to current level"""
"""Lift the single child to current level. Requires _has_single_child_that_can_be_merged to be true"""
assert len(self) == 1, "bug: _merge_single_child called on loop with len != 1"
assert not self._measurements, "bug: _merge_single_child called on loop with measurements"
child = cast(Loop, self[0])

# if the child has a fixed repetition count of 1 the measurements can be merged
mergable_measurements = child.repetition_count == 1 and child.repetition_parameter is None

assert not self._measurements or mergable_measurements, "bug: _merge_single_child called on loop with measurements"
assert not self._waveform, "bug: _merge_single_child called on loop with children and waveform"

child = cast(Loop, self[0])
measurements = child._measurements
if self._measurements:
if measurements:
measurements.extend(self._measurements)
else:
measurements = self._measurements

repetition_count = self.repetition_count * child.repetition_count

if self._repetition_parameter is None and child._repetition_parameter is None:
Expand Down Expand Up @@ -434,7 +451,7 @@ def cleanup(self, actions=('remove_empty_loops', 'merge_single_child')):
for child in self:
child.cleanup(actions)

if 'merge_single_child' in actions and len(self) == 1 and not self._measurements:
if 'merge_single_child' in actions and self._has_single_child_that_can_be_merged():
self._merge_single_child()

def get_duration_structure(self) -> Tuple[int, Union[TimeType, tuple]]:
Expand Down
43 changes: 14 additions & 29 deletions tests/_program/loop_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from qupulse.utils.types import TimeType, time_from_float
from qupulse._program._loop import Loop, MultiChannelProgram, _make_compatible, _is_compatible, _CompatibilityLevel,\
RepetitionWaveform, SequenceWaveform, make_compatible, MakeCompatibleWarning
RepetitionWaveform, SequenceWaveform, make_compatible, MakeCompatibleWarning, DroppedMeasurementWarning
from qupulse._program.instructions import InstructionBlock, ImmutableInstructionBlock
from tests.pulses.sequencing_dummies import DummyWaveform
from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform
Expand Down Expand Up @@ -330,32 +330,6 @@ def test_unroll(self):
Loop(waveform=wf3)])
self.assertEqual(expected, root)

def test_remove_empty_loops(self):
wfs = [DummyWaveform(duration=i) for i in range(2)]

root = Loop(children=[
Loop(waveform=wfs[0]),
Loop(waveform=None),
Loop(children=[Loop(waveform=None)]),
Loop(children=[Loop(waveform=wfs[1])])
])

expected = Loop(children=[
Loop(waveform=wfs[0]),
Loop(children=[Loop(waveform=wfs[1])])
])

root.remove_empty_loops()

self.assertEqual(expected, root)

root = Loop(children=[
Loop(measurements=[('m', 0, 1)])
])

with self.assertWarnsRegex(UserWarning, 'Dropping measurement'):
root.remove_empty_loops()

def test_cleanup(self):
wfs = [DummyWaveform(duration=i) for i in range(3)]

Expand All @@ -377,18 +351,29 @@ def test_cleanup(self):

self.assertEqual(expected, root)

def test_cleanup_single_rep(self):
wf = DummyWaveform(duration=1)
measurements = [('n', 0, 1)]

root = Loop(children=[Loop(waveform=wf, repetition_count=1)],
measurements=measurements, repetition_count=10)

expected = Loop(waveform=wf, repetition_count=10, measurements=measurements)
root.cleanup()
self.assertEqual(expected, root)

def test_cleanup_warnings(self):
root = Loop(children=[
Loop(measurements=[('m', 0, 1)])
])

with self.assertWarnsRegex(UserWarning, 'Dropping measurement'):
with self.assertWarnsRegex(DroppedMeasurementWarning, 'Dropping measurement'):
root.cleanup()

root = Loop(children=[
Loop(measurements=[('m', 0, 1)], children=[Loop()])
])
with self.assertWarnsRegex(UserWarning, 'Dropping measurement since there is no waveform in children'):
with self.assertWarnsRegex(DroppedMeasurementWarning, 'Dropping measurement since there is no waveform in children'):
root.cleanup()


Expand Down