Skip to content

Commit

Permalink
Tests + fixes for RepetitionPT.create_program()
Browse files Browse the repository at this point in the history
  • Loading branch information
lumip committed Jul 19, 2018
1 parent 854199b commit 4b8f432
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 14 deletions.
10 changes: 8 additions & 2 deletions qctoolkit/_program/_loop.py
Expand Up @@ -46,9 +46,15 @@ def __init__(self,
def compare_key(self) -> Tuple:
return self._waveform, self.repetition_count, tuple(c.compare_key for c in self)

def append_child(self, **kwargs) -> None:
def append_child(self, loop: Optional['Loop']=None, **kwargs) -> None:
# do not invalidate but update cached duration
super().__setitem__(slice(len(self), len(self)), (kwargs, ))
if loop is not None:
if kwargs:
raise ValueError("Cannot pass a Loop object and Loop constructor arguments at the same time in append_child")
arg = (loop,)
else:
arg = (kwargs,)
super().__setitem__(slice(len(self), len(self)), arg)
self._invalidate_duration(body_duration_increment=self[-1].duration)

def _invalidate_duration(self, body_duration_increment=None):
Expand Down
6 changes: 4 additions & 2 deletions qctoolkit/pulses/repetition_pulse_template.py
Expand Up @@ -144,10 +144,12 @@ def _internal_create_program(self,
volatile_parameters,
measurement_mapping,
channel_mapping)
if subprogram:
if subprogram is not None:
measurements = self.get_measurement_windows(parameters, measurement_mapping)
program = Loop(measurements=measurements, repetition_count=repetition_count)
program.append_child(subprogram)
program.append_child(loop=subprogram)
# todo (2018-07-19): could in some circumstances possibly just multiply subprogram repetition count
# could be tricky if any repetition count is volatile ? check later and optimize if necessary
return program
return None

Expand Down
1 change: 1 addition & 0 deletions tests/pulses/pulse_template_tests.py
Expand Up @@ -207,6 +207,7 @@ def test_internal_create_program(self) -> None:
template = AtomicPulseTemplateStub(waveform=wf, measurements=measurement_windows, parameter_names={'foo'})
parameters = {'foo': ConstantParameter(7.2)}
channel_mapping = {'B': 'A'}
# todo (2018-07-12): test for volatile paramters
program = template._internal_create_program(parameters=parameters,
volatile_parameters=dict(),
measurement_mapping={'M': 'N'},
Expand Down
162 changes: 158 additions & 4 deletions tests/pulses/repetition_pulse_template_tests.py
Expand Up @@ -3,6 +3,7 @@

import numpy as np

from qctoolkit._program._loop import Loop
from qctoolkit.expressions import Expression
from qctoolkit.pulses.repetition_pulse_template import RepetitionPulseTemplate,ParameterNotIntegerException, RepetitionWaveform
from qctoolkit.pulses.parameters import ParameterNotProvidedException, ParameterConstraintViolation, ConstantParameter, \
Expand Down Expand Up @@ -91,6 +92,10 @@ def test_integral(self) -> None:
template = RepetitionPulseTemplate(dummy, Expression('2+m'))
self.assertEqual([Expression('(2+m)*(foo+2)'), Expression('(2+m)*(k*3+x**2)')], template.integral)

def test_parameter_names_param_only_in_constraint(self) -> None:
pt = RepetitionPulseTemplate(DummyPulseTemplate(parameter_names={'a'}), 'n', parameter_constraints=['a<c'])
self.assertEqual(pt.parameter_names, {'a','c', 'n'})


class RepetitionPulseTemplateSequencingTests(unittest.TestCase):

Expand Down Expand Up @@ -118,6 +123,159 @@ def test_requires_stop_declaration(self) -> None:
parameter.requires_stop_ = parameter_requires_stop
self.assertEqual(parameter_requires_stop, t.requires_stop(parameters, conditions))


def setUp(self) -> None:
self.body = DummyPulseTemplate()
self.repetitions = 'foo'
self.template = RepetitionPulseTemplate(self.body, self.repetitions, parameter_constraints=['foo<9'])
self.sequencer = DummySequencer()
self.block = DummyInstructionBlock()

def test_create_program_constant(self) -> None:
repetitions = 3
body_program = Loop()
body = DummyPulseTemplate(duration=2.0, program=body_program)
t = RepetitionPulseTemplate(body, repetitions, parameter_constraints=['foo<9'])
parameters = {'foo': 8}
volatile_parameters = {}
measurement_mapping = {'my': 'thy'}
channel_mapping = {}
program = t.create_program(parameters=parameters,
volatile_parameters=volatile_parameters,
measurement_mapping=measurement_mapping,
channel_mapping=channel_mapping)

self.assertEqual(repetitions, program.repetition_count)
self.assertEqual((parameters, volatile_parameters, measurement_mapping, channel_mapping), body.create_program_calls[-1])
self.assertEqual([body_program], program.children)

def test_create_program_declaration_success(self) -> None:
repetitions = "foo"
body_program = Loop()
body = DummyPulseTemplate(duration=2.0, program=body_program)
t = RepetitionPulseTemplate(body, repetitions, parameter_constraints=['foo<9'])
parameters = dict(foo=ConstantParameter(3))
volatile_parameters = {}
measurement_mapping = dict(moth='fire')
channel_mapping = dict(asd='f')
program = t.create_program(parameters=parameters,
volatile_parameters=volatile_parameters,
measurement_mapping=measurement_mapping,
channel_mapping=channel_mapping)

self.assertEqual(parameters[repetitions].get_value(), program.repetition_count)
self.assertEqual((parameters, volatile_parameters, measurement_mapping, channel_mapping),
body.create_program_calls[-1])
self.assertEqual([body_program], program.children)

def test_create_program_declaration_exceeds_bounds(self) -> None:
repetitions = "foo"
body_program = Loop()
body = DummyPulseTemplate(duration=2.0, program=body_program)
t = RepetitionPulseTemplate(body, repetitions, parameter_constraints=['foo<9'])
parameters = dict(foo=ConstantParameter(9))
volatile_parameters = {}
measurement_mapping = dict(moth='fire')
channel_mapping = dict(asd='f')
with self.assertRaises(ParameterConstraintViolation):
t.create_program(parameters=parameters,
volatile_parameters=volatile_parameters,
measurement_mapping=measurement_mapping,
channel_mapping=channel_mapping)

def test_create_program_declaration_parameter_not_provided(self) -> None:
repetitions = "foo"
body_program = Loop()
body = DummyPulseTemplate(duration=2.0, program=body_program)
t = RepetitionPulseTemplate(body, repetitions, parameter_constraints=['foo<9'])
parameters = {}
volatile_parameters = {}
measurement_mapping = dict(moth='fire')
channel_mapping = dict(asd='f')
with self.assertRaises(ParameterNotProvidedException):
t.create_program(parameters=parameters,
volatile_parameters=volatile_parameters,
measurement_mapping=measurement_mapping,
channel_mapping=channel_mapping)

def test_create_program_declaration_parameter_value_not_whole(self) -> None:
repetitions = "foo"
body_program = Loop()
body = DummyPulseTemplate(duration=2.0, program=body_program)
t = RepetitionPulseTemplate(body, repetitions, parameter_constraints=['foo<9'])
parameters = dict(foo=ConstantParameter(3.3))
volatile_parameters = {}
measurement_mapping = dict(moth='fire')
channel_mapping = dict(asd='f')
with self.assertRaises(ParameterNotIntegerException):
t.create_program(parameters=parameters,
volatile_parameters=volatile_parameters,
measurement_mapping=measurement_mapping,
channel_mapping=channel_mapping)

def test_create_program_rep_count_zero_constant(self) -> None:
repetitions = 0
body_program = Loop()
body = DummyPulseTemplate(duration=2.0, program=body_program)

# suppress warning about 0 repetitions on construction here, we are only interested in correct behavior during sequencing (i.e., do nothing)
with warnings.catch_warnings(record=True):
t = RepetitionPulseTemplate(body, repetitions)

parameters = {}
volatile_parameters = {}
measurement_mapping = dict(moth='fire')
channel_mapping = dict(asd='f')

program = t.create_program(parameters=parameters,
volatile_parameters=volatile_parameters,
measurement_mapping=measurement_mapping,
channel_mapping=channel_mapping)
self.assertIsNone(program)

def test_create_program_rep_count_zero_declaration(self) -> None:
repetitions = "foo"
body_program = Loop()
body = DummyPulseTemplate(duration=2.0, program=body_program)

# suppress warning about 0 repetitions on construction here, we are only interested in correct behavior during sequencing (i.e., do nothing)
with warnings.catch_warnings(record=True):
t = RepetitionPulseTemplate(body, repetitions)

parameters = dict(foo=ConstantParameter(0))
volatile_parameters = {}
measurement_mapping = dict(moth='fire')
channel_mapping = dict(asd='f')

program = t.create_program(parameters=parameters,
volatile_parameters=volatile_parameters,
measurement_mapping=measurement_mapping,
channel_mapping=channel_mapping)
self.assertIsNone(program)

def test_rep_count_neg_declaration(self) -> None:
repetitions = "foo"
body_program = Loop()
body = DummyPulseTemplate(duration=2.0, program=body_program)

# suppress warning about 0 repetitions on construction here, we are only interested in correct behavior during sequencing (i.e., do nothing)
with warnings.catch_warnings(record=True):
t = RepetitionPulseTemplate(body, repetitions)

parameters = dict(foo=ConstantParameter(-1))
volatile_parameters = {}
measurement_mapping = dict(moth='fire')
channel_mapping = dict(asd='f')

program = t.create_program(parameters=parameters,
volatile_parameters=volatile_parameters,
measurement_mapping=measurement_mapping,
channel_mapping=channel_mapping)
self.assertIsNone(program)


class RepetitionPulseTemplateOldSequencingTests(unittest.TestCase):

def setUp(self) -> None:
self.body = DummyPulseTemplate()
self.repetitions = 'foo'
Expand Down Expand Up @@ -187,10 +345,6 @@ def test_build_sequence_declaration_parameter_value_not_whole(self) -> None:
self.template.build_sequence(self.sequencer, parameters, conditions, {}, {}, self.block)
self.assertFalse(self.sequencer.sequencing_stacks)

def test_parameter_names_param_only_in_constraint(self) -> None:
pt = RepetitionPulseTemplate(DummyPulseTemplate(parameter_names={'a'}), 'n', parameter_constraints=['a<c'])
self.assertEqual(pt.parameter_names, {'a','c', 'n'})

def test_rep_count_zero_constant(self) -> None:
repetitions = 0
parameters = {}
Expand Down
20 changes: 14 additions & 6 deletions tests/pulses/sequencing_dummies.py
Expand Up @@ -5,6 +5,7 @@
import numpy

"""LOCAL IMPORTS"""
from qctoolkit._program._loop import Loop
from qctoolkit.utils.types import MeasurementWindow, ChannelID, TimeType, time_from_float
from qctoolkit.serialization import Serializer
from qctoolkit._program.waveforms import Waveform
Expand Down Expand Up @@ -136,7 +137,7 @@ def add_instruction(self, instruction: Instruction) -> None:

class DummyWaveform(Waveform):

def __init__(self, duration: float=0, sample_output: numpy.ndarray=None, defined_channels={'A'}) -> None:
def __init__(self, duration: float=0.0, sample_output: numpy.ndarray=None, defined_channels={'A'}) -> None:
super().__init__()
self.duration_ = time_from_float(duration)
self.sample_output = sample_output
Expand Down Expand Up @@ -299,6 +300,7 @@ def __init__(self,
measurement_names: Set[str] = set(),
measurements: list=list(),
integrals: Dict[ChannelID, ExpressionScalar]={'default': ExpressionScalar(0)},
program: Optional[Loop]=None,
identifier=None,
registry=None) -> None:
super().__init__(identifier=identifier, measurements=measurements)
Expand All @@ -314,6 +316,8 @@ def __init__(self,
self.build_waveform_calls = []
self.measurement_names_ = set(measurement_names)
self._integrals = integrals
self.create_program_calls = []
self._program = program
self._register(registry=registry)

@property
Expand All @@ -324,10 +328,6 @@ def duration(self):
def parameter_names(self) -> Set[str]:
return set(self.parameter_names_)

def get_measurement_windows(self, parameters: Dict[str, Parameter] = None) -> List[MeasurementWindow]:
"""Return all measurement windows defined in this PulseTemplate."""
raise NotImplementedError()

@property
def build_sequence_calls(self):
return len(self.build_sequence_arguments)
Expand All @@ -353,13 +353,21 @@ def build_sequence(self,
instruction_block: InstructionBlock):
self.build_sequence_arguments.append((sequencer,parameters,conditions, measurement_mapping, channel_mapping, instruction_block))

def create_program(self,
parameters: Dict[str, Parameter],
volatile_parameters: Set[str],
measurement_mapping: Dict[str, Optional[str]],
channel_mapping: Dict[ChannelID, Optional[ChannelID]]) -> Optional[Loop]:
self.create_program_calls.append((parameters, volatile_parameters, measurement_mapping, channel_mapping))
return self._program

def build_waveform(self,
parameters: Dict[str, Parameter],
channel_mapping: Dict[ChannelID, ChannelID]):
self.build_waveform_calls.append((parameters, channel_mapping))
if self.waveform or self.waveform is None:
return self.waveform
return DummyWaveform(duration=self.duration, defined_channels=self.defined_channels)
return DummyWaveform(duration=self.duration.evaluate_numeric(**parameters), defined_channels=self.defined_channels)

def requires_stop(self, parameters: Dict[str, Parameter], conditions: Dict[str, Condition]) -> bool:
self.requires_stop_arguments.append((parameters,conditions))
Expand Down

0 comments on commit 4b8f432

Please sign in to comment.