Skip to content

Commit

Permalink
Merge branch 'issues/250_serialization' into issues/272_PulseStorage_…
Browse files Browse the repository at this point in the history
…id_inconsistencies
  • Loading branch information
terrorfisch committed Jul 12, 2018
2 parents bb935b8 + 5a5b5d7 commit 81b7891
Show file tree
Hide file tree
Showing 4 changed files with 428 additions and 20 deletions.
130 changes: 117 additions & 13 deletions qctoolkit/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

from abc import ABCMeta, abstractmethod
from typing import Dict, Any, Optional, NamedTuple, Union
from typing import Dict, Any, Optional, NamedTuple, Union, Mapping, MutableMapping, Set
import os
import zipfile
import tempfile
Expand All @@ -23,7 +23,8 @@
from qctoolkit.utils.types import DocStringABCMeta

__all__ = ["StorageBackend", "FilesystemBackend", "ZipFileBackend", "CachingBackend", "Serializable", "Serializer",
"AnonymousSerializable", "DictBackend", "JSONSerializableEncoder", "JSONSerializableDecoder", "PulseStorage"]
"AnonymousSerializable", "DictBackend", "JSONSerializableEncoder", "JSONSerializableDecoder", "PulseStorage",
"convert_pulses_in_storage", "convert_stored_pulse_in_storage"]


class StorageBackend(metaclass=ABCMeta):
Expand Down Expand Up @@ -91,6 +92,14 @@ def delete(self, identifier: str) -> None:
def __delitem__(self, identifier: str) -> None:
self.delete(identifier)

@abstractmethod
def list_contents(self) -> Set[str]:
"""Return a listing of all available identifiers.
Returns:
List of all available identifiers.
"""


class FilesystemBackend(StorageBackend):
"""A StorageBackend implementation based on a regular filesystem.
Expand Down Expand Up @@ -141,6 +150,21 @@ def delete(self, identifier):
except FileNotFoundError as fnf:
raise KeyError(identifier) from fnf

def list_contents(self) -> Set[str]:
contents = set()
for dirpath, dirs, files in os.walk(self._root):
contents = contents | {filename
for filename, ext in (os.path.splitext(file) for file in files)
if ext == '.json'}
break # abort after first iteration; FileSystemBackend doesn't allow for subdirectories anyway right now, so this is a safeguard

# pref = os.path.commonprefix((dirpath, self._root))
# dir_rel_path = dirpath[len(pref):]
# contents = contents | {os.path.join(dir_rel_path, filename)
# for filename, ext in (os.path.splitext(file) for file in files)
# if ext == '.json'}
return contents


class ZipFileBackend(StorageBackend):
"""A StorageBackend implementation based on a single zip file.
Expand Down Expand Up @@ -225,6 +249,12 @@ def _update(self, filename: str, data: Optional[str]) -> None:
with zipfile.ZipFile(self._root, mode='a', compression=zipfile.ZIP_DEFLATED) as zf:
zf.writestr(filename, data)

def list_contents(self) -> Set[str]:
with zipfile.ZipFile(self._root, 'r') as myzip:
return set(filename
for filename, ext in (os.path.splitext(file) for file in myzip.namelist())
if ext == '.json')


class CachingBackend(StorageBackend):
"""Adds naive memory caching functionality to another StorageBackend.
Expand Down Expand Up @@ -265,6 +295,9 @@ def delete(self, identifier: str) -> None:
if identifier in self._cache:
del self._cache[identifier]

def list_contents(self) -> Set[str]:
return self._backend.list_contents()


class DictBackend(StorageBackend):
"""DictBackend uses a dictionary to store the data for convenience serialization."""
Expand All @@ -289,6 +322,9 @@ def storage(self) -> Dict[str, str]:
def delete(self, identifier: str) -> None:
del self._cache[identifier]

def list_contents(self) -> Set[str]:
return set(self._cache.keys())


def get_type_identifier(obj: Any) -> str:
"""Return a unique type identifier for any object.
Expand Down Expand Up @@ -318,13 +354,26 @@ def __new__(mcs, name, bases, dct):
return cls


default_pulse_registry = weakref.WeakValueDictionary()
default_pulse_registry = None


def get_default_pulse_registry() -> Union[weakref.WeakKeyDictionary, 'PulseStorage']:
def get_default_pulse_registry() -> MutableMapping:
return default_pulse_registry


def set_default_pulse_registry(new_default_registry: Optional[MutableMapping]) -> None:
global default_pulse_registry
default_pulse_registry = new_default_registry


def new_default_pulse_registry() -> None:
"""Sets a new empty default pulse registry.
The new registry is a newly created weakref.WeakValueDictionry().
"""
set_default_pulse_registry(weakref.WeakValueDictionary())


class Serializable(metaclass=SerializableMeta):
"""Any object that can be converted into a serialized representation for storage and back.
Expand All @@ -345,7 +394,7 @@ class Serializable(metaclass=SerializableMeta):
type_identifier_name = '#type'
identifier_name = '#identifier'

def __init__(self, identifier: Optional[str]=None, registry: Optional[dict]=None) -> None:
def __init__(self, identifier: Optional[str]=None, registry: Optional[MutableMapping]=None) -> None:
"""Initialize a Serializable.
Args:
Expand Down Expand Up @@ -616,7 +665,7 @@ def deserialize(self, representation: Union[str, Dict[str, Any]]) -> Serializabl
if isinstance(representation, str):
if representation in self.__subpulses:
return self.__subpulses[representation].serializable

