Skip to content

Commit

Permalink
Merge 81b7891 into 5a5b5d7
Browse files Browse the repository at this point in the history
  • Loading branch information
lumip committed Jul 12, 2018
2 parents 5a5b5d7 + 81b7891 commit 895f48b
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 22 deletions.
15 changes: 15 additions & 0 deletions qctoolkit/serialization.py
Expand Up @@ -496,6 +496,13 @@ def deserialize(cls, serializer: Optional['Serializer']=None, **kwargs) -> 'Seri

return cls(**kwargs)

def renamed(self, new_identifier: str, registry: Optional[Dict]=None) -> 'Serializable':
"""Returns a copy of the Serializable with its identifier set to new_identifier."""
data = self.get_serialization_data()
data.pop(Serializable.type_identifier_name)
data.pop(Serializable.identifier_name)
return self.deserialize(registry=registry, identifier=new_identifier, **data)


class AnonymousSerializable:
"""Any object that can be converted into a serialized representation for storage and back which NEVER has an
Expand Down Expand Up @@ -715,6 +722,14 @@ def __getitem__(self, identifier: str) -> Serializable:
return self._temporary_storage[identifier].serializable

def __setitem__(self, identifier: str, serializable: Serializable) -> None:
if identifier != serializable.identifier: # address issue #272: https://github.com/qutech/qc-toolkit/issues/272
raise ValueError("Storing a Serializable under a different than its own internal identifier is currently"
" not supported! If you want to rename the serializable, please use the "
"Serializable.renamed() method to obtain a renamed copy which can then be stored with "
"the new identifier.\n"
"If you think that storing under a different identifier without explicit renaming should"
"a supported feature, please contribute to our ongoing discussion about this on:\n"
"https://github.com/qutech/qc-toolkit/issues/272")
if identifier in self._temporary_storage:
if self.temporary_storage[identifier].serializable is serializable:
return
Expand Down
2 changes: 1 addition & 1 deletion tests/pulses/function_pulse_tests.py
Expand Up @@ -99,7 +99,7 @@ def make_kwargs(self):
str(ParameterConstraint('d > c'))]
}

def assert_equal_instance(self, lhs: FunctionPulseTemplate, rhs: FunctionPulseTemplate):
def assert_equal_instance_except_id(self, lhs: FunctionPulseTemplate, rhs: FunctionPulseTemplate):
self.assertIsInstance(lhs, FunctionPulseTemplate)
self.assertIsInstance(rhs, FunctionPulseTemplate)
self.assertEqual(lhs.parameter_names, rhs.parameter_names)
Expand Down
4 changes: 2 additions & 2 deletions tests/pulses/loop_pulse_template_tests.py
Expand Up @@ -241,7 +241,7 @@ def make_kwargs(self):
'measurements': [('a', 0, 1), ('b', 1, 1)]
}

def assert_equal_instance(self, lhs: ForLoopPulseTemplate, rhs: ForLoopPulseTemplate):
def assert_equal_instance_except_id(self, lhs: ForLoopPulseTemplate, rhs: ForLoopPulseTemplate):
self.assertIsInstance(lhs, ForLoopPulseTemplate)
self.assertIsInstance(rhs, ForLoopPulseTemplate)
self.assertEqual(lhs.body, rhs.body)
Expand Down Expand Up @@ -453,7 +453,7 @@ def make_kwargs(self):
'condition': 'foo_cond'
}

def assert_equal_instance(self, lhs: WhileLoopPulseTemplate, rhs: WhileLoopPulseTemplate):
def assert_equal_instance_except_id(self, lhs: WhileLoopPulseTemplate, rhs: WhileLoopPulseTemplate):
self.assertIsInstance(lhs, WhileLoopPulseTemplate)
self.assertIsInstance(rhs, WhileLoopPulseTemplate)
self.assertEqual(lhs.body, rhs.body)
Expand Down
2 changes: 1 addition & 1 deletion tests/pulses/multi_channel_pulse_template_tests.py
Expand Up @@ -327,7 +327,7 @@ def make_instance(self, identifier=None, registry=None):
del kwargs['subtemplates']
return self.class_to_test(identifier=identifier, *subtemplates, **kwargs, registry=registry)

def assert_equal_instance(self, lhs: AtomicMultiChannelPulseTemplate, rhs: AtomicMultiChannelPulseTemplate):
def assert_equal_instance_except_id(self, lhs: AtomicMultiChannelPulseTemplate, rhs: AtomicMultiChannelPulseTemplate):
self.assertIsInstance(lhs, AtomicMultiChannelPulseTemplate)
self.assertIsInstance(rhs, AtomicMultiChannelPulseTemplate)
self.assertEqual(lhs.subtemplates, rhs.subtemplates)
Expand Down
2 changes: 1 addition & 1 deletion tests/pulses/point_pulse_template_tests.py
Expand Up @@ -236,7 +236,7 @@ def make_kwargs(self):
'parameter_constraints': [str(ParameterConstraint('ilse>2')), str(ParameterConstraint('k>foo'))]
}

def assert_equal_instance(self, lhs: PointPulseTemplate, rhs: PointPulseTemplate):
def assert_equal_instance_except_id(self, lhs: PointPulseTemplate, rhs: PointPulseTemplate):
self.assertIsInstance(lhs, PointPulseTemplate)
self.assertIsInstance(rhs, PointPulseTemplate)
self.assertEqual(lhs.point_pulse_entries, rhs.point_pulse_entries)
Expand Down
2 changes: 1 addition & 1 deletion tests/pulses/pulse_template_parameter_mapping_tests.py
Expand Up @@ -251,7 +251,7 @@ def make_instance(self, identifier=None, registry=None):
kwargs = self.make_kwargs()
return self.class_to_test(identifier=identifier, **kwargs, allow_partial_parameter_mapping=True, registry=registry)

def assert_equal_instance(self, lhs: MappingPulseTemplate, rhs: MappingPulseTemplate):
def assert_equal_instance_except_id(self, lhs: MappingPulseTemplate, rhs: MappingPulseTemplate):
self.assertIsInstance(lhs, MappingPulseTemplate)
self.assertIsInstance(rhs, MappingPulseTemplate)
self.assertEqual(lhs.template, rhs.template)
Expand Down
2 changes: 1 addition & 1 deletion tests/pulses/repetition_pulse_template_tests.py
Expand Up @@ -298,7 +298,7 @@ def make_kwargs(self):
'measurements': [('m', 0, 1)]
}

def assert_equal_instance(self, lhs: RepetitionPulseTemplate, rhs: RepetitionPulseTemplate):
def assert_equal_instance_except_id(self, lhs: RepetitionPulseTemplate, rhs: RepetitionPulseTemplate):
self.assertIsInstance(lhs, RepetitionPulseTemplate)
self.assertIsInstance(rhs, RepetitionPulseTemplate)
self.assertEqual(lhs.body, rhs.body)
Expand Down
2 changes: 1 addition & 1 deletion tests/pulses/sequence_pulse_template_tests.py
Expand Up @@ -192,7 +192,7 @@ def make_instance(self, identifier=None, registry=None):
del kwargs['subtemplates']
return self.class_to_test(identifier=identifier, *subtemplates, **kwargs, registry=registry)

def assert_equal_instance(self, lhs: SequencePulseTemplate, rhs: SequencePulseTemplate):
def assert_equal_instance_except_id(self, lhs: SequencePulseTemplate, rhs: SequencePulseTemplate):
self.assertIsInstance(lhs, SequencePulseTemplate)
self.assertIsInstance(rhs, SequencePulseTemplate)
self.assertEqual(lhs.subtemplates, rhs.subtemplates)
Expand Down
3 changes: 1 addition & 2 deletions tests/pulses/table_pulse_template_tests.py
Expand Up @@ -453,10 +453,9 @@ def make_kwargs(self):
'parameter_constraints': [str(ParameterConstraint('ilse>2')), str(ParameterConstraint('k>foo'))]
}

def assert_equal_instance(self, lhs: TablePulseTemplate, rhs: TablePulseTemplate):
def assert_equal_instance_except_id(self, lhs: TablePulseTemplate, rhs: TablePulseTemplate):
self.assertIsInstance(lhs, TablePulseTemplate)
self.assertIsInstance(rhs, TablePulseTemplate)
self.assertEqual(lhs.identifier, rhs.identifier)
self.assertEqual(lhs.entries, rhs.entries)
self.assertEqual(lhs.measurement_declarations, rhs.measurement_declarations)
self.assertEqual(lhs.parameter_constraints, rhs.parameter_constraints)
Expand Down
38 changes: 26 additions & 12 deletions tests/serialization_tests.py
Expand Up @@ -56,8 +56,12 @@ def class_to_test(self) -> typing.Any:
def make_kwargs(self) -> dict:
pass

@abstractmethod
def assert_equal_instance(self, lhs, rhs):
self.assert_equal_instance_except_id(lhs, rhs)
self.assertEqual(lhs.identifier, rhs.identifier)

@abstractmethod
def assert_equal_instance_except_id(self, lhs, rhs):
pass

def make_instance(self, identifier=None, registry=None):
Expand Down Expand Up @@ -131,6 +135,12 @@ def test_duplication_error(self):
with self.assertRaises(RuntimeError):
self.make_instance('blub', registry=registry)

def test_renamed(self) -> None:
registry = dict()
instance = self.make_instance('hugo', registry=registry)
renamed_instance = instance.renamed('ilse', registry=registry)
self.assert_equal_instance_except_id(instance, renamed_instance)

def test_conversion(self):
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
Expand All @@ -155,12 +165,11 @@ def class_to_test(self):
def make_kwargs(self):
return {'data': 'blubber', 'test_dict': {'foo': 'bar', 'no': 17.3}}

def assert_equal_instance(self, lhs, rhs):
self.assertEqual(lhs.identifier, rhs.identifier)
def assert_equal_instance_except_id(self, lhs, rhs):
self.assertEqual(lhs.data, rhs.data)


class DummyPulseTemplateSerializationtests(SerializableTests, unittest.TestCase):
class DummyPulseTemplateSerializationTests(SerializableTests, unittest.TestCase):
@property
def class_to_test(self):
return DummyPulseTemplate
Expand All @@ -176,9 +185,8 @@ def make_kwargs(self):
'integrals': {'default': ExpressionScalar(19.231)}
}

def assert_equal_instance(self, lhs, rhs):
def assert_equal_instance_except_id(self, lhs, rhs):
self.assertEqual(lhs.compare_key, rhs.compare_key)
self.assertEqual(lhs.identifier, rhs.identifier)


class FileSystemBackendTest(unittest.TestCase):
Expand Down Expand Up @@ -603,8 +611,8 @@ def test_getitem(self):
self.assertIn('asdf', self.storage.temporary_storage)

def test_setitem(self):
instance_1 = DummySerializable(identifier='my_id_1')
instance_2 = DummySerializable(identifier='my_id_2')
instance_1 = DummySerializable(identifier='my_id', registry=dict())
instance_2 = DummySerializable(identifier='my_id', registry=dict())

def overwrite(identifier, serializable):
self.assertFalse(overwrite.called)
Expand All @@ -624,6 +632,11 @@ def overwrite(identifier, serializable):
with self.assertRaisesRegex(RuntimeError, 'assigned twice'):
self.storage['my_id'] = instance_2

def test_setitem_different_id(self) -> None:
serializable = DummySerializable(identifier='my_id', registry=dict())
with self.assertRaisesRegex(ValueError, "different than its own internal identifier"):
self.storage['a_totally_different_id'] = serializable

def test_overwrite(self):

encode_mock = mock.Mock(return_value='asd')
Expand All @@ -639,9 +652,9 @@ def test_overwrite(self):
self.assertEqual(self.storage._temporary_storage, {'my_id': self.storage.StorageEntry('asd', instance)})

def test_write_through(self):
instance_1 = DummySerializable(identifier='my_id_1')
inner_instance = DummySerializable(identifier='my_id_2')
outer_instance = NestedDummySerializable(inner_instance, identifier='my_id_3')
instance_1 = DummySerializable(identifier='my_id_1', registry=dict())
inner_instance = DummySerializable(identifier='my_id_2', registry=dict())
outer_instance = NestedDummySerializable(inner_instance, identifier='my_id_3', registry=dict())

def get_expected():
return {identifier: serialized
Expand Down Expand Up @@ -714,11 +727,12 @@ def test_set_to_default_registry(self) -> None:

def test_beautified_json(self) -> None:
data = {'e': 89, 'b': 151, 'c': 123515, 'a': 123, 'h': 2415}
template = DummySerializable(data=data)
template = DummySerializable(data=data, identifier="foo")
pulse_storage = PulseStorage(DummyStorageBackend())
pulse_storage['foo'] = template

expected = """{
\"#identifier\": \"foo\",
\"#type\": \"""" + DummySerializable.get_type_identifier() + """\",
\"data\": {
\"a\": 123,
Expand Down

0 comments on commit 895f48b

Please sign in to comment.