diff --git a/ReleaseNotes.txt b/ReleaseNotes.txt index 1905f359e..32605fff2 100644 --- a/ReleaseNotes.txt +++ b/ReleaseNotes.txt @@ -16,6 +16,9 @@ - Make ExpressionScalar hashable - Fix bug that prevented evaluation of expressions containing some special functions (`erfc`, `factorial`, etc.) +- Parameters: + - `ConstantParameter` now accepts a `Expression` without free variables as value (given as `Expression` or string) + ## 0.2 ## - General: diff --git a/qupulse/pulses/parameters.py b/qupulse/pulses/parameters.py index 13b2478c4..f4fbbc94c 100644 --- a/qupulse/pulses/parameters.py +++ b/qupulse/pulses/parameters.py @@ -16,7 +16,7 @@ import numpy from qupulse.serialization import AnonymousSerializable -from qupulse.expressions import Expression +from qupulse.expressions import Expression, ExpressionVariableMissingException from qupulse.utils.types import HashableNumpyArray, DocStringABCMeta __all__ = ["Parameter", "ConstantParameter", @@ -49,7 +49,7 @@ def requires_stop(self) -> bool: @abstractmethod def __hash__(self) -> int: - pass + """Returns a hash value of the parameter. Must be implemented.""" def __eq__(self, other) -> bool: return type(self) is type(other) and hash(self) == hash(other) @@ -58,17 +58,22 @@ def __eq__(self, other) -> bool: class ConstantParameter(Parameter): """A pulse parameter with a constant value.""" - def __init__(self, value: Union[Real, numpy.ndarray]) -> None: + def __init__(self, value: Union[Real, numpy.ndarray, Expression, str, sympy.Expr]) -> None: """Create a ConstantParameter instance. Args: value (Real): The value of the parameter """ super().__init__() - if isinstance(value, Real): - self._value = value - else: - self._value = numpy.array(value).view(HashableNumpyArray) + try: + if isinstance(value, Real): + self._value = value + elif isinstance(value, (str, Expression, sympy.Expr)): + self._value = Expression(value).evaluate_numeric() + else: + self._value = numpy.array(value).view(HashableNumpyArray) + except ExpressionVariableMissingException: + raise RuntimeError("Expressions passed into ConstantParameter may not have free variables.") def get_value(self) -> Union[Real, numpy.ndarray]: return self._value diff --git a/tests/pulses/bug_tests.py b/tests/pulses/bug_tests.py index d632f844c..3f66482e0 100644 --- a/tests/pulses/bug_tests.py +++ b/tests/pulses/bug_tests.py @@ -49,9 +49,7 @@ def test_plotting_two_channel_function_pulse_after_two_channel_table_pulse_crash _ = plot(sequence_template, parameters=sequence_parameters, sample_rate=100, show=False) - @unittest.expectedFailure def test_plot_with_parameter_value_being_expression_string(self) -> None: - """This is currently not supported but probably should be?""" sine_measurements = [('M', 't_duration/2', 't_duration')] sine = FunctionPulseTemplate('a*sin(omega*t)', 't_duration', measurements=sine_measurements) sine_channel_mapping = dict(default='sin_channel') diff --git a/tests/pulses/parameters_tests.py b/tests/pulses/parameters_tests.py index 3e9956045..ac3b78cb2 100644 --- a/tests/pulses/parameters_tests.py +++ b/tests/pulses/parameters_tests.py @@ -25,6 +25,26 @@ def test_repr(self) -> None: constant_parameter = ConstantParameter(0.2) self.assertEqual("", repr(constant_parameter)) + def test_expression_value(self) -> None: + expression_str = "exp(4)*sin(pi/2)" + expression_obj = Expression(expression_str) + expression_val = expression_obj.evaluate_numeric() + param = ConstantParameter(expression_str) + self.assertEqual(expression_val, param.get_value()) + param = ConstantParameter(expression_obj) + self.assertEqual(expression_val, param.get_value()) + + def test_invalid_expression_value(self) -> None: + expression_obj = Expression("sin(pi/2*t)") + with self.assertRaises(RuntimeError): + ConstantParameter(expression_obj) + + def test_numpy_value(self) -> None: + import numpy as np + arr = np.array([6, 7, 8]) + param = ConstantParameter(arr) + np.array_equal(arr, param.get_value()) + class MappedParameterTest(unittest.TestCase): @@ -87,10 +107,12 @@ def test_no_relation(self): ParameterConstraint('a*b') ParameterConstraint('1 < 2') - def test_str(self): + def test_str_and_serialization(self): self.assertEqual(str(ParameterConstraint('a < b')), 'a < b') + self.assertEqual(ParameterConstraint('a < b').get_serialization_data(), 'a < b') self.assertEqual(str(ParameterConstraint('a==b')), 'a==b') + self.assertEqual(ParameterConstraint('a==b').get_serialization_data(), 'a==b') class ParameterNotProvidedExceptionTests(unittest.TestCase): diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index 6fe7a7dc9..73d66223a 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -164,12 +164,12 @@ class PulseTemplateTest(unittest.TestCase): def test_create_program(self) -> None: template = PulseTemplateStub(defined_channels={'A'}, parameter_names={'foo'}) - parameters = {'foo': ConstantParameter(2.126), 'bar': -26.2, 'hugo': '2*x+b', 'append_a_child': '1'} + parameters = {'foo': ConstantParameter(2.126), 'bar': -26.2, 'hugo': 'exp(sin(pi/2))', 'append_a_child': '1'} measurement_mapping = {'M': 'N'} channel_mapping = {'A': 'B'} expected_parameters = {'foo': ConstantParameter(2.126), 'bar': ConstantParameter(-26.2), - 'hugo': ConstantParameter('2*x+b'), 'append_a_child': ConstantParameter('1')} + 'hugo': ConstantParameter('exp(sin(pi/2))'), 'append_a_child': ConstantParameter('1')} to_single_waveform = {'voll', 'toggo'} global_transformation = TransformationStub() @@ -308,11 +308,11 @@ def test_create_program_channel_mapping(self): def test_create_program_none(self) -> None: template = PulseTemplateStub(defined_channels={'A'}, parameter_names={'foo'}) - parameters = {'foo': ConstantParameter(2.126), 'bar': -26.2, 'hugo': '2*x+b'} + parameters = {'foo': ConstantParameter(2.126), 'bar': -26.2, 'hugo': 'exp(sin(pi/2))'} measurement_mapping = {'M': 'N'} channel_mapping = {'A': 'B'} expected_parameters = {'foo': ConstantParameter(2.126), 'bar': ConstantParameter(-26.2), - 'hugo': ConstantParameter('2*x+b')} + 'hugo': ConstantParameter('exp(sin(pi/2))')} expected_internal_kwargs = dict(parameters=expected_parameters, measurement_mapping=measurement_mapping, channel_mapping=channel_mapping,