diff --git a/qctoolkit/serialization.py b/qctoolkit/serialization.py index a1c8e09b6..d9cc30545 100644 --- a/qctoolkit/serialization.py +++ b/qctoolkit/serialization.py @@ -318,13 +318,26 @@ def __new__(mcs, name, bases, dct): return cls -default_pulse_registry = weakref.WeakValueDictionary() +default_pulse_registry = None -def get_default_pulse_registry() -> Union[weakref.WeakKeyDictionary, 'PulseStorage']: +def get_default_pulse_registry() -> Union[Dict, 'PulseStorage']: return default_pulse_registry +def set_default_pulse_registry(new_default_registry: Optional[Union[Dict, 'PulseStorage']]) -> None: + global default_pulse_registry + default_pulse_registry = new_default_registry + + +def new_default_pulse_registry() -> None: + """Sets a new empty default pulse registry. + + The new registry is a newly created weakref.WeakValueDictionry(). + """ + set_default_pulse_registry(weakref.WeakValueDictionary()) + + class Serializable(metaclass=SerializableMeta): """Any object that can be converted into a serialized representation for storage and back. diff --git a/tests/serialization_tests.py b/tests/serialization_tests.py index 718623357..229b5dbdd 100644 --- a/tests/serialization_tests.py +++ b/tests/serialization_tests.py @@ -13,7 +13,7 @@ from qctoolkit.serialization import FilesystemBackend, CachingBackend, Serializable, JSONSerializableEncoder,\ ZipFileBackend, AnonymousSerializable, DictBackend, PulseStorage, JSONSerializableDecoder, Serializer,\ - get_default_pulse_registry, SerializableMeta + get_default_pulse_registry, set_default_pulse_registry, new_default_pulse_registry, SerializableMeta from qctoolkit.expressions import ExpressionScalar @@ -115,17 +115,19 @@ def test_serialization_and_deserialization(self): storage['blub'] = instance storage.clear() + set_default_pulse_registry(dict()) other_instance = typing.cast(self.class_to_test, storage['blub']) self.assert_equal_instance(instance, other_instance) self.assertIs(registry['blub'], instance) self.assertIs(get_default_pulse_registry()['blub'], other_instance) + set_default_pulse_registry(None) def test_duplication_error(self): registry = dict() - instance = self.make_instance('blub', registry=registry) + self.make_instance('blub', registry=registry) with self.assertRaises(RuntimeError): self.make_instance('blub', registry=registry) @@ -481,6 +483,33 @@ def get_type_identifier(cls): self.assertEqual(SerializableMeta.deserialization_callbacks['foo.bar.never'], NativeDeserializable) +class DefaultPulseRegistryManipulationTests(unittest.TestCase): + + def test_get_set_default_pulse_registry(self) -> None: + # store previous registry + previous_registry = get_default_pulse_registry() + + registry = dict() + set_default_pulse_registry(registry) + self.assertIs(get_default_pulse_registry(), registry) + + # restore previous registry + set_default_pulse_registry(previous_registry) + self.assertIs(get_default_pulse_registry(), previous_registry) + + def test_new_default_pulse_registry(self) -> None: + # store previous registry + previous_registry = get_default_pulse_registry() + + new_default_pulse_registry() + self.assertIsNotNone(get_default_pulse_registry()) + self.assertIsNot(get_default_pulse_registry(), previous_registry) + + # restore previous registry + set_default_pulse_registry(previous_registry) + self.assertIs(get_default_pulse_registry(), previous_registry) + + class PulseStorageTests(unittest.TestCase): def setUp(self): self.backend = DummyStorageBackend() @@ -631,8 +660,7 @@ def test_set_to_default_registry(self) -> None: pulse_storage.set_to_default_registry() self.assertIs(get_default_pulse_registry(), pulse_storage) finally: - import qctoolkit.serialization - qctoolkit.serialization.default_pulse_registry = previous_default_registry + set_default_pulse_registry(previous_default_registry) def test_beautified_json(self) -> None: data = {'e': 89, 'b': 151, 'c': 123515, 'a': 123, 'h': 2415}