Skip to content

Commit

Permalink
Merge 18634f7 into 07d6e71
Browse files Browse the repository at this point in the history
  • Loading branch information
lumip committed Jul 12, 2018
2 parents 07d6e71 + 18634f7 commit 92996e2
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 5 deletions.
19 changes: 14 additions & 5 deletions qctoolkit/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

from abc import ABCMeta, abstractmethod
from typing import Dict, Any, Optional, NamedTuple, Union
from typing import Dict, Any, Optional, NamedTuple, Union, Mapping, MutableMapping
import os
import zipfile
import tempfile
Expand Down Expand Up @@ -650,7 +650,7 @@ def _load_and_deserialize(self, identifier: str) -> StorageEntry:
serialization = self._storage_backend[identifier]
serializable = self._deserialize(serialization)
self._temporary_storage[identifier] = PulseStorage.StorageEntry(serialization=serialization,
serializable=serializable)
serializable=serializable)
return self._temporary_storage[identifier]

@property
Expand Down Expand Up @@ -730,7 +730,7 @@ def set_to_default_registry(self) -> None:

class JSONSerializableDecoder(json.JSONDecoder):

def __init__(self, storage, *args, **kwargs) -> None:
def __init__(self, storage: Mapping, *args, **kwargs) -> None:
super().__init__(*args, object_hook=self.filter_serializables, **kwargs)

self.storage = storage
Expand All @@ -751,14 +751,23 @@ def filter_serializables(self, obj_dict) -> Any:

else:
deserialization_callback = SerializableMeta.deserialization_callbacks[type_identifier]
return deserialization_callback(identifier=obj_identifier, **obj_dict)

# if the storage is the default registry, we would get conflicts when the Serializable tries to register
# itself on construction. Pass an empty dict as registry keyword argument in this case.
# calling PulseStorage objects will take care of registering.
# (solution to issue #301: https://github.com/qutech/qc-toolkit/issues/301 )
registry = None
if get_default_pulse_registry() is self.storage:
registry = dict()

return deserialization_callback(identifier=obj_identifier, registry=registry, **obj_dict)
return obj_dict


class JSONSerializableEncoder(json.JSONEncoder):
""""""

def __init__(self, storage, *args, **kwargs) -> None:
def __init__(self, storage: MutableMapping, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.storage = storage
Expand Down
57 changes: 57 additions & 0 deletions tests/serialization_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,63 @@ def test_delitem(self):
self.assertEqual({}, backend.stored_items)
self.assertEqual(pulse_storage.temporary_storage, {})

def test_deserialize_storage_is_default_registry(self) -> None:
backend = DummyStorageBackend()

# fill backend
serializable = DummySerializable(identifier='peter', registry=dict())
pulse_storage = PulseStorage(backend)
pulse_storage['peter'] = serializable
del pulse_storage

# try to deserialize while PulseStorage is default registry
pulse_storage = PulseStorage(backend)
with pulse_storage.as_default_registry():
deserialized = pulse_storage['peter']
self.assertEqual(deserialized, serializable)

def test_deserialize_storage_is_not_default_registry_id_free(self) -> None:
backend = DummyStorageBackend()

# fill backend
serializable = DummySerializable(identifier='peter', registry=dict())
pulse_storage = PulseStorage(backend)
pulse_storage['peter'] = serializable
del pulse_storage

pulse_storage = PulseStorage(backend)
deserialized = pulse_storage['peter']
self.assertEqual(deserialized, serializable)

def test_deserialize_storage_is_not_default_registry_id_occupied(self) -> None:
backend = DummyStorageBackend()

# fill backend
serializable = DummySerializable(identifier='peter')
pulse_storage = PulseStorage(backend)
pulse_storage['peter'] = serializable
del pulse_storage

pulse_storage = PulseStorage(backend)
with self.assertRaisesRegex(RuntimeError, "Pulse with name already exists"):
pulse_storage['peter']

def test_deserialize_twice_same_object_storage_is_default_registry(self) -> None:
backend = DummyStorageBackend()

# fill backend
serializable = DummySerializable(identifier='peter', registry=dict())
pulse_storage = PulseStorage(backend)
pulse_storage['peter'] = serializable
del pulse_storage

# try to deserialize while PulseStorage is default registry
pulse_storage = PulseStorage(backend)
with pulse_storage.as_default_registry():
deserialized_1 = pulse_storage['peter']
deserialized_2 = pulse_storage['peter']
self.assertIs(deserialized_1, deserialized_2)
self.assertEqual(deserialized_1, serializable)

class JSONSerializableDecoderTests(unittest.TestCase):
def test_filter_serializables(self):
Expand Down

0 comments on commit 92996e2

Please sign in to comment.