diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index d291dc250..7296da625 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -98,9 +98,9 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: """Returns an expression giving the integral over the pulse.""" def create_program(self, *, - parameters: Optional[Dict[str, Union[Parameter, float, Expression, str, Real]]]=None, - measurement_mapping: Optional[Dict[str, Optional[str]]]=None, - channel_mapping: Optional[Dict[ChannelID, Optional[ChannelID]]]=None, + parameters: Optional[Mapping[str, Union[Parameter, float, Expression, str, Real]]]=None, + measurement_mapping: Optional[Mapping[str, Optional[str]]]=None, + channel_mapping: Optional[Mapping[ChannelID, Optional[ChannelID]]]=None, global_transformation: Optional[Transformation]=None, to_single_waveform: Set[Union[str, 'PulseTemplate']]=None) -> Optional['Loop']: """Translates this PulseTemplate into a program Loop. @@ -128,9 +128,9 @@ def create_program(self, *, if to_single_waveform is None: to_single_waveform = set() - for channel in self.defined_channels: - if channel not in channel_mapping: - channel_mapping[channel] = channel + # make sure all channels are mapped + complete_channel_mapping = {channel: channel for channel in self.defined_channels} + complete_channel_mapping.update(channel_mapping) non_unique_targets = {channel for channel, count in collections.Counter(channel_mapping.values()).items() @@ -139,21 +139,20 @@ def create_program(self, *, raise ValueError('The following channels are mapped to twice', non_unique_targets) # make sure all values in the parameters dict are of type Parameter - for (key, value) in parameters.items(): - if not isinstance(value, Parameter): - parameters[key] = ConstantParameter(value) + parameters = {key: value if isinstance(value, Parameter) else ConstantParameter(value) + for key, value in parameters.items()} root_loop = Loop() # call subclass specific implementation self._create_program(parameters=parameters, measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping, + channel_mapping=complete_channel_mapping, global_transformation=global_transformation, to_single_waveform=to_single_waveform, parent_loop=root_loop) if root_loop.waveform is None and len(root_loop.children) == 0: - return None # return None if no program + return None # return None if no program return root_loop @abstractmethod diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index 73d66223a..304f13a79 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -165,8 +165,11 @@ class PulseTemplateTest(unittest.TestCase): def test_create_program(self) -> None: template = PulseTemplateStub(defined_channels={'A'}, parameter_names={'foo'}) parameters = {'foo': ConstantParameter(2.126), 'bar': -26.2, 'hugo': 'exp(sin(pi/2))', 'append_a_child': '1'} + previous_parameters = parameters.copy() measurement_mapping = {'M': 'N'} + previos_measurement_mapping = measurement_mapping.copy() channel_mapping = {'A': 'B'} + previous_channel_mapping = channel_mapping.copy() expected_parameters = {'foo': ConstantParameter(2.126), 'bar': ConstantParameter(-26.2), 'hugo': ConstantParameter('exp(sin(pi/2))'), 'append_a_child': ConstantParameter('1')} @@ -192,6 +195,9 @@ def test_create_program(self) -> None: global_transformation=global_transformation) _create_program.assert_called_once_with(**expected_internal_kwargs, parent_loop=program) self.assertEqual(expected_program, program) + self.assertEqual(previos_measurement_mapping, measurement_mapping) + self.assertEqual(previous_channel_mapping, channel_mapping) + self.assertEqual(previous_parameters, parameters) def test__create_program(self): parameters = {'a': ConstantParameter(.1), 'b': ConstantParameter(.2)}