Skip to content

Commit

Permalink
Fixed implementation of MappedPT._internal_create_program and tests +…
Browse files Browse the repository at this point in the history
… additional test for ForLoopPT._internal_create_program()
  • Loading branch information
lumip committed Aug 2, 2018
1 parent c025c99 commit 2580635
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 11 deletions.
12 changes: 7 additions & 5 deletions qctoolkit/pulses/mapping_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,16 @@ def build_sequence(self,
channel_mapping=self.get_updated_channel_mapping(channel_mapping),
instruction_block=instruction_block)

def _internal_create_program(self,
def _internal_create_program(self, *,
parameters: Dict[str, Parameter],
measurement_mapping: Dict[str, Optional[str]],
channel_mapping: Dict[ChannelID, Optional[ChannelID]]) -> Optional[Loop]:
channel_mapping: Dict[ChannelID, Optional[ChannelID]],
parent_loop: Loop) -> None:
# parameters are validated in map_parameters() call, no need to do it here again explicitly
return self.template.create_program(parameters=self.map_parameters(parameters),
measurement_mapping=self.get_updated_measurement_mapping(measurement_mapping),
channel_mapping=self.get_updated_channel_mapping(channel_mapping))
self.template._internal_create_program(parameters=self.map_parameters(parameters),
measurement_mapping=self.get_updated_measurement_mapping(measurement_mapping),
channel_mapping=self.get_updated_channel_mapping(channel_mapping),
parent_loop=parent_loop)

def build_waveform(self,
parameters: Dict[str, numbers.Real],
Expand Down
22 changes: 20 additions & 2 deletions tests/pulses/loop_pulse_template_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,26 @@ def test_create_program_missing_params(self) -> None:
self.assertIsNone(program._measurements)
self.assert_measurement_windows_equal({}, program.get_measurement_windows())

def test_create_program_body_none(self) -> None:
dt = DummyPulseTemplate(parameter_names={'i'}, waveform=None, duration=0,
measurements=[('b', 2, 1)])
flt = ForLoopPulseTemplate(body=dt, loop_index='i', loop_range=('a', 'b', 'c'),
measurements=[('A', 0, 1)], parameter_constraints=['c > 1'])

parameters = {'a': ConstantParameter(1), 'b': ConstantParameter(4), 'c': ConstantParameter(2)}
measurement_mapping = dict(A='B', b='b')
channel_mapping = dict(C='D')

program = Loop()
flt._internal_create_program(parameters=parameters,
measurement_mapping=measurement_mapping,
channel_mapping=channel_mapping,
parent_loop=program)

self.assertEqual(0, len(program.children))
self.assertEqual(1, program.repetition_count)
self.assertEqual([], program.children)

def test_create_program(self) -> None:
dt = DummyPulseTemplate(parameter_names={'i'}, waveform=DummyWaveform(duration=4.0), duration=4,
measurements=[('b', 2, 1)])
Expand All @@ -300,8 +320,6 @@ def test_create_program(self) -> None:
measurement_mapping = dict(A='B', b='b')
channel_mapping = dict(C='D')

#children = [Loop(waveform=DummyWaveform(duration=2.0))]
#program = Loop(children=children)
program = Loop()
flt._internal_create_program(parameters=parameters,
measurement_mapping=measurement_mapping,
Expand Down
153 changes: 149 additions & 4 deletions tests/pulses/mapping_pulse_template_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from qctoolkit.pulses.mapping_pulse_template import MissingMappingException,\
UnnecessaryMappingException, MappingPulseTemplate,\
AmbiguousMappingException, MappingCollisionException
from qctoolkit.pulses.parameters import ParameterNotProvidedException
from qctoolkit.pulses.parameters import ConstantParameter, ParameterConstraintViolation, ParameterConstraint
from qctoolkit.pulses.parameters import ConstantParameter, ParameterConstraintViolation, ParameterConstraint, ParameterNotProvidedException
from qctoolkit.expressions import Expression
from qctoolkit._program._loop import Loop

from tests.pulses.sequencing_dummies import DummyPulseTemplate, DummySequencer, DummyInstructionBlock
from tests.pulses.sequencing_dummies import DummyPulseTemplate, DummySequencer, DummyInstructionBlock, MeasurementWindowTestCase, DummyWaveform
from tests.serialization_tests import SerializableTests
from tests.serialization_dummies import DummySerializer

Expand Down Expand Up @@ -186,7 +186,151 @@ def test_integral(self) -> None:
self.assertEqual({'default': Expression('2*f'), 'other': Expression('-3.2*f+2.3')}, pulse.integral)


class MappingPulseTemplateSequencingTests(unittest.TestCase):
class MappingPulseTemplateSequencingTest(MeasurementWindowTestCase):

def test_create_program(self) -> None:
measurement_mapping = {'meas1': 'meas2'}
parameter_mapping = {'t': 'k'}
channel_mapping = {'B': 'default'}

template = DummyPulseTemplate(measurements=[('meas1', 0, 1)], measurement_names={'meas1'}, defined_channels={'B'},
waveform=DummyWaveform(duration=2.0),
duration=2,
parameter_names={'t'})
st = MappingPulseTemplate(template, parameter_mapping=parameter_mapping,
measurement_mapping=measurement_mapping, channel_mapping=channel_mapping)

pre_parameters = {'k': ConstantParameter(5)}
pre_measurement_mapping = {'meas2': 'meas3'}
pre_channel_mapping = {'default': 'A'}

program = Loop()
st._internal_create_program(parameters=pre_parameters,
measurement_mapping=pre_measurement_mapping,
channel_mapping=pre_channel_mapping,
parent_loop=program)

self.assertEqual(1, len(template.create_program_calls))
self.assertEqual((st.map_parameters(pre_parameters),
st.get_updated_measurement_mapping(pre_measurement_mapping),
st.get_updated_channel_mapping(pre_channel_mapping),
program),
template.create_program_calls[-1])

self.assertEqual(1, program.repetition_count)
self.assertEqual(1, len(program.children))
self.assertIs(template.waveform, program.children[0].waveform)
self.assert_measurement_windows_equal({'meas3': ([0], [1])}, program.get_measurement_windows())

def test_create_program_invalid_measurement_mapping(self) -> None:
measurement_mapping = {'meas1': 'meas2'}
parameter_mapping = {'t': 'k'}
channel_mapping = {'B': 'default'}

template = DummyPulseTemplate(measurements=[('meas1', 0, 1)], measurement_names={'meas1'},
defined_channels={'B'},
waveform=DummyWaveform(duration=2.0),
duration=2,
parameter_names={'t'})
st = MappingPulseTemplate(template, parameter_mapping=parameter_mapping,
measurement_mapping=measurement_mapping, channel_mapping=channel_mapping)

pre_parameters = {'k': ConstantParameter(5)}
pre_measurement_mapping = {}
pre_channel_mapping = {'default': 'A'}

program = Loop()
with self.assertRaises(KeyError):
st._internal_create_program(parameters=pre_parameters,
measurement_mapping=pre_measurement_mapping,
channel_mapping=pre_channel_mapping,
parent_loop=program)

def test_create_program_missing_params(self) -> None:
measurement_mapping = {'meas1': 'meas2'}
parameter_mapping = {'t': 'k'}
channel_mapping = {'B': 'default'}

template = DummyPulseTemplate(measurements=[('meas1', 0, 1)], measurement_names={'meas1'},
defined_channels={'B'},
waveform=DummyWaveform(duration=2.0),
duration=2,
parameter_names={'t'})
st = MappingPulseTemplate(template, parameter_mapping=parameter_mapping,
measurement_mapping=measurement_mapping, channel_mapping=channel_mapping)

pre_parameters = {}
pre_measurement_mapping = {'meas2': 'meas3'}
pre_channel_mapping = {'default': 'A'}

program = Loop()
with self.assertRaises(ParameterNotProvidedException):
st._internal_create_program(parameters=pre_parameters,
measurement_mapping=pre_measurement_mapping,
channel_mapping=pre_channel_mapping,
parent_loop=program)

def test_create_program_parameter_constraint_violation(self) -> None:
measurement_mapping = {'meas1': 'meas2'}
parameter_mapping = {'t': 'k'}
channel_mapping = {'B': 'default'}

template = DummyPulseTemplate(measurements=[('meas1', 0, 1)], measurement_names={'meas1'},
defined_channels={'B'},
waveform=DummyWaveform(duration=2.0),
duration=2,
parameter_names={'t'})
st = MappingPulseTemplate(template, parameter_mapping=parameter_mapping,
measurement_mapping=measurement_mapping, channel_mapping=channel_mapping,
parameter_constraints={'k > 6'})

pre_parameters = {'k': ConstantParameter(5)}
pre_measurement_mapping = {'meas2': 'meas3'}
pre_channel_mapping = {'default': 'A'}

program = Loop()
with self.assertRaises(ParameterConstraintViolation):
st._internal_create_program(parameters=pre_parameters,
measurement_mapping=pre_measurement_mapping,
channel_mapping=pre_channel_mapping,
parent_loop=program)

def test_create_program_subtemplate_none(self) -> None:
measurement_mapping = {'meas1': 'meas2'}
parameter_mapping = {'t': 'k'}
channel_mapping = {'B': 'default'}

template = DummyPulseTemplate(measurements=[('meas1', 0, 1)], measurement_names={'meas1'},
defined_channels={'B'},
waveform=None,
duration=0,
parameter_names={'t'})
st = MappingPulseTemplate(template, parameter_mapping=parameter_mapping,
measurement_mapping=measurement_mapping, channel_mapping=channel_mapping)

pre_parameters = {'k': ConstantParameter(5)}
pre_measurement_mapping = {'meas2': 'meas3'}
pre_channel_mapping = {'default': 'A'}

program = Loop()
st._internal_create_program(parameters=pre_parameters,
measurement_mapping=pre_measurement_mapping,
channel_mapping=pre_channel_mapping,
parent_loop=program)

self.assertEqual(1, len(template.create_program_calls))
self.assertEqual((st.map_parameters(pre_parameters),
st.get_updated_measurement_mapping(pre_measurement_mapping),
st.get_updated_channel_mapping(pre_channel_mapping),
program),
template.create_program_calls[-1])

self.assertEqual(1, program.repetition_count)
self.assertEqual(0, len(program.children))
self.assertIsNone(program._measurements)


class MappingPulseTemplateOldSequencingTests(unittest.TestCase):

def test_build_sequence(self):
measurement_mapping = {'meas1': 'meas2'}
Expand Down Expand Up @@ -218,6 +362,7 @@ def test_build_sequence(self):
def test_requires_stop(self):
pass


class PulseTemplateParameterMappingExceptionsTests(unittest.TestCase):

def test_missing_mapping_exception_str(self) -> None:
Expand Down

0 comments on commit 2580635

Please sign in to comment.