diff --git a/qctoolkit/pulses/multi_channel_pulse_template.py b/qctoolkit/pulses/multi_channel_pulse_template.py index 2665754be..d95becab5 100644 --- a/qctoolkit/pulses/multi_channel_pulse_template.py +++ b/qctoolkit/pulses/multi_channel_pulse_template.py @@ -85,7 +85,9 @@ def duration(self) -> Expression: @property def parameter_names(self) -> Set[str]: - return set.union(*(st.parameter_names for st in self._subtemplates)) | self.constrained_parameters + return set.union(self.measurement_parameters, + self.constrained_parameters, + *(st.parameter_names for st in self._subtemplates)) @property def subtemplates(self) -> Sequence[AtomicPulseTemplate]: @@ -97,7 +99,7 @@ def defined_channels(self) -> Set[ChannelID]: @property def measurement_names(self) -> Set[str]: - return set.union(*(st.measurement_names for st in self._subtemplates)) + return super().measurement_names.union(*(st.measurement_names for st in self._subtemplates)) def build_waveform(self, parameters: Dict[str, numbers.Real], channel_mapping: Dict[ChannelID, Optional[ChannelID]]) -> Optional['MultiChannelWaveform']: diff --git a/tests/pulses/multi_channel_pulse_template_tests.py b/tests/pulses/multi_channel_pulse_template_tests.py index 0ef379902..45e7620bb 100644 --- a/tests/pulses/multi_channel_pulse_template_tests.py +++ b/tests/pulses/multi_channel_pulse_template_tests.py @@ -15,8 +15,6 @@ from tests.serialization_tests import SerializableTests - - class AtomicMultiChannelPulseTemplateTest(unittest.TestCase): def __init__(self,*args,**kwargs): super().__init__(*args,**kwargs) @@ -119,7 +117,19 @@ def test_measurement_names(self): sts = [DummyPulseTemplate(duration='t1', defined_channels={'A'}, parameter_names={'a', 'b'}, measurement_names={'A', 'C'}), DummyPulseTemplate(duration='t1', defined_channels={'B'}, parameter_names={'a', 'c'}, measurement_names={'A', 'B'})] - self.assertEqual(AtomicMultiChannelPulseTemplate(*sts).measurement_names, {'A', 'B', 'C'}) + self.assertEqual(AtomicMultiChannelPulseTemplate(*sts, measurements=[('D', 1, 2)]).measurement_names, + {'A', 'B', 'C', 'D'}) + + def test_parameter_names(self): + sts = [DummyPulseTemplate(duration='t1', defined_channels={'A'}, parameter_names={'a', 'b'}, + measurement_names={'A', 'C'}), + DummyPulseTemplate(duration='t1', defined_channels={'B'}, parameter_names={'a', 'c'}, + measurement_names={'A', 'B'})] + pt = AtomicMultiChannelPulseTemplate(*sts, measurements=[('D', 'd', 2)], parameter_constraints=['d < e']) + + self.assertEqual(pt.parameter_names, + {'a', 'b', 'c', 'd', 'e'}) + def test_integral(self) -> None: sts = [DummyPulseTemplate(duration='t1', defined_channels={'A'},