Skip to content

Commit

Permalink
Merge 0bca178 into 07d6e71
Browse files Browse the repository at this point in the history
  • Loading branch information
lumip committed Jul 12, 2018
2 parents 07d6e71 + 0bca178 commit 8be81f9
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 8 deletions.
21 changes: 17 additions & 4 deletions qctoolkit/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

from abc import ABCMeta, abstractmethod
from typing import Dict, Any, Optional, NamedTuple, Union
from typing import Dict, Any, Optional, NamedTuple, Union, MutableMapping
import os
import zipfile
import tempfile
Expand Down Expand Up @@ -318,13 +318,26 @@ def __new__(mcs, name, bases, dct):
return cls


default_pulse_registry = weakref.WeakValueDictionary()
default_pulse_registry = None


def get_default_pulse_registry() -> Union[weakref.WeakKeyDictionary, 'PulseStorage']:
def get_default_pulse_registry() -> MutableMapping:
return default_pulse_registry


def set_default_pulse_registry(new_default_registry: Optional[MutableMapping]) -> None:
global default_pulse_registry
default_pulse_registry = new_default_registry


def new_default_pulse_registry() -> None:
"""Sets a new empty default pulse registry.
The new registry is a newly created weakref.WeakValueDictionry().
"""
set_default_pulse_registry(weakref.WeakValueDictionary())


class Serializable(metaclass=SerializableMeta):
"""Any object that can be converted into a serialized representation for storage and back.
Expand All @@ -345,7 +358,7 @@ class Serializable(metaclass=SerializableMeta):
type_identifier_name = '#type'
identifier_name = '#identifier'

def __init__(self, identifier: Optional[str]=None, registry: Optional[dict]=None) -> None:
def __init__(self, identifier: Optional[str]=None, registry: Optional[MutableMapping]=None) -> None:
"""Initialize a Serializable.
Args:
Expand Down
36 changes: 32 additions & 4 deletions tests/serialization_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from qctoolkit.serialization import FilesystemBackend, CachingBackend, Serializable, JSONSerializableEncoder,\
ZipFileBackend, AnonymousSerializable, DictBackend, PulseStorage, JSONSerializableDecoder, Serializer,\
get_default_pulse_registry, SerializableMeta
get_default_pulse_registry, set_default_pulse_registry, new_default_pulse_registry, SerializableMeta

from qctoolkit.expressions import ExpressionScalar

Expand Down Expand Up @@ -115,17 +115,19 @@ def test_serialization_and_deserialization(self):
storage['blub'] = instance

storage.clear()
set_default_pulse_registry(dict())

other_instance = typing.cast(self.class_to_test, storage['blub'])
self.assert_equal_instance(instance, other_instance)

self.assertIs(registry['blub'], instance)
self.assertIs(get_default_pulse_registry()['blub'], other_instance)
set_default_pulse_registry(None)

def test_duplication_error(self):
registry = dict()

instance = self.make_instance('blub', registry=registry)
self.make_instance('blub', registry=registry)
with self.assertRaises(RuntimeError):
self.make_instance('blub', registry=registry)

Expand Down Expand Up @@ -481,6 +483,33 @@ def get_type_identifier(cls):
self.assertEqual(SerializableMeta.deserialization_callbacks['foo.bar.never'], NativeDeserializable)


class DefaultPulseRegistryManipulationTests(unittest.TestCase):

def test_get_set_default_pulse_registry(self) -> None:
# store previous registry
previous_registry = get_default_pulse_registry()

registry = dict()
set_default_pulse_registry(registry)
self.assertIs(get_default_pulse_registry(), registry)

# restore previous registry
set_default_pulse_registry(previous_registry)
self.assertIs(get_default_pulse_registry(), previous_registry)

def test_new_default_pulse_registry(self) -> None:
# store previous registry
previous_registry = get_default_pulse_registry()

new_default_pulse_registry()
self.assertIsNotNone(get_default_pulse_registry())
self.assertIsNot(get_default_pulse_registry(), previous_registry)

# restore previous registry
set_default_pulse_registry(previous_registry)
self.assertIs(get_default_pulse_registry(), previous_registry)


class PulseStorageTests(unittest.TestCase):
def setUp(self):
self.backend = DummyStorageBackend()
Expand Down Expand Up @@ -631,8 +660,7 @@ def test_set_to_default_registry(self) -> None:
pulse_storage.set_to_default_registry()
self.assertIs(get_default_pulse_registry(), pulse_storage)
finally:
import qctoolkit.serialization
qctoolkit.serialization.default_pulse_registry = previous_default_registry
set_default_pulse_registry(previous_default_registry)

def test_beautified_json(self) -> None:
data = {'e': 89, 'b': 151, 'c': 123515, 'a': 123, 'h': 2415}
Expand Down

0 comments on commit 8be81f9

Please sign in to comment.