From e3f334d42f391f3a247610b2c0a3c2226679a1d9 Mon Sep 17 00:00:00 2001 From: Lukas Prediger Date: Sun, 17 Jun 2018 19:34:48 +0200 Subject: [PATCH 1/2] Fix for pulse registration in Serialziable.__init__ --- qctoolkit/serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qctoolkit/serialization.py b/qctoolkit/serialization.py index dcc57d529..ef903dde1 100644 --- a/qctoolkit/serialization.py +++ b/qctoolkit/serialization.py @@ -358,7 +358,7 @@ def __init__(self, identifier: Optional[str]=None, registration: weakref.WeakVal raise ValueError("Identifier must not be empty.") self.__identifier = identifier - if identifier and registration: + if identifier and registration is not None: if identifier in registration: raise RuntimeError('Pulse with name already exists', identifier) else: From 789309a1683278086ae17f8d18a21da3c12e6d48 Mon Sep 17 00:00:00 2001 From: Lukas Prediger Date: Sun, 17 Jun 2018 19:40:49 +0200 Subject: [PATCH 2/2] PulseStorage can be conveniently replace default pulse registration. Added get_default_pulse_registration() method to query current default pulse registration. Added method PulseStorage.as_default_registry(), a context manager method that causes the PulseStorage object to replace the current default pulse registration for the duration of a "with:" block. Added method PulseStorage.set_to_default_registry() which permanently sets the PulseStorage object as the default pulse registration. --- qctoolkit/serialization.py | 19 +++++++++++++++++++ tests/serialization_tests.py | 15 ++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/qctoolkit/serialization.py b/qctoolkit/serialization.py index ef903dde1..409df9906 100644 --- a/qctoolkit/serialization.py +++ b/qctoolkit/serialization.py @@ -17,6 +17,7 @@ import json import weakref import warnings +from contextlib import contextmanager from qctoolkit.utils.types import DocStringABCMeta @@ -319,6 +320,10 @@ def __new__(mcs, name, bases, dct): default_pulse_registration = weakref.WeakValueDictionary() +def get_default_pulse_registration() -> Union[weakref.WeakKeyDictionary, 'PulseStorage']: + return default_pulse_registration + + class Serializable(metaclass=SerializableMeta): """Any object that can be converted into a serialized representation for storage and back. @@ -672,6 +677,20 @@ def clear(self) -> None: def __del__(self) -> None: self.flush() + @contextmanager + def as_default_registry(self) -> Any: + global default_pulse_registration + previous_registration = default_pulse_registration + default_pulse_registration = self + try: + yield self + finally: + default_pulse_registration = previous_registration + + def set_to_default_registry(self) -> None: + global default_pulse_registration + default_pulse_registration = self + class JSONSerializableDecoder(json.JSONDecoder): diff --git a/tests/serialization_tests.py b/tests/serialization_tests.py index 319faf22a..38d5378fd 100644 --- a/tests/serialization_tests.py +++ b/tests/serialization_tests.py @@ -12,7 +12,8 @@ from typing import Optional, Any from qctoolkit.serialization import FilesystemBackend, CachingBackend, Serializable, JSONSerializableEncoder,\ - ZipFileBackend, AnonymousSerializable, DictBackend, PulseStorage, JSONSerializableDecoder, Serializer + ZipFileBackend, AnonymousSerializable, DictBackend, PulseStorage, JSONSerializableDecoder, Serializer, get_default_pulse_registration + from tests.serialization_dummies import DummyStorageBackend @@ -506,6 +507,18 @@ def test_flush_on_destroy_object(self) -> None: self.assertIn('my_id_1', backend.stored_items) + def test_as_default_registry(self) -> None: + prev_reg = get_default_pulse_registration() + pulse_storage = PulseStorage(DummyStorageBackend()) + with pulse_storage.as_default_registry(): + self.assertIs(get_default_pulse_registration(), pulse_storage) + self.assertIs(get_default_pulse_registration(), prev_reg) + + def test_set_to_default_registry(self) -> None: + pulse_storage = PulseStorage(DummyStorageBackend()) + pulse_storage.set_to_default_registry() + self.assertIs(get_default_pulse_registration(), pulse_storage) + class JSONSerializableDecoderTests(unittest.TestCase): def test_filter_serializables(self):