Skip to content

Commit

Permalink
Add (de-)serialization for RepPT and ForPT measurements/parameter_con…
Browse files Browse the repository at this point in the history
…straints and their tests
  • Loading branch information
terrorfisch committed Mar 7, 2018
1 parent 052331c commit 1e92aba
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 40 deletions.
20 changes: 14 additions & 6 deletions qctoolkit/pulses/loop_pulse_template.py
Expand Up @@ -2,7 +2,7 @@
another PulseTemplate based on a condition."""


from typing import Dict, Set, Optional, Any, Union, Tuple, Generator, Sequence
from typing import Dict, Set, Optional, Any, Union, Tuple, Generator, Sequence, cast

import sympy

Expand Down Expand Up @@ -227,22 +227,30 @@ def get_serialization_data(self, serializer: Serializer) -> Dict[str, Any]:
data = dict(
body=serializer.dictify(self.body),
loop_range=self._loop_range.to_tuple(),
loop_index=self._loop_index
loop_index=self._loop_index,
)
if self.parameter_constraints:
data['parameter_constraints'] = [str(c) for c in self.parameter_constraints]
if self.measurement_declarations:
data['measurements'] = self.measurement_declarations
return data

@staticmethod
def deserialize(serializer: Serializer,
body: Dict[str, Any],
loop_range: Tuple,
loop_index: str,
identifier: Optional[str]=None) -> 'ForLoopPulseTemplate':
body = serializer.deserialize(body)
identifier: Optional[str]=None,
measurements: Optional=None,
parameter_constraints: Optional=None) -> 'ForLoopPulseTemplate':
body = cast(PulseTemplate, serializer.deserialize(body))
return ForLoopPulseTemplate(body=body,
identifier=identifier,
loop_range=loop_range,
loop_index=loop_index)

loop_index=loop_index,
measurements=measurements,
parameter_constraints=parameter_constraints
)


class WhileLoopPulseTemplate(LoopPulseTemplate):
Expand Down
4 changes: 2 additions & 2 deletions qctoolkit/pulses/pulse_template.py
Expand Up @@ -13,7 +13,7 @@

from qctoolkit.utils.types import ChannelID, DocStringABCMeta
from qctoolkit.serialization import Serializable
from qctoolkit.expressions import Expression
from qctoolkit.expressions import ExpressionScalar

from qctoolkit.pulses.conditions import Condition
from qctoolkit.pulses.parameters import Parameter
Expand Down Expand Up @@ -60,7 +60,7 @@ def is_interruptable(self) -> bool:

@property
@abstractmethod
def duration(self) -> Expression:
def duration(self) -> ExpressionScalar:
"""An expression for the duration of this PulseTemplate."""

@property
Expand Down
31 changes: 15 additions & 16 deletions qctoolkit/pulses/repetition_pulse_template.py
@@ -1,7 +1,7 @@
"""This module defines RepetitionPulseTemplate, a higher-order hierarchical pulse template that
represents the n-times repetition of another PulseTemplate."""

from typing import Dict, List, Set, Optional, Union, Any, Iterable, Tuple
from typing import Dict, List, Set, Optional, Union, Any, Iterable, Tuple, cast
from numbers import Real
from warnings import warn

Expand Down Expand Up @@ -58,14 +58,6 @@ def compare_key(self) -> Tuple[Any, int]:
def duration(self) -> float:
return self._body.duration*self._repetition_count

def get_measurement_windows(self) -> Iterable[MeasurementWindow]:
def get_measurement_window_generator(body: Waveform, repetition_count: int):
body_windows = list(body.get_measurement_windows())
for i in range(repetition_count):
for (name, begin, length) in body_windows:
yield (name, begin+i*body.duration, length)
return get_measurement_window_generator(self._body, self._repetition_count)

def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'RepetitionWaveform':
return RepetitionWaveform(body=self._body.unsafe_get_subset_for_channels(channels),
repetition_count=self._repetition_count)
Expand Down Expand Up @@ -167,21 +159,28 @@ def requires_stop(self,
return any(parameters[v].requires_stop for v in self.repetition_count.variables)

def get_serialization_data(self, serializer: Serializer) -> Dict[str, Any]:
return dict(
data = dict(
body=serializer.dictify(self.body),
repetition_count=self.repetition_count.original_expression,
parameter_constraints=self.parameter_constraints
repetition_count=self.repetition_count.original_expression
)
if self.parameter_constraints:
data['parameter_constraints'] = [str(c) for c in self.parameter_constraints]
if self.measurement_declarations:
data['measurements'] = self.measurement_declarations
return data

@staticmethod
def deserialize(serializer: Serializer,
repetition_count: Union[str, int],
body: Dict[str, Any],
parameter_constraints: List[str],
identifier: Optional[str]=None) -> 'RepetitionPulseTemplate':
body = serializer.deserialize(body)
parameter_constraints: Optional[List[str]]=None,
identifier: Optional[str]=None,
measurements: Optional[List[MeasurementDeclaration]]=None) -> 'RepetitionPulseTemplate':
body = cast(PulseTemplate, serializer.deserialize(body))
return RepetitionPulseTemplate(body, repetition_count,
identifier=identifier, parameter_constraints=parameter_constraints)
identifier=identifier,
parameter_constraints=parameter_constraints,
measurements=measurements)


class ParameterNotIntegerException(Exception):
Expand Down
55 changes: 53 additions & 2 deletions tests/pulses/loop_pulse_template_tests.py
Expand Up @@ -186,7 +186,7 @@ def test_requires_stop(self):
parameters['A'] = DummyParameter(requires_stop=True)
self.assertTrue(flt.requires_stop(parameters, dict()))

def test_get_serialization_data(self):
def test_get_serialization_data_minimal(self):

dt = DummyPulseTemplate(parameter_names={'i'})
flt = ForLoopPulseTemplate(body=dt, loop_index='i', loop_range=('A', 'B'))
Expand All @@ -203,7 +203,29 @@ def check_dt(to_dictify) -> str:
loop_index='i')
self.assertEqual(data, expected_data)

def test_deserialize(self):
def test_get_serialization_data_all_features(self):
measurements = [('a', 0, 1), ('b', 1, 1)]
parameter_constraints = ['foo < 3']

dt = DummyPulseTemplate(parameter_names={'i'})
flt = ForLoopPulseTemplate(body=dt, loop_index='i', loop_range=('A', 'B'),
measurements=measurements, parameter_constraints=parameter_constraints)

def check_dt(to_dictify) -> str:
self.assertIs(to_dictify, dt)
return 'dt'

serializer = DummySerializer(serialize_callback=check_dt)

data = flt.get_serialization_data(serializer)
expected_data = dict(body='dt',
loop_range=('A', 'B', 1),
loop_index='i',
measurements=measurements,
parameter_constraints=parameter_constraints)
self.assertEqual(data, expected_data)

def test_deserialize_minimal(self):
body_str = 'dt'
dt = DummyPulseTemplate(parameter_names={'i'})

Expand All @@ -225,6 +247,35 @@ def make_dt(ident: str):
self.assertEqual(flt.loop_index, 'i')
self.assertEqual(flt.loop_range.to_tuple(), ('A', 'B', 1))

def test_deserialize_all_features(self):
body_str = 'dt'
dt = DummyPulseTemplate(parameter_names={'i'})

measurements = [('a', 0, 1), ('b', 1, 1)]
parameter_constraints = ['foo < 3']

def make_dt(ident: str):
self.assertEqual(body_str, ident)
return ident

data = dict(body=body_str,
loop_range=('A', 'B', 1),
loop_index='i',
identifier='meh',
measurements=measurements,
parameter_constraints=parameter_constraints)

serializer = DummySerializer(deserialize_callback=make_dt)
serializer.subelements['dt'] = dt

flt = ForLoopPulseTemplate.deserialize(serializer, **data)
self.assertEqual(flt.identifier, 'meh')
self.assertIs(flt.body, dt)
self.assertEqual(flt.loop_index, 'i')
self.assertEqual(flt.loop_range.to_tuple(), ('A', 'B', 1))
self.assertEqual(flt.measurement_declarations, measurements)
self.assertEqual([str(c) for c in flt.parameter_constraints], parameter_constraints)




Expand Down
34 changes: 20 additions & 14 deletions tests/pulses/repetition_pulse_template_tests.py
Expand Up @@ -229,50 +229,55 @@ def setUp(self) -> None:
self.serializer = DummySerializer(deserialize_callback=lambda x: x['name'])
self.body = DummyPulseTemplate()

def test_get_serialization_data_constant(self) -> None:
def test_get_serialization_data_minimal(self) -> None:
repetition_count = 3
template = RepetitionPulseTemplate(self.body, repetition_count)
expected_data = dict(
body=str(id(self.body)),
repetition_count=repetition_count,
parameter_constraints=[]
)
data = template.get_serialization_data(self.serializer)
self.assertEqual(expected_data, data)

def test_get_serialization_data_declaration(self) -> None:
template = RepetitionPulseTemplate(self.body, 'foo', parameter_constraints=['foo<3'])
def test_get_serialization_data_all_features(self) -> None:
repetition_count = 'foo'
measurements = [('a', 0, 1), ('b', 1, 1)]
parameter_constraints = ['foo < 3']
template = RepetitionPulseTemplate(self.body, repetition_count,
measurements=measurements,
parameter_constraints=parameter_constraints)
expected_data = dict(
body=str(id(self.body)),
repetition_count='foo',
parameter_constraints=[ParameterConstraint('foo<3')]
repetition_count=repetition_count,
measurements=measurements,
parameter_constraints=parameter_constraints
)
data = template.get_serialization_data(self.serializer)
self.assertEqual(expected_data, data)

def test_deserialize_constant(self) -> None:
def test_deserialize_minimal(self) -> None:
repetition_count = 3
data = dict(
repetition_count=repetition_count,
body=dict(name=str(id(self.body))),
identifier='foo',
parameter_constraints=['bar<3']
identifier='foo'
)
# prepare dependencies for deserialization
self.serializer.subelements[str(id(self.body))] = self.body
# deserialize
template = RepetitionPulseTemplate.deserialize(self.serializer, **data)
# compare!
self.assertEqual(self.body, template.body)
self.assertIs(self.body, template.body)
self.assertEqual(repetition_count, template.repetition_count)
self.assertEqual([str(c) for c in template.parameter_constraints], ['bar < 3'])
#self.assertEqual([str(c) for c in template.parameter_constraints], ['bar < 3'])

def test_deserialize_declaration(self) -> None:
def test_deserialize_all_features(self) -> None:
data = dict(
repetition_count='foo',
body=dict(name=str(id(self.body))),
identifier='foo',
parameter_constraints=['foo<3']
parameter_constraints=['foo < 3'],
measurements=[('a', 0, 1), ('b', 1, 1)]
)
# prepare dependencies for deserialization
self.serializer.subelements[str(id(self.body))] = self.body
Expand All @@ -281,9 +286,10 @@ def test_deserialize_declaration(self) -> None:
template = RepetitionPulseTemplate.deserialize(self.serializer, **data)

# compare!
self.assertEqual(self.body, template.body)
self.assertIs(self.body, template.body)
self.assertEqual('foo', template.repetition_count)
self.assertEqual(template.parameter_constraints, [ParameterConstraint('foo < 3')])
self.assertEqual(template.measurement_declarations, data['measurements'])


class ParameterNotIntegerExceptionTests(unittest.TestCase):
Expand Down

0 comments on commit 1e92aba

Please sign in to comment.