Skip to content

Commit

Permalink
Some docstrings for Expression. Other templates now delegate to Seria…
Browse files Browse the repository at this point in the history
…lizer._serialize_subpulse when serializing Expression subjects instead of simply converting it to strings.
  • Loading branch information
lumip committed May 6, 2016
1 parent 6085724 commit 9fe200f
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 28 deletions.
2 changes: 1 addition & 1 deletion qctoolkit/comparable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Comparable(metaclass=ABCMeta):

@abstractproperty
def _compare_key(self) -> Any:
"""Return a unique key used in comparison and hashing operations.
"""Returns a unique key used in comparison and hashing operations.
The key must describe the essential properties of the object. Two objects are equal iff their keys are identical.
"""
Expand Down
65 changes: 56 additions & 9 deletions qctoolkit/expressions.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,57 @@
from py_expression_eval import Parser
import py_expression_eval
from typing import Any, Dict, Iterable, Optional

from qctoolkit.serialization import Serializable
from qctoolkit.comparable import Comparable
from qctoolkit.serialization import Serializable, Serializer

__all__ = ["Expression"]


class Expression(Serializable):
class Expression(Serializable, Comparable):
"""A mathematical expression."""

def __init__(self, ex: str) -> None:
"""Creates an Expression object.
Receives the mathematical expression which shall be represented by the object as a string which will be parsed
using py_expression_eval. For available operators, functions and constants see
https://github.com/AxiaCore/py-expression-eval/#available-operators-constants-and-functions . In addition,
the ** operator may be used for exponentiation instead of the ^ operator.
Args:
ex (string): The mathematical expression represented as a string
"""
self.__string = str(ex) # type: str
self.__expression = Parser().parse(ex.replace('**', '^')) # type: py_expression_eval.Expression
self.__expression = py_expression_eval.Parser().parse(ex.replace('**', '^')) # type: py_expression_eval.Expression

@property
def string(self) -> str:
def __str__(self) -> str:
"""Returns a string representation of this expression."""
return self.__string

@property
def _compare_key(self) -> Any:
"""Returns the string representation of this expression as unique key used in comparison and hashing operations."""
return str(self)

def variables(self) -> Iterable[str]:
"""Returns all variables occurring in the expression."""
return self.__expression.variables()

def evaluate(self, **kwargs) -> float:
return self.__expression.evaluate(kwargs)
"""Evaluates the expression with the required variables passed in as kwargs.
Keyword Args:
<variable name> float: Values for the free variables of the expression.
Raises:
ExpressionVariableMissingException if a value for a variable is not provided.
"""
try:
return self.__expression.evaluate(kwargs)
except Exception as e:
raise ExpressionVariableMissingException(str(e).split(' ')[2], self) from e

def get_serialization_data(self, serializer: 'Serializer') -> Dict[str, Any]:
return dict(type='Expression', expression=self.__string)
def get_serialization_data(self, serializer: Serializer) -> Dict[str, Any]:
return dict(type=serializer.get_type_identifier(self), expression=str(self))

@staticmethod
def deserialize(serializer: 'Serializer', **kwargs) -> Serializable:
Expand All @@ -32,3 +60,22 @@ def deserialize(serializer: 'Serializer', **kwargs) -> Serializable:
@property
def identifier(self) -> Optional[str]:
return None


class ExpressionVariableMissingException(Exception):

def __init__(self, variable: str, expression: Expression) -> None:
super().__init__()
self.__variable = variable
self.__expression = expression

@property
def expression(self) -> Expression:
return self.__expression

@property
def variable(self) -> str:
return self.__variable

def __str__(self) -> str:
return "Could not evaluate <{}>: A value for variable <{}> is missing!".format(str(self.expression), self.variable)
5 changes: 3 additions & 2 deletions qctoolkit/pulses/function_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ def get_serialization_data(self, serializer: Serializer) -> None:
root = dict()
root['type'] = 'FunctionPulseTemplate'
root['parameter_names'] = self.__parameter_names
root['duration_expression'] = self.__duration_expression.string
root['expression'] = self.__expression.string
root['duration_expression'] = serializer._serialize_subpulse(self.__duration_expression)
#self.__duration_expression.string
root['expression'] = serializer._serialize_subpulse(self.__expression)
root['measurement'] = self.__is_measurement_pulse
return root

Expand Down
6 changes: 3 additions & 3 deletions qctoolkit/pulses/sequence_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def get_serialization_data(self, serializer: Serializer) -> Dict[str, Any]:

subtemplates = []
for (subtemplate, mapping_functions) in self.subtemplates:
mapping_functions_strings = {k: m.string for k, m in mapping_functions.items()}
mapping_functions_strings = {k: serializer._serialize_subpulse(m) for k, m in mapping_functions.items()}
subtemplate = serializer._serialize_subpulse(subtemplate)
subtemplates.append(dict(template=subtemplate, mappings=copy.deepcopy(mapping_functions_strings)))
subtemplates.append(dict(template=subtemplate, mappings=mapping_functions_strings))
data['subtemplates'] = subtemplates

data['type'] = serializer.get_type_identifier(self)
Expand All @@ -134,7 +134,7 @@ def deserialize(serializer: Serializer,
identifier: Optional[str]=None) -> 'SequencePulseTemplate':
subtemplates = \
[(serializer.deserialize(d['template']),
{k: m for k, m in d['mappings'].items()}) for d in subtemplates]
{k: str(serializer.deserialize(m)) for k, m in d['mappings'].items()}) for d in subtemplates]

template = SequencePulseTemplate(subtemplates, external_parameters, identifier=identifier)
template.is_interruptable = is_interruptable
Expand Down
15 changes: 13 additions & 2 deletions tests/expression_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from qctoolkit.expressions import Expression
from qctoolkit.expressions import Expression, ExpressionVariableMissingException


class ExpressionTests(unittest.TestCase):
Expand All @@ -12,4 +12,15 @@ def test_evaluate(self) -> None:
'b': 1.5,
'c': -7
}
self.assertEqual(2*1.5 - 7, e.evaluate(**params))
self.assertEqual(2*1.5 - 7, e.evaluate(**params))

def test_variables(self) -> None:
e = Expression('4 ** PI + x * foo')
self.assertEqual(sorted(['foo','x']), sorted(e.variables()))

def test_evaluate_variable_missing(self) -> None:
e = Expression('a * b + c')
params = {
'b': 1.5
}
self.assertRaises(ExpressionVariableMissingException, e.evaluate, **params)
7 changes: 4 additions & 3 deletions tests/pulses/function_pulse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ def test_get_measurement_windows(self):
def test_serialization_data(self):
expected_data = dict(type='FunctionPulseTemplate',
parameter_names=set(['a', 'b', 'c']),
duration_expression=self.s2,
expression=self.s,
duration_expression=str(self.s2),
expression=str(self.s),
measurement=False)
self.assertEqual(expected_data, self.fpt.get_serialization_data(DummySerializer()))
self.assertEqual(expected_data, self.fpt.get_serialization_data(DummySerializer(serialize_callback=lambda x: str(x))))


class FunctionPulseSequencingTest(unittest.TestCase):
def setUp(self):
Expand Down
21 changes: 13 additions & 8 deletions tests/pulses/sequence_pulse_template_tests.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import unittest
import copy

from qctoolkit.expressions import Expression
from qctoolkit.pulses.table_pulse_template import TablePulseTemplate
from qctoolkit.pulses.sequence_pulse_template import SequencePulseTemplate, MissingMappingException, UnnecessaryMappingException, MissingParameterDeclarationException
from qctoolkit.pulses.parameters import ParameterDeclaration, ParameterNotProvidedException, ConstantParameter

from tests.pulses.sequencing_dummies import DummySequencer, DummyInstructionBlock, DummyPulseTemplate, DummyParameter, DummyNoValueParameter
from tests.serialization_dummies import DummySerializer


class SequencePulseTemplateTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -71,28 +71,30 @@ def setUp(self) -> None:
self.foo_mappings = dict(hugo='ilse', albert='albert', voltage='voltage')

def test_get_serialization_data(self) -> None:
serializer = DummySerializer(serialize_callback=lambda x: str(x))
foo_mappings = {k: Expression(v) for k, v in self.foo_mappings.items()}
sequence = SequencePulseTemplate([(self.table_foo, self.foo_mappings), (self.table, {})],
['ilse', 'albert', 'voltage'],
identifier='foo')

expected_data = dict(
type=self.serializer.get_type_identifier(sequence),
type=serializer.get_type_identifier(sequence),
external_parameters=['albert', 'ilse', 'voltage'],
is_interruptable=True,
subtemplates = [
dict(template=str(id(self.table_foo)), mappings=self.foo_mappings),
dict(template=str(id(self.table)), mappings=dict())
dict(template=str(self.table_foo), mappings={k: str(v) for k, v in foo_mappings.items()}),
dict(template=str(self.table), mappings=dict())
]
)
data = sequence.get_serialization_data(self.serializer)
data = sequence.get_serialization_data(serializer)
self.assertEqual(expected_data, data)

def test_deserialize(self) -> None:
foo_mappings = {k: Expression(v) for k, v in self.foo_mappings.items()}
data = dict(
external_parameters={'ilse', 'albert', 'voltage'},
is_interruptable=True,
subtemplates = [
dict(template=str(id(self.table_foo)), mappings=self.foo_mappings),
dict(template=str(id(self.table_foo)), mappings={k: str(id(v)) for k, v in foo_mappings.items()}),
dict(template=str(id(self.table)), mappings=dict())
],
identifier='foo'
Expand All @@ -101,6 +103,8 @@ def test_deserialize(self) -> None:
# prepare dependencies for deserialization
self.serializer.subelements[str(id(self.table_foo))] = self.table_foo
self.serializer.subelements[str(id(self.table))] = self.table
for v in foo_mappings.values():
self.serializer.subelements[str(id(v))] = v

# deserialize
sequence = SequencePulseTemplate.deserialize(self.serializer, **data)
Expand All @@ -111,7 +115,8 @@ def test_deserialize(self) -> None:
sequence.parameter_declarations)
self.assertIs(self.table_foo, sequence.subtemplates[0][0])
self.assertIs(self.table, sequence.subtemplates[1][0])
self.assertEqual(self.foo_mappings, {k: m.string for k,m in sequence.subtemplates[0][1].items()})
#self.assertEqual(self.foo_mappings, {k: m.string for k,m in sequence.subtemplates[0][1].items()})
self.assertEqual(foo_mappings, sequence.subtemplates[0][1])
self.assertEqual(dict(), sequence.subtemplates[1][1])
self.assertEqual(data['identifier'], sequence.identifier)

Expand Down

0 comments on commit 9fe200f

Please sign in to comment.