diff --git a/qupulse/_program/seqc.py b/qupulse/_program/seqc.py index 9ec7ae415..352697447 100644 --- a/qupulse/_program/seqc.py +++ b/qupulse/_program/seqc.py @@ -12,7 +12,8 @@ - `ProgramWaveformManager` and `HDAWGProgramEntry`: Program wise handling of waveforms and seqc-code classes that convert `Loop` objects""" -from typing import Optional, Union, Sequence, Dict, Iterator, Tuple, Callable, NamedTuple, MutableMapping, Mapping +from typing import Optional, Union, Sequence, Dict, Iterator, Tuple, Callable, NamedTuple, MutableMapping, Mapping,\ + Iterable from types import MappingProxyType import abc import itertools @@ -21,6 +22,7 @@ import os.path import hashlib from collections import OrderedDict +import string import numpy as np from pathlib import Path @@ -314,7 +316,38 @@ def prepare_delete(self): del self._memory.concatenated_waveforms[self._program_name] +class UserRegisterManager: + """This class keeps track of the user registered that are used in a certain context""" + def __init__(self, available: Iterable[int], name_template: str): + assert 'register' in (x[1] for x in string.Formatter().parse(name_template)) + + self._available = set(available) + self._name_template = name_template + self._used = {} + + def require(self, obj) -> str: + for register, registered_obj in self._used.items(): + if obj == registered_obj: + return self._name_template.format(register=register) + if self._available: + register = self._available.pop() + self._used[register] = obj + return self._name_template.format(register=register) + else: + raise ValueError("No register available for %r" % obj) + + def iter_used_registers(self) -> Iterator[Tuple[int, str]]: + """ + + Returns: + An iterator over (register index, register name) pairs + """ + return ((register, self._name_template.format(register=register)) for register in self._used.keys()) + + class HDAWGProgramEntry(ProgramEntry): + USER_REG_NAME_TEMPLATE = 'user_reg_{register}' + def __init__(self, loop: Loop, selection_index: int, waveform_memory: WaveformMemory, program_name: str, channels: Tuple[Optional[ChannelID], Optional[ChannelID]], markers: Tuple[Optional[ChannelID], Optional[ChannelID], Optional[ChannelID], Optional[ChannelID]], @@ -336,12 +369,15 @@ def __init__(self, loop: Loop, selection_index: int, waveform_memory: WaveformMe self._seqc_node = None self._seqc_source = None self._var_declarations = None + self._user_registers = None + self._user_register_source = None def compile(self, min_repetitions_for_for_loop: int, min_repetitions_for_shared_wf: int, indentation: str, - trigger_wait_code: str): + trigger_wait_code: str, + available_registers: Iterable[int]): """Compile the loop representation to an internal sequencing c one using `loop_to_seqc` Args: @@ -349,6 +385,7 @@ def compile(self, min_repetitions_for_shared_wf: See `loop_to_seqc` indentation: Each line is prefixed with this trigger_wait_code: The code is put before the playback start + available_registers Returns: """ @@ -356,10 +393,23 @@ def compile(self, if self._seqc_node: self._waveform_manager.clear_requested() + + user_registers = UserRegisterManager(available_registers, self.USER_REG_NAME_TEMPLATE) + self._seqc_node = loop_to_seqc(self._loop, min_repetitions_for_for_loop=min_repetitions_for_for_loop, min_repetitions_for_shared_wf=min_repetitions_for_shared_wf, - waveform_to_bin=self.get_binary_waveform) + waveform_to_bin=self.get_binary_waveform, + user_registers=user_registers) + + self._user_register_source = '\n'.join( + '{indentation}var {user_reg_name} = getUserReg({register});'.format(indentation=indentation, + user_reg_name=user_reg_name, + register=register) + for register, user_reg_name in user_registers.iter_used_registers() + ) + self._user_registers = user_registers + self._var_declarations = '{indentation}var {pos_var_name} = 0;'.format(pos_var_name=pos_var_name, indentation=indentation) self._trigger_wait_code = indentation + trigger_wait_code @@ -371,15 +421,16 @@ def compile(self, @property def seqc_node(self) -> 'SEQCNode': - if self._seqc_node is None: - raise RuntimeError('compile not called') + assert self._seqc_node is not None, "compile not called" return self._seqc_node @property def seqc_source(self) -> str: - if self._seqc_source is None: - raise RuntimeError('compile not called') - return '\n'.join([self._var_declarations, self._trigger_wait_code, self._seqc_source]) + assert self._seqc_source is not None, "compile not called" + return '\n'.join([self._var_declarations, + self._user_register_source, + self._trigger_wait_code, + self._seqc_source]) @property def name(self) -> str: @@ -435,26 +486,55 @@ def _get_low_unused_index(self): if idx not in existing and idx != self.GLOBAL_CONSTS['PROG_SEL_NONE']: return idx - def add_program(self, name, loop: Loop, + def add_program(self, name: str, loop: Loop, channels: Tuple[Optional[ChannelID], Optional[ChannelID]], markers: Tuple[Optional[ChannelID], Optional[ChannelID], Optional[ChannelID], Optional[ChannelID]], amplitudes: Tuple[float, float], offsets: Tuple[float, float], voltage_transformations: Tuple[Optional[Callable], Optional[Callable]], sample_rate: TimeType): - """""" + """Register the given program and translate it to seqc. + + TODO: Add an interface to change the trigger mode + + Args: + name: Human readable name of the program (used f.i. for the function name) + loop: The program to upload + channels: see AWG.upload + markers: see AWG.upload + amplitudes: Used to sample the waveforms + offsets: Used to sample the waveforms + voltage_transformations: see AWG.upload + sample_rate: Used to sample the waveforms + """ assert name not in self._programs selection_index = self._get_low_unused_index() + # TODO: verify total number of registers + available_registers = range(2, 16) + program_entry = HDAWGProgramEntry(loop, selection_index, self._waveform_memory, name, channels, markers, amplitudes, offsets, voltage_transformations, sample_rate) # TODO: de-hardcode these parameters and put compilation in seperate function - program_entry.compile(20, 1000, ' ', self.WAIT_FOR_SOFTWARE_TRIGGER) + program_entry.compile(20, 1000, ' ', self.WAIT_FOR_SOFTWARE_TRIGGER, + available_registers=available_registers) self._programs[name] = program_entry + def get_register_values_to_update_volatile_parameters(self, name: str, parameters: Mapping[str, float]) -> Mapping[int, int]: + """ + + Args: + name: Program name + parameters: new values for volatile parameters + + Returns: + A dict register->value that reflects the new parameter values + """ + raise NotImplementedError() + @property def programs(self) -> Mapping[str, HDAWGProgramEntry]: return MappingProxyType(self._programs) @@ -568,7 +648,8 @@ def to_node_clusters(loop: Union[Sequence[Loop], Loop], loop_to_seqc_kwargs: dic def loop_to_seqc(loop: Loop, min_repetitions_for_for_loop: int, min_repetitions_for_shared_wf: int, - waveform_to_bin: Callable[[Waveform], BinaryWaveform]) -> 'SEQCNode': + waveform_to_bin: Callable[[Waveform], BinaryWaveform], + user_registers: UserRegisterManager) -> 'SEQCNode': assert min_repetitions_for_for_loop <= min_repetitions_for_shared_wf # At which point do we switch from indexed to shared @@ -579,12 +660,13 @@ def loop_to_seqc(loop: Loop, node = loop_to_seqc(loop[0], min_repetitions_for_for_loop=min_repetitions_for_for_loop, min_repetitions_for_shared_wf=min_repetitions_for_shared_wf, - waveform_to_bin=waveform_to_bin) + waveform_to_bin=waveform_to_bin, user_registers=user_registers) else: node_clusters = to_node_clusters(loop, dict(min_repetitions_for_for_loop=min_repetitions_for_for_loop, min_repetitions_for_shared_wf=min_repetitions_for_shared_wf, - waveform_to_bin=waveform_to_bin)) + waveform_to_bin=waveform_to_bin, + user_registers=user_registers)) seqc_nodes = [] @@ -603,7 +685,11 @@ def loop_to_seqc(loop: Loop, node = Scope(seqc_nodes) - if loop.repetition_count != 1: + if loop.repetition_parameter is not None: + register_var = user_registers.require(loop.repetition_parameter) + return Repeat(scope=node, repetition_count=register_var) + + elif loop.repetition_count != 1: return Repeat(scope=node, repetition_count=loop.repetition_count) else: return node @@ -706,10 +792,10 @@ def to_source_code(self, waveform_manager: ProgramWaveformManager, node_name_gen class Repeat(SEQCNode): - """ - stepping: if False resets the pos to initial value after each iteration""" + """""" __slots__ = ('repetition_count', 'scope') INITIAL_POSITION_NAME_TEMPLATE = 'init_pos_{node_name}' + FOR_LOOP_NAME_TEMPLATE = 'idx_{node_name}' class _AdvanceStrategy: """describes what happens how this node interacts with the position variable""" @@ -717,8 +803,17 @@ class _AdvanceStrategy: POST_ADVANCE = 'post_advance' IGNORE = 'ignore' - def __init__(self, repetition_count: int, scope: SEQCNode): - assert repetition_count > 1 + def __init__(self, repetition_count: Union[int, str], scope: SEQCNode): + """ + Args: + repetition_count: A const integer value or a string that is expected to be a "var" + scope: The repeated scope + """ + if isinstance(repetition_count, int): + assert repetition_count > 1 + else: + assert isinstance(repetition_count, str) and repetition_count.isidentifier() + self.repetition_count = repetition_count self.scope = scope @@ -766,17 +861,34 @@ def to_source_code(self, waveform_manager: ProgramWaveformManager, node_name_gen advance_strategy = self._get_position_advance_strategy() if advance_pos_var else self._AdvanceStrategy.IGNORE inner_advance_pos_var = advance_strategy == self._AdvanceStrategy.INITIAL_RESET + def get_node_name(): + """Helper to assert node name only generated when needed and only generated once""" + if getattr(get_node_name, 'node_name', None) is None: + get_node_name.node_name = next(node_name_generator) + return get_node_name.node_name + if advance_strategy == self._AdvanceStrategy.INITIAL_RESET: - node_name = next(node_name_generator) - initial_position_name = self.INITIAL_POSITION_NAME_TEMPLATE.format(node_name=node_name) + initial_position_name = self.INITIAL_POSITION_NAME_TEMPLATE.format(node_name=get_node_name()) # store initial position yield '{line_prefix}var {init_pos_name} = {pos_var_name};'.format(line_prefix=line_prefix, init_pos_name=initial_position_name, pos_var_name=pos_var_name) - yield '{line_prefix}repeat({repetition_count}) {{'.format(line_prefix=line_prefix, - repetition_count=self.repetition_count) + if isinstance(self.repetition_count, int): + yield '{line_prefix}repeat({repetition_count}) {{'.format(line_prefix=line_prefix, + repetition_count=self.repetition_count) + else: + # repeat requires a const-expression so we need to use a for loop for user reg vars + assert isinstance(self.repetition_count, str) + loop_var = self.FOR_LOOP_NAME_TEMPLATE.format(node_name=get_node_name()) + yield '{line_prefix}var {loop_var};'.format(line_prefix=line_prefix, loop_var=loop_var) + yield ('{line_prefix}for({loop_var} = 0; ' + '{loop_var} < {repetition_count}; ' + '{loop_var} = {loop_var} + 1) {{').format(line_prefix=line_prefix, + loop_var=loop_var, + repetition_count=self.repetition_count) + if advance_strategy == self._AdvanceStrategy.INITIAL_RESET: yield ('{body_prefix}{pos_var_name} = {init_pos_name};' '').format(body_prefix=body_prefix, diff --git a/qupulse/pulses/parameters.py b/qupulse/pulses/parameters.py index 5d1b58d03..5164c5573 100644 --- a/qupulse/pulses/parameters.py +++ b/qupulse/pulses/parameters.py @@ -155,6 +155,13 @@ def requires_stop(self) -> bool: except KeyError as err: raise ParameterNotProvidedException(err.args[0]) from err + def __eq__(self, other): + if type(other) == type(self): + return (self._expression == other._expression and + self._namespace == other._namespace) + else: + return NotImplemented + def __repr__(self) -> str: try: value = self.get_value() diff --git a/tests/_program/seqc_tests.py b/tests/_program/seqc_tests.py index a6db87ebf..7854a5060 100644 --- a/tests/_program/seqc_tests.py +++ b/tests/_program/seqc_tests.py @@ -6,9 +6,11 @@ import numpy as np +from qupulse.expressions import ExpressionScalar +from qupulse.pulses.parameters import MappedParameter, ConstantParameter from qupulse._program._loop import Loop from qupulse._program.seqc import BinaryWaveform, loop_to_seqc, WaveformPlayback, Repeat, SteppingRepeat, Scope,\ - to_node_clusters, find_sharable_waveforms, mark_sharable_waveforms + to_node_clusters, find_sharable_waveforms, mark_sharable_waveforms, UserRegisterManager from tests.pulses.sequencing_dummies import DummyWaveform @@ -56,6 +58,9 @@ def complex_program_as_loop(unique_wfs, wf_same): root.append_child(waveform=unique_wfs[0], repetition_count=21) root.append_child(waveform=wf_same, repetition_count=23) + mapped = MappedParameter(ExpressionScalar('n + 4'), {'n': ConstantParameter(3)}) + root.append_child(waveform=wf_same, repetition_count=23, repetition_parameter=mapped) + return root @@ -71,6 +76,7 @@ def complex_program_as_seqc(unique_wfs, wf_same): ]), Repeat(21, WaveformPlayback(make_binary_waveform(unique_wfs[0]))), Repeat(23, WaveformPlayback(make_binary_waveform(wf_same))), + Repeat('test_14', WaveformPlayback(make_binary_waveform(wf_same))) ]) ) @@ -232,6 +238,9 @@ def test_get_position_advance_strategy(self): class LoopToSEQCTranslationTests(TestCase): def test_loop_to_seqc_leaf(self): """Test the translation of leaves""" + # we use None because it is not used in this test + user_registers = None + wf = DummyWaveform(duration=32) loop = Loop(waveform=wf) @@ -239,7 +248,7 @@ def test_loop_to_seqc_leaf(self): loop.repetition_count = 15 waveform_to_bin = mock.Mock(wraps=make_binary_waveform) expected = Repeat(loop.repetition_count, WaveformPlayback(waveform=make_binary_waveform(wf))) - result = loop_to_seqc(loop, 1, 1, waveform_to_bin) + result = loop_to_seqc(loop, 1, 1, waveform_to_bin, user_registers=user_registers) waveform_to_bin.assert_called_once_with(wf) self.assertEqual(expected, result) @@ -247,17 +256,21 @@ def test_loop_to_seqc_leaf(self): loop.repetition_count = 1 waveform_to_bin = mock.Mock(wraps=make_binary_waveform) expected = WaveformPlayback(waveform=make_binary_waveform(wf)) - result = loop_to_seqc(loop, 1, 1, waveform_to_bin) + result = loop_to_seqc(loop, 1, 1, waveform_to_bin, user_registers=user_registers) waveform_to_bin.assert_called_once_with(wf) self.assertEqual(expected, result) def test_loop_to_seqc_len_1(self): """Test the translation of loops with len(loop) == 1""" + # we use None because it is not used in this test + user_registers = None + loop = Loop(children=[Loop()]) waveform_to_bin = mock.Mock(wraps=make_binary_waveform) loop_to_seqc_kwargs = dict(min_repetitions_for_for_loop=2, min_repetitions_for_shared_wf=3, - waveform_to_bin=waveform_to_bin) + waveform_to_bin=waveform_to_bin, + user_registers=user_registers) expected = 'asdf' with mock.patch('qupulse._program.seqc.loop_to_seqc', return_value=expected) as mocked_loop_to_seqc: @@ -334,14 +347,18 @@ def test_mark_sharable_waveforms(self): def test_loop_to_seqc_cluster_handling(self): """Test handling of clusters""" + + # we use None because it is not used in this test + user_registers = None + with self.assertRaises(AssertionError): loop_to_seqc(Loop(repetition_count=12, children=[Loop()]), min_repetitions_for_for_loop=3, min_repetitions_for_shared_wf=2, - waveform_to_bin=make_binary_waveform) + waveform_to_bin=make_binary_waveform, user_registers=user_registers) loop_to_seqc_kwargs = dict(min_repetitions_for_for_loop=3, min_repetitions_for_shared_wf=4, - waveform_to_bin=make_binary_waveform) + waveform_to_bin=make_binary_waveform, user_registers=user_registers) wf_same = map(WaveformPlayback, map(make_binary_waveform, get_unique_wfs(100000, 32))) wf_sep, = map(WaveformPlayback, map(make_binary_waveform, get_unique_wfs(1, 64))) @@ -381,13 +398,15 @@ def dummy_find_sharable_waveforms(cluster): def test_program_translation(self): """Integration test""" + user_registers = UserRegisterManager(range(14, 15), 'test_{register}') + unique_wfs = get_unique_wfs() same_wf = DummyWaveform(duration=32, sample_output=np.ones(32)) root = complex_program_as_loop(unique_wfs, wf_same=same_wf) t0 = time.perf_counter() - seqc = loop_to_seqc(root, 50, 100, make_binary_waveform) + seqc = loop_to_seqc(root, 50, 100, make_binary_waveform, user_registers=user_registers) t1 = time.perf_counter() print('took', t1 - t0, 's') @@ -583,6 +602,23 @@ def node_name_gen(): playWaveIndexed(0, pos, 48); // advance disabled do to parent repetition } pos = pos + 48; + var idx_1; + for(idx_1 = 0; idx_1 < test_14; idx_1 = idx_1 + 1) { + playWaveIndexed(0, pos, 48); // advance disabled do to parent repetition + } + pos = pos + 48; }""" self.assertEqual(expected, seqc_code) + +class UserRegisterManagerTest(unittest.TestCase): + def test_require(self): + manager = UserRegisterManager([7, 8, 9], 'test{register}') + + required = [manager.require(0), manager.require(1), manager.require(2)] + + self.assertEqual({'test7', 'test8', 'test9'}, set(required)) + self.assertEqual(required[1], manager.require(1)) + + with self.assertRaisesRegex(ValueError, "No register"): + manager.require(3) diff --git a/tests/pulses/loop_pulse_template_tests.py b/tests/pulses/loop_pulse_template_tests.py index bbbe02624..bf840e9d7 100644 --- a/tests/pulses/loop_pulse_template_tests.py +++ b/tests/pulses/loop_pulse_template_tests.py @@ -388,7 +388,7 @@ def test_create_program(self) -> None: global_transformation=global_transformation, volatile=volatile) - validate_parameter_constraints.assert_called_once_with(parameters=parameters) + validate_parameter_constraints.assert_called_once_with(parameters=parameters, volatile=volatile) get_measurement_windows.assert_called_once_with(expected_meas_params, measurement_mapping) self.assertEqual(body_create_program.call_args_list, expected_create_program_calls) diff --git a/tests/pulses/measurement_tests.py b/tests/pulses/measurement_tests.py index a1512654c..140169c75 100644 --- a/tests/pulses/measurement_tests.py +++ b/tests/pulses/measurement_tests.py @@ -1,7 +1,7 @@ import unittest from qupulse.pulses.parameters import ParameterConstraint, ParameterConstraintViolation,\ - ParameterNotProvidedException, ParameterConstrainer, ConstantParameter + ParameterNotProvidedException, ParameterConstrainer, ConstantParameter, ConstrainedParameterIsVolatileWarning from qupulse.pulses.measurement import MeasurementDefiner from qupulse._program.instructions import InstructionBlock, MEASInstruction @@ -86,6 +86,7 @@ class ParameterConstrainerTest(unittest.TestCase): def __init__(self, *args, to_test_constructor=None, **kwargs): super().__init__(*args, **kwargs) + # TODO: Figure out what is going on here if to_test_constructor is None: self.to_test_constructor = lambda parameter_constraints=None:\ ParameterConstrainer(parameter_constraints=parameter_constraints) @@ -104,27 +105,30 @@ def test_parameter_constraints(self): def test_validate_parameter_constraints(self): to_test = self.to_test_constructor() - to_test.validate_parameter_constraints(dict()) - to_test.validate_parameter_constraints(dict(a=1)) + to_test.validate_parameter_constraints(dict(), set()) + to_test.validate_parameter_constraints(dict(a=1), set()) to_test = self.to_test_constructor(['a < b']) with self.assertRaises(ParameterNotProvidedException): - to_test.validate_parameter_constraints(dict()) + to_test.validate_parameter_constraints(dict(), set()) with self.assertRaises(ParameterConstraintViolation): - to_test.validate_parameter_constraints(dict(a=1, b=0.8)) - to_test.validate_parameter_constraints(dict(a=1, b=2)) + to_test.validate_parameter_constraints(dict(a=1, b=0.8), set()) + to_test.validate_parameter_constraints(dict(a=1, b=2), set()) to_test = self.to_test_constructor(['a < b', 'c < 1']) with self.assertRaises(ParameterNotProvidedException): - to_test.validate_parameter_constraints(dict(a=1, b=2)) + to_test.validate_parameter_constraints(dict(a=1, b=2), set()) with self.assertRaises(ParameterNotProvidedException): - to_test.validate_parameter_constraints(dict(c=0.5)) + to_test.validate_parameter_constraints(dict(c=0.5), set()) with self.assertRaises(ParameterConstraintViolation): - to_test.validate_parameter_constraints(dict(a=1, b=0.8, c=0.5)) + to_test.validate_parameter_constraints(dict(a=1, b=0.8, c=0.5), set()) with self.assertRaises(ParameterConstraintViolation): - to_test.validate_parameter_constraints(dict(a=0.5, b=0.8, c=1)) - to_test.validate_parameter_constraints(dict(a=0.5, b=0.8, c=0.1)) + to_test.validate_parameter_constraints(dict(a=0.5, b=0.8, c=1), set()) + to_test.validate_parameter_constraints(dict(a=0.5, b=0.8, c=0.1), {'j'}) + + with self.assertWarns(ConstrainedParameterIsVolatileWarning): + to_test.validate_parameter_constraints(dict(a=0.5, b=0.8, c=0.1), {'a'}) def test_constrained_parameters(self): to_test = self.to_test_constructor() diff --git a/tests/pulses/parameters_tests.py b/tests/pulses/parameters_tests.py index e280775e9..5abf84d99 100644 --- a/tests/pulses/parameters_tests.py +++ b/tests/pulses/parameters_tests.py @@ -60,17 +60,17 @@ def test_requires_stop_and_get_value(self) -> None: hugo = DummyParameter(5.2, requires_stop=True) ilse = DummyParameter(2356.4, requires_stop=True) - p.dependencies = {'foo': foo, 'bar': bar, 'ilse': ilse} + p._namespace = {'foo': foo, 'bar': bar, 'ilse': ilse} with self.assertRaises(ParameterNotProvidedException): p.requires_stop with self.assertRaises(ParameterNotProvidedException): p.get_value() - p.dependencies = {'foo': foo, 'bar': bar, 'hugo': hugo} + p._namespace = {'foo': foo, 'bar': bar, 'hugo': hugo} self.assertTrue(p.requires_stop) hugo = DummyParameter(5.2, requires_stop=False) - p.dependencies = {'foo': foo, 'bar': bar, 'hugo': hugo, 'ilse': ilse} + p._namespace = {'foo': foo, 'bar': bar, 'hugo': hugo, 'ilse': ilse} self.assertFalse(p.requires_stop) self.assertEqual(1.5, p.get_value()) @@ -78,6 +78,14 @@ def test_repr(self) -> None: p = MappedParameter(Expression("foo + bar * hugo")) self.assertIsInstance(repr(p), str) + def test_equality(self): + p1 = MappedParameter(Expression("foo + 1"), {'foo': ConstantParameter(3)}) + p2 = MappedParameter(Expression("foo + 1"), {'foo': ConstantParameter(4)}) + p3 = MappedParameter(Expression("foo + 1"), {'foo': ConstantParameter(3)}) + + self.assertEqual(p1, p3) + self.assertNotEqual(p1, p2) + class ParameterConstraintTest(unittest.TestCase): def test_ordering(self): diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index 19596ce46..11f982b7a 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -431,6 +431,9 @@ def test_internal_create_program(self) -> None: self.assertEqual(expected_program, program) + # MultiChannelProgram calls cleanup + program.cleanup() + # ensure same result as from Sequencer sequencer = Sequencer() sequencer.push(template, parameters=parameters, conditions={}, window_mapping=measurement_mapping, diff --git a/tests/pulses/repetition_pulse_template_tests.py b/tests/pulses/repetition_pulse_template_tests.py index e9ef99a09..1362814d4 100644 --- a/tests/pulses/repetition_pulse_template_tests.py +++ b/tests/pulses/repetition_pulse_template_tests.py @@ -144,7 +144,7 @@ def test_internal_create_program(self): global_transformation=global_transformation, to_single_waveform=to_single_waveform, parent_loop=program.children[0], volatile=set()) - validate_parameter_constraints.assert_called_once_with(parameters=parameters) + validate_parameter_constraints.assert_called_once_with(parameters=parameters, volatile=set()) get_repetition_count_value.assert_called_once_with(real_relevant_parameters) get_meas.assert_called_once_with(real_relevant_parameters, measurement_mapping) @@ -165,7 +165,7 @@ def test_create_program_constant_success_measurements(self) -> None: parent_loop=program, volatile=volatile) self.assertEqual(1, len(program.children)) - internal_loop = program.children[0] # type: Loop + internal_loop = program[0] # type: Loop self.assertEqual(repetitions, internal_loop.repetition_count) self.assertEqual(1, len(internal_loop)) @@ -174,6 +174,12 @@ def test_create_program_constant_success_measurements(self) -> None: self.assert_measurement_windows_equal({'b': ([0, 2, 4], [1, 1, 1]), 'thy': ([2], [2])}, program.get_measurement_windows()) + # done in MultiChannelProgram + program.cleanup() + + self.assert_measurement_windows_equal({'b': ([0, 2, 4], [1, 1, 1]), 'thy': ([2], [2])}, + program.get_measurement_windows()) + # ensure same result as from Sequencer sequencer = Sequencer() sequencer.push(t, parameters=parameters, conditions={}, window_mapping=measurement_mapping, channel_mapping=channel_mapping) @@ -274,6 +280,10 @@ def test_create_program_declaration_success_measurements(self) -> None: self.assert_measurement_windows_equal({'fire': ([0], [7.1]), 'b': ([0, 2, 4], [1, 1, 1])}, program.get_measurement_windows()) + # MultiChannelProgram calls cleanup + program.cleanup() + self.assert_measurement_windows_equal({'fire': ([0], [7.1]), 'b': ([0, 2, 4], [1, 1, 1])}, program.get_measurement_windows()) + # ensure same result as from Sequencer sequencer = Sequencer() sequencer.push(t, parameters=parameters, conditions={}, window_mapping=measurement_mapping, diff --git a/tests/pulses/sequence_pulse_template_tests.py b/tests/pulses/sequence_pulse_template_tests.py index ab814766f..f36023bce 100644 --- a/tests/pulses/sequence_pulse_template_tests.py +++ b/tests/pulses/sequence_pulse_template_tests.py @@ -257,7 +257,7 @@ def test_internal_create_program(self): self.assertEqual(expected_program, program) - validate_parameter_constraints.assert_called_once_with(parameters=kwargs['parameters']) + validate_parameter_constraints.assert_called_once_with(parameters=kwargs['parameters'], volatile=kwargs['volatile']) get_measurement_windows.assert_called_once_with(dict(a=.1, b=.2), kwargs['measurement_mapping']) create_0.assert_called_once_with(**kwargs, parent_loop=program, volatile=set()) create_1.assert_called_once_with(**kwargs, parent_loop=program, volatile=set()) @@ -370,6 +370,11 @@ def test_internal_create_program_one_child_no_duration(self) -> None: list(loop.children)) self.assert_measurement_windows_equal({'a': ([0], [1])}, loop.get_measurement_windows()) + # MultiChannelProgram calls cleanup + loop.cleanup() + self.assert_measurement_windows_equal({'fire': ([0], [7.1]), 'b': ([0, 2, 4], [1, 1, 1])}, + loop.get_measurement_windows()) + # ensure same result as from Sequencer sequencer = Sequencer() sequencer.push(seq, parameters=parameters, conditions={}, window_mapping=measurement_mapping, channel_mapping=channel_mapping) @@ -392,6 +397,11 @@ def test_internal_create_program_one_child_no_duration(self) -> None: list(loop.children)) self.assert_measurement_windows_equal({'a': ([0], [1])}, loop.get_measurement_windows()) + # MultiChannelProgram calls cleanup + loop.cleanup() + self.assert_measurement_windows_equal({'fire': ([0], [7.1]), 'b': ([0, 2, 4], [1, 1, 1])}, + loop.get_measurement_windows()) + # ensure same result as from Sequencer sequencer = Sequencer() sequencer.push(seq, parameters=parameters, conditions={}, window_mapping=measurement_mapping, channel_mapping=channel_mapping)