if isinstance(representation, str):
repr_ = json.loads(self.__storage_backend.get(representation))
repr_['identifier'] = representation
Expand All @@ -626,12 +675,12 @@ def deserialize(self, representation: Union[str, Dict[str, Any]]) -> Serializabl
module_name, class_name = repr_['type'].rsplit('.', 1)
module = __import__(module_name, fromlist=[class_name])
class_ = getattr(module, class_name)

repr_to_store = repr_.copy()
repr_.pop('type')

serializable = class_.deserialize(self, **repr_)

if 'identifier' in repr_:
identifier = repr_['identifier']
self.__subpulses[identifier] = self.__FileEntry(repr_, serializable)
Expand All @@ -657,7 +706,7 @@ def _load_and_deserialize(self, identifier: str) -> StorageEntry:
serialization = self._storage_backend[identifier]
serializable = self._deserialize(serialization)
self._temporary_storage[identifier] = PulseStorage.StorageEntry(serialization=serialization,
serializable=serializable)
serializable=serializable)
return self._temporary_storage[identifier]

@property
Expand Down Expand Up @@ -745,7 +794,7 @@ def set_to_default_registry(self) -> None:

class JSONSerializableDecoder(json.JSONDecoder):

def __init__(self, storage, *args, **kwargs) -> None:
def __init__(self, storage: Mapping, *args, **kwargs) -> None:
super().__init__(*args, object_hook=self.filter_serializables, **kwargs)

self.storage = storage
Expand All @@ -766,14 +815,23 @@ def filter_serializables(self, obj_dict) -> Any:

else:
deserialization_callback = SerializableMeta.deserialization_callbacks[type_identifier]
return deserialization_callback(identifier=obj_identifier, **obj_dict)

# if the storage is the default registry, we would get conflicts when the Serializable tries to register
# itself on construction. Pass an empty dict as registry keyword argument in this case.
# calling PulseStorage objects will take care of registering.
# (solution to issue #301: https://github.com/qutech/qc-toolkit/issues/301 )
registry = None
if get_default_pulse_registry() is self.storage:
registry = dict()

return deserialization_callback(identifier=obj_identifier, registry=registry, **obj_dict)
return obj_dict


class JSONSerializableEncoder(json.JSONEncoder):
""""""

def __init__(self, storage, *args, **kwargs) -> None:
def __init__(self, storage: MutableMapping, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.storage = storage
Expand Down Expand Up @@ -817,3 +875,49 @@ def default(self, o: Any) -> Any:
return list(o)
else:
return super().default(o)


def convert_stored_pulse_in_storage(identifier: str, source_storage: StorageBackend, dest_storage: StorageBackend) -> None:
"""Converts a pulse from the old to the new serialization format.
The pulse with the given identifier is completely (including subpulses) converted from the old serialization format
read from a given source storage to the new serialization format and written to a given destination storage.
Args:
identifier (str): The identifier of the pulse to convert.
source_storage (StorageBackend): A StorageBackend containing the pulse identified by the identifier argument in the old serialization format.
dest_storage (StorageBackend): A StorageBackend the converted pulse will be written to in the new serialization format.
Raises:
ValueError: if the dest_storage StorageBackend contains identifiers also assigned in source_storage.
"""
if dest_storage.list_contents().intersection(source_storage.list_contents()):
raise ValueError("dest_storage already contains pulses with the same ids. Aborting to prevent inconsistencies for duplicate keys.")
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
serializer = Serializer(source_storage)
pulse_storage = PulseStorage(dest_storage)
serializable = serializer.deserialize(identifier)
pulse_storage.overwrite(identifier, serializable)


def convert_pulses_in_storage(source_storage: StorageBackend, dest_storage: StorageBackend) -> None:
"""Converts all pulses from the old to the new serialization format.
All pulses in a given source storage are completely (including subpulses) converted from the old serialization format
to the new serialization format and written to a given destination storage.
Args:
source_storage (StorageBackend): A StorageBackend containing pulses in the old serialization format.
dest_storage (StorageBackend): A StorageBackend the converted pulses will be written to in the new serialization format.
Raises:
ValueError: if the dest_storage StorageBackend contains identifiers also assigned in source_storage.
"""
if dest_storage.list_contents().intersection(source_storage.list_contents()):
raise ValueError("dest_storage already contains pulses with the same ids. Aborting to prevent inconsistencies for duplicate keys.")
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
serializer = Serializer(source_storage)
pulse_storage = PulseStorage(dest_storage)
for identifier in source_storage.list_contents():
serializable = serializer.deserialize(identifier)
pulse_storage.overwrite(identifier, serializable)
2 changes: 2 additions & 0 deletions tests/pulses/sequencing_dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ def requires_stop(self, parameters: Dict[str, Parameter], conditions: Dict[str,

def get_serialization_data(self, serializer: Optional['Serializer']=None) -> Dict[str, Any]:
data = super().get_serialization_data(serializer=serializer)
if serializer: # compatibility with old serialization routines
data = dict()
data['requires_stop'] = self.requires_stop_
data['is_interruptable'] = self.is_interruptable
data['parameter_names'] = self.parameter_names
Expand Down
7 changes: 5 additions & 2 deletions tests/serialization_dummies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Dict, Any, Callable
from typing import Union, Dict, Any, Callable, Set

from qctoolkit.serialization import Serializer, Serializable, StorageBackend

Expand Down Expand Up @@ -27,9 +27,12 @@ def exists(self, identifier: str) -> bool:
self.times_exists_called += 1
return identifier in self.stored_items

def delete(self, identifier: str):
def delete(self, identifier: str) -> None:
del self.stored_items[identifier]

def list_contents(self) -> Set[str]:
return set(self.stored_items.keys())


class DummySerializer(Serializer):

Expand Down

0 comments on commit 81b7891

Please sign in to comment.