From 3ca0a1dabc4dcd9ca1fd27018f18bbea1d95b02f Mon Sep 17 00:00:00 2001 From: Patrick Bethke Date: Fri, 1 Jan 2016 13:55:49 +0100 Subject: [PATCH] identity mappings are now automatically added if explicit mappings are omitted. also tests. --- qctoolkit/pulses/sequence_pulse_template.py | 9 +++++++- tests/pulses/sequence_pulse_template_tests.py | 22 ++++++++++++++----- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/qctoolkit/pulses/sequence_pulse_template.py b/qctoolkit/pulses/sequence_pulse_template.py index 019869ea3..8247dafcc 100644 --- a/qctoolkit/pulses/sequence_pulse_template.py +++ b/qctoolkit/pulses/sequence_pulse_template.py @@ -43,7 +43,7 @@ def __len__(self): class SequencePulseTemplate(PulseTemplate): """A sequence of different PulseTemplates. - + SequencePulseTemplate allows to group smaller PulseTemplates (subtemplates) into on larger sequence, i.e., when instantiating a pulse from a SequencePulseTemplate @@ -64,6 +64,13 @@ class SequencePulseTemplate(PulseTemplate): def __init__(self, subtemplates: List[Subtemplate], external_parameters: List[str], identifier: Optional[str]=None) -> None: super().__init__(identifier) self.__parameter_names = frozenset(external_parameters) + # insert identity mappings for entries without explicit mapping + for i, entry in enumerate(subtemplates): + if type(entry) != tuple: + subtemplates[i] = (entry, IdentityMapping(entry)) + elif type(entry) == tuple and len(entry) == 1: + subtemplates[i] = (entry[0], IdentityMapping(entry[0])) + # convert all mapping strings to expressions for i, (template, mappings) in enumerate(subtemplates): subtemplates[i] = (template, {k: Expression(v) for k, v in mappings.items()}) diff --git a/tests/pulses/sequence_pulse_template_tests.py b/tests/pulses/sequence_pulse_template_tests.py index d09a8d426..5554f5f68 100644 --- a/tests/pulses/sequence_pulse_template_tests.py +++ b/tests/pulses/sequence_pulse_template_tests.py @@ -37,7 +37,7 @@ def __init__(self, *args, **kwargs): self.sequence = SequencePulseTemplate([(self.square, self.mapping1)], self.outer_parameters) def test_missing_mapping(self): - mapping = self.mapping1 + mapping = copy.deepcopy(self.mapping1) mapping.pop('v') subtemplates = [(self.square, mapping)] @@ -45,7 +45,7 @@ def test_missing_mapping(self): SequencePulseTemplate(subtemplates, self.outer_parameters) def test_unnecessary_mapping(self): - mapping = self.mapping1 + mapping = copy.deepcopy(self.mapping1) mapping['unnecessary'] = 'voltage' subtemplates = [(self.square, mapping)] @@ -57,6 +57,18 @@ def test_identifier(self): pulse = SequencePulseTemplate([], [], identifier=identifier) self.assertEqual(identifier, pulse.identifier) + def test_identity_mapping_tuple(self): + mapping = copy.deepcopy(self.mapping1) + subtemplates = [(self.square,)] + pulse = SequencePulseTemplate(subtemplates, self.square.parameter_names) + self.assertEqual(self.square.parameter_names, pulse.parameter_names) + + def test_identity_mapping_direct(self): + mapping = copy.deepcopy(self.mapping1) + subtemplates = [self.square] + pulse = SequencePulseTemplate(subtemplates, self.square.parameter_names) + self.assertEqual(self.square.parameter_names, pulse.parameter_names) + class SequencePulseTemplateSerializationTests(unittest.TestCase): @@ -147,7 +159,7 @@ def test_str(self): a = [UnnecessaryMappingException(T,"b"), MissingMappingException(T,"b"), MissingParameterDeclarationException(T, "c")] - + b = [x.__str__() for x in a] for s in b: self.assertIsInstance(s, str) @@ -158,11 +170,11 @@ def test_is_interruptable(self): self.assertTrue(self.sequence.is_interruptable) self.sequence.is_interruptable = False self.assertFalse(self.sequence.is_interruptable) - + def test_parameter_declarations(self): decl = self.sequence.parameter_declarations self.assertEqual(decl, set([ParameterDeclaration(i) for i in self.outer_parameters])) - + def test_requires_stop(self): seq = SequencePulseTemplate([],[]) self.assertFalse(seq.requires_stop({}, {}))