Skip to content

Commit

Permalink
Tests for FunctionPulseTemplate.
Browse files Browse the repository at this point in the history
  • Loading branch information
lumip committed Jul 29, 2016
1 parent 557016a commit 4c35e44
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 30 deletions.
9 changes: 4 additions & 5 deletions qctoolkit/pulses/function_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def parameter_names(self) -> Set[str]:

@property
def parameter_declarations(self) -> Set[ParameterDeclaration]:
return [ParameterDeclaration(param_name) for param_name in self.parameter_names]
return {ParameterDeclaration(param_name) for param_name in self.parameter_names}

def get_pulse_length(self, parameters: Dict[str, Parameter]) -> float:
"""Return the length of this pulse for the given parameters.
Expand Down Expand Up @@ -118,8 +118,7 @@ def requires_stop(self,
conditions: Dict[str, 'Condition']) -> bool:
return any(
parameters[name].requires_stop
for name in parameters.keys()
if (name in self.parameter_names) and not isinstance(parameters[name], numbers.Number)
for name in parameters.keys() if (name in self.parameter_names)
)

def get_serialization_data(self, serializer: Serializer) -> None:
Expand All @@ -136,7 +135,7 @@ def deserialize(serializer: 'Serializer', **kwargs) -> 'Serializable':
return FunctionPulseTemplate(
kwargs['expression'],
kwargs['duration_expression'],
kwargs['Measurement']
kwargs['measurement']
)


Expand Down Expand Up @@ -173,7 +172,7 @@ def __evaluate_partially(self, t):

@property
def compare_key(self) -> Any:
return self.__expression
return self.__expression, self.__duration, self.__parameters

@property
def duration(self) -> float:
Expand Down
109 changes: 84 additions & 25 deletions tests/pulses/function_pulse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,58 +2,117 @@

from qctoolkit.pulses.function_pulse_template import FunctionPulseTemplate,\
FunctionWaveform
from qctoolkit.pulses.sequencing import Sequencer
from qctoolkit.pulses.instructions import InstructionBlock
from qctoolkit.pulses.parameters import ConstantParameter
from qctoolkit.pulses.parameters import ParameterDeclaration, ParameterNotProvidedException
from qctoolkit.expressions import Expression
import numpy as np

from tests.serialization_dummies import DummySerializer
from qctoolkit.expressions import Expression
from tests.pulses.sequencing_dummies import DummyParameter

import numpy as np

class FunctionPulseTest(unittest.TestCase):
def setUp(self):
self.s = 'a + b'

def setUp(self) -> None:
self.maxDiff = None
self.s = 'a + b * t'
self.s2 = 'c'
self.pars = dict(a=ConstantParameter(1), b=ConstantParameter(2), c=ConstantParameter(3))
self.fpt = FunctionPulseTemplate(self.s, self.s2)
self.pars = dict(a=DummyParameter(1), b=DummyParameter(2), c=DummyParameter(136.78))

def test_get_pulse_length(self) -> None:
self.assertEqual(136.78, self.fpt.get_pulse_length(self.pars))

def test_get_pulse_length(self):
self.assertEqual(self.fpt.get_pulse_length(self.pars), 3)
def test_get_pulse_length(self) -> None:
with self.assertRaises(ParameterNotProvidedException):
self.fpt.get_pulse_length(dict(b=DummyParameter(26.3267)))

# def test_get_measurement_windows(self):
# self.assertEqual(self.fpt.get_measurement_windows(self.pars), None)
#
# fpt2 = FunctionPulseTemplate(self.s, self.s2, measurement=True)
# self.assertEqual(fpt2.get_measurement_windows(self.pars), [(0, 3)])
def test_is_interruptable(self) -> None:
self.assertFalse(self.fpt.is_interruptable)

def test_serialization_data(self):
def test_num_channels(self) -> None:
self.assertEqual(1, self.fpt.num_channels)

def test_serialization_data(self) -> None:
expected_data = dict(type='FunctionPulseTemplate',
parameter_names=set(['a', 'b', 'c']),
duration_expression=str(self.s2),
expression=str(self.s),
measurement=False)
self.assertEqual(expected_data, self.fpt.get_serialization_data(DummySerializer(serialize_callback=lambda x: str(x))))

def test_deserialize(self) -> None:
basic_data = dict(type='FunctionPulseTemplate',
parameter_names=set(['a', 'b', 'c']),
duration_expression=str(self.s2),
expression=str(self.s),
measurement=False)
serializer = DummySerializer(serialize_callback=lambda x: str(x))
template = FunctionPulseTemplate.deserialize(serializer, **basic_data)
self.assertEqual(basic_data['parameter_names'], template.parameter_names)
self.assertEqual({ParameterDeclaration(name) for name in basic_data['parameter_names']}, template.parameter_declarations)
serialized_data = template.get_serialization_data(serializer)
self.assertEqual(basic_data, serialized_data)

def test_parameter_names_and_declarations_expression_input(self) -> None:
template = FunctionPulseTemplate(Expression("3 * foo + bar * t"), Expression("5 * hugo"))
expected_parameter_names = {'foo', 'bar', 'hugo'}
self.assertEqual(expected_parameter_names, template.parameter_names)
self.assertEqual({ParameterDeclaration(name) for name in expected_parameter_names}, template.parameter_declarations)

def test_parameter_names_and_declarations_string_input(self) -> None:
template = FunctionPulseTemplate("3 * foo + bar * t", "5 * hugo")
expected_parameter_names = {'foo', 'bar', 'hugo'}
self.assertEqual(expected_parameter_names, template.parameter_names)
self.assertEqual({ParameterDeclaration(name) for name in expected_parameter_names},
template.parameter_declarations)


class FunctionPulseSequencingTest(unittest.TestCase):
def setUp(self):

def setUp(self) -> None:
unittest.TestCase.setUp(self)
self.f = "a * t"
self.duration = "y"
self.args = dict(a=ConstantParameter(3),y=ConstantParameter(1))
self.args = dict(a=DummyParameter(3),y=DummyParameter(1))
self.fpt = FunctionPulseTemplate(self.f, self.duration)

def test_build_sequence(self):
ib = InstructionBlock()
seq = Sequencer()
cond = None
self.fpt.build_sequence(seq, self.args, cond, ib)

def test_build_waveform(self) -> None:
wf = self.fpt.build_waveform(self.args)
self.assertIsNotNone(wf)
self.assertIsInstance(wf, FunctionWaveform)
expected_waveform = FunctionWaveform(dict(a=3, y=1), Expression(self.f), Expression(self.duration))
self.assertEqual(expected_waveform, wf)

def test_requires_stop(self) -> None:
parameters = dict(a=DummyParameter(36.126), y=DummyParameter(247.9543))
self.assertFalse(self.fpt.requires_stop(parameters, dict()))
parameters = dict(a=DummyParameter(36.126), y=DummyParameter(247.9543, requires_stop=True))
self.assertTrue(self.fpt.requires_stop(parameters, dict()))


class FunctionWaveformTest(unittest.TestCase):

def test_sample(self):
def test_equality(self) -> None:
wf1a = FunctionWaveform(dict(a=2, b=1), Expression('a*t'), Expression('b'))
wf1b = FunctionWaveform(dict(a=2, b=1), Expression('a*t'), Expression('b'))
wf2 = FunctionWaveform(dict(a=3, b=1), Expression('a*t'), Expression('b'))
wf3 = FunctionWaveform(dict(a=2, b=1), Expression('a*t+2'), Expression('b'))
wf4 = FunctionWaveform(dict(a=2, c=2), Expression('a*t'), Expression('c'))
self.assertEqual(wf1a, wf1a)
self.assertEqual(wf1a, wf1b)
self.assertNotEqual(wf1a, wf2)
self.assertNotEqual(wf1a, wf3)
self.assertNotEqual(wf1a, wf4)

def test_num_channels(self) -> None:
wf = FunctionWaveform(dict(), Expression('t'), Expression('4'))
self.assertEqual(1, wf.num_channels)

def test_duration(self) -> None:
wf = FunctionWaveform(dict(foo=2.5), Expression('2*t'), Expression('4*foo/5'))
self.assertEqual(2, wf.duration)

def test_sample(self) -> None:
f = Expression("(t+1)**b")
length = Expression("c**b")
par = {"b":2,"c":10}
Expand Down

0 comments on commit 4c35e44

Please sign in to comment.