From 29294659e06f5116dd66055c08d166a536883e59 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 19 May 2021 13:36:37 +0200 Subject: [PATCH 01/23] Add constant propagation check to Transformation --- qupulse/_program/transformation.py | 30 +++++++++++- tests/_program/transformation_tests.py | 63 ++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/qupulse/_program/transformation.py b/qupulse/_program/transformation.py index 697b8d7c6..6df76ad77 100644 --- a/qupulse/_program/transformation.py +++ b/qupulse/_program/transformation.py @@ -39,6 +39,10 @@ def chain(self, next_transformation: 'Transformation') -> 'Transformation': else: return chain_transformations(self, next_transformation) + def is_constant_invariant(self): + """Signals if the transformation always maps constants to constants.""" + return False + class IdentityTransformation(Transformation, metaclass=SingletonABCMeta): def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Mapping[ChannelID, np.ndarray]: @@ -60,6 +64,10 @@ def chain(self, next_transformation: Transformation) -> Transformation: def __repr__(self): return 'IdentityTransformation()' + def is_constant_invariant(self): + """Signals if the transformation always maps constants to constants.""" + return True + class ChainedTransformation(Transformation): def __init__(self, *transformations: Transformation): @@ -94,6 +102,10 @@ def chain(self, next_transformation) -> 'ChainedTransformation': def __repr__(self): return 'ChainedTransformation%r' % (self._transformations,) + def is_constant_invariant(self): + """Signals if the transformation always maps constants to constants.""" + return all(trafo.is_constant_invariant() for trafo in self._transformations) + class LinearTransformation(Transformation): def __init__(self, @@ -138,7 +150,7 @@ def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Ma transformed_data = self._matrix @ data_in for idx, out_channel in enumerate(self._output_channels): - data_out[out_channel] = transformed_data[idx, :] + data_out[out_channel] = transformed_data[idx, ...] return data_out @@ -169,6 +181,10 @@ def __repr__(self): input_channels=self._input_channels, output_channels=self._output_channels) + def is_constant_invariant(self): + """Signals if the transformation always maps constants to constants.""" + return True + class OffsetTransformation(Transformation): def __init__(self, offsets: Mapping[ChannelID, Real]): @@ -198,6 +214,10 @@ def compare_key(self) -> frozenset: def __repr__(self): return 'OffsetTransformation(%r)' % self._offsets + def is_constant_invariant(self): + """Signals if the transformation always maps constants to constants.""" + return True + class ScalingTransformation(Transformation): def __init__(self, factors: Mapping[ChannelID, Real]): @@ -220,6 +240,10 @@ def compare_key(self) -> frozenset: def __repr__(self): return 'ScalingTransformation(%r)' % self._factors + def is_constant_invariant(self): + """Signals if the transformation always maps constants to constants.""" + return True + try: import pandas @@ -268,6 +292,10 @@ def get_output_channels(self, input_channels: Set[ChannelID]) -> Set[ChannelID]: def __repr__(self): return 'ParallelConstantChannelTransformation(%r)' % self._channels + def is_constant_invariant(self): + """Signals if the transformation always maps constants to constants.""" + return True + def chain_transformations(*transformations: Transformation) -> Transformation: parsed_transformations = [] diff --git a/tests/_program/transformation_tests.py b/tests/_program/transformation_tests.py index 1122552fd..e75e17dc6 100644 --- a/tests/_program/transformation_tests.py +++ b/tests/_program/transformation_tests.py @@ -23,6 +23,15 @@ def compare_key(self): return id(self) +def assert_scalar_trafo_works(test_case: unittest.TestCase, trafo: Transformation, scalar_in: dict): + non_scalar = {ch: np.array([val]) for ch, val in scalar_in.items()} + + out_non_scalar = trafo(np.zeros((1,)), non_scalar) + out_scalar = trafo(0., scalar_in) + for ch in out_scalar: + test_case.assertEqual(out_non_scalar[ch][0], out_scalar[ch]) + + class TransformationTests(unittest.TestCase): def test_chain(self): trafo = TransformationStub() @@ -143,6 +152,21 @@ def test_repr(self): trafo = LinearTransformation(matrix, in_chs, out_chs) self.assertEqual(trafo, eval(repr(trafo))) + def test_scalar_trafo_works(self): + in_chs = ('a', 'b', 'c') + out_chs = ('transformed_a', 'transformed_b') + matrix = np.array([[1, -1, 0], [1, 1, 1]]) + trafo = LinearTransformation(matrix, in_chs, out_chs) + + assert_scalar_trafo_works(self, trafo, {'a': 0., 'b': 0.3, 'c': 0.6}) + + def test_constant_propagation(self): + in_chs = ('a', 'b', 'c') + out_chs = ('transformed_a', 'transformed_b') + matrix = np.array([[1, -1, 0], [1, 1, 1]]) + trafo = LinearTransformation(matrix, in_chs, out_chs) + self.assertTrue(trafo.is_constant_invariant()) + class IdentityTransformationTests(unittest.TestCase): def test_compare_key(self): @@ -172,6 +196,12 @@ def test_repr(self): trafo = IdentityTransformation() self.assertEqual(trafo, eval(repr(trafo))) + def test_scalar_trafo_works(self): + assert_scalar_trafo_works(self, IdentityTransformation(), {'a': 0., 'b': 0.3, 'c': 0.6}) + + def test_constant_propagation(self): + self.assertTrue(IdentityTransformation().is_constant_invariant()) + class ChainedTransformationTests(unittest.TestCase): def test_init_and_properties(self): @@ -249,6 +279,12 @@ def test_repr(self): trafo = ChainedTransformation(ScalingTransformation({'a': 1.1}), OffsetTransformation({'b': 6.6})) self.assertEqual(trafo, eval(repr(trafo))) + def test_constant_propagation(self): + trafo = ChainedTransformation(ScalingTransformation({'a': 1.1}), OffsetTransformation({'b': 6.6})) + self.assertTrue(trafo.is_constant_invariant()) + trafo = ChainedTransformation(ScalingTransformation({'a': 1.1}), TransformationStub()) + self.assertFalse(trafo.is_constant_invariant()) + class ParallelConstantChannelTransformationTests(unittest.TestCase): def test_init(self): @@ -297,6 +333,17 @@ def test_repr(self): trafo = ParallelConstantChannelTransformation(channels) self.assertEqual(trafo, eval(repr(trafo))) + def test_scalar_trafo_works(self): + channels = {'X': 2, 'Y': 4.4} + trafo = ParallelConstantChannelTransformation(channels) + + assert_scalar_trafo_works(self, trafo, {'a': 0., 'b': 0.3, 'c': 0.6}) + + def test_constant_propagation(self): + channels = {'X': 2, 'Y': 4.4} + trafo = ParallelConstantChannelTransformation(channels) + self.assertTrue(trafo.is_constant_invariant()) + class TestChaining(unittest.TestCase): def test_identity_result(self): @@ -369,6 +416,14 @@ def test_repr(self): trafo = OffsetTransformation(self.offsets) self.assertEqual(trafo, eval(repr(trafo))) + def test_scalar_trafo_works(self): + trafo = OffsetTransformation(self.offsets) + assert_scalar_trafo_works(self, trafo, {'A': 0., 'B': 0.3, 'c': 0.6}) + + def test_constant_propagation(self): + trafo = OffsetTransformation(self.offsets) + self.assertTrue(trafo.is_constant_invariant()) + class TestScalingTransformation(unittest.TestCase): def setUp(self) -> None: @@ -406,3 +461,11 @@ def test_trafo(self): def test_repr(self): trafo = OffsetTransformation(self.scales) self.assertEqual(trafo, eval(repr(trafo))) + + def test_scalar_trafo_works(self): + trafo = ScalingTransformation(self.scales) + assert_scalar_trafo_works(self, trafo, {'A': 0., 'B': 0.3, 'c': 0.6}) + + def test_constant_propagation(self): + trafo = ScalingTransformation(self.scales) + self.assertTrue(trafo.is_constant_invariant()) From 6a9f2134aca367e43644001d4bd94210e134e5d6 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Sat, 29 May 2021 13:44:52 +0200 Subject: [PATCH 02/23] Add constant value inspection functions to waveform --- qupulse/_program/waveforms.py | 68 +++++++++++++++++++++++++++-------- 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 61ae624dd..881c3a831 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -10,7 +10,7 @@ from numbers import Real from typing import ( AbstractSet, Any, FrozenSet, Iterable, Mapping, NamedTuple, Sequence, Set, - Tuple, Union, cast) + Tuple, Union, cast, Optional) from weakref import WeakValueDictionary, ref import numpy as np @@ -98,21 +98,29 @@ def get_sampled(self, if channel not in self.defined_channels: raise KeyError('Channel not defined in this waveform: {}'.format(channel)) - if output_array is None: - # cache the result to save memory - result = self.unsafe_sample(channel, sample_times) - result.flags.writeable = False - key = hash(bytes(result)) - if key not in self.__sampled_cache: - self.__sampled_cache[key] = result - return self.__sampled_cache[key] + constant_value = self.constant_value(channel) + if constant_value is None: + if output_array is None: + # cache the result to save memory + result = self.unsafe_sample(channel, sample_times) + result.flags.writeable = False + key = hash(bytes(result)) + if key not in self.__sampled_cache: + self.__sampled_cache[key] = result + return self.__sampled_cache[key] + else: + if len(output_array) != len(sample_times): + raise ValueError('Output array length and sample time length are different') + # use the user provided memory + return self.unsafe_sample(channel=channel, + sample_times=sample_times, + output_array=output_array) else: - if len(output_array) != len(sample_times): - raise ValueError('Output array length and sample time length are different') - # use the user provided memory - return self.unsafe_sample(channel=channel, - sample_times=sample_times, - output_array=output_array) + if output_array is None: + output_array = np.full_like(sample_times, fill_value=constant_value, dtype=float) + else: + output_array[:] = constant_value + return output_array @property @abstractmethod @@ -142,6 +150,36 @@ def get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform return self return self.unsafe_get_subset_for_channels(channels=channels) + def is_constant(self) -> bool: + """Convenience function to check if all channels are constant. The result is equal to + `all(waveform.constant_value(ch) is not None for ch in waveform.defined_channels)` but might be more performant. + + Returns: + True if all channels have constant values. + """ + return self.constant_value_dict() is None + + def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: + result = {ch: self.constant_value(ch) for ch in self.defined_channels} + if None in result.values(): + return None + else: + return result + + def constant_value(self, channel: ChannelID) -> Optional[float]: + """Checks if the requested channel has a constant value and returns it if so. + + Guarantee that this assertion passes for every t in waveform duration: + >>> assert waveform.constant_value(channel) is None or waveform.constant_value(t) = waveform.get_sampled(channel, t) + + Args: + channel: The channel to check + + Returns: + None if there is no guarantee that the channel is constant. The value otherwise. + """ + return None + def __neg__(self): return FunctorWaveform(self, {ch: np.negative for ch in self.defined_channels}) From 6f0a68945307e7211e6e9f17d8bc4162a5c00220 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Sat, 29 May 2021 13:47:37 +0200 Subject: [PATCH 03/23] Add example implementation of constant analysis with constant detecting classemethod for TableWaveform --- qupulse/_program/waveforms.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 881c3a831..c417b8b3b 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -6,6 +6,7 @@ import itertools import operator +import warnings from abc import ABCMeta, abstractmethod from numbers import Real from typing import ( @@ -205,16 +206,20 @@ class TableWaveform(Waveform): def __init__(self, channel: ChannelID, - waveform_table: Sequence[EntryInInit]) -> None: + waveform_table: Tuple[TableWaveformEntry, ...]) -> None: """Create a new TableWaveform instance. Args: - waveform_table (ImmutableList(WaveformTableEntry)): A list of instantiated table - entries of the form (time as float, voltage as float, interpolation strategy). + waveform_table: A tuple of instantiated and validated table entries """ super().__init__() - self._table = self._validate_input(waveform_table) + if not isinstance(waveform_table, tuple): + warnings.warn("Please use a tuple of TableWaveformEntry to construct TableWaveform directly", + category=DeprecationWarning) + waveform_table = self._validate_input(waveform_table) + + self._table = waveform_table self._channel_id = channel @staticmethod @@ -256,6 +261,23 @@ def _validate_input(input_waveform_table: Sequence[EntryInInit]) -> Tuple[TableW return tuple(output_waveform_table) + def is_constant(self) -> bool: + # only correct if `from_table` is used + return False + + def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: + # only correct if `from_table` is used + return None + + @classmethod + def from_table(cls, channel: ChannelID, table: Sequence[EntryInInit]) -> Union['TableWaveform', 'ConstantWaveform']: + table = cls._validate_input(table) + v = table[0].v + if all(entry.v == v for entry in table): + return ConstantWaveform(table[-1].t, v, channel) + else: + return TableWaveform(channel, table) + @property def compare_key(self) -> Any: return self._channel_id, self._table From b7e4ffb83fcda470c76649e1edf1d591c163104a Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Sat, 29 May 2021 13:51:32 +0200 Subject: [PATCH 04/23] Constant inspection for function and constant waveform. --- qupulse/_program/waveforms.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index c417b8b3b..a10908e54 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -322,6 +322,7 @@ def __repr__(self): class ConstantWaveform(Waveform): + # TODO: remove _is_constant_waveform = True def __init__(self, duration: float, amplitude: Any, channel: ChannelID): @@ -330,6 +331,16 @@ def __init__(self, duration: float, amplitude: Any, channel: ChannelID): self._amplitude = amplitude self._channel = channel + def is_constant(self) -> bool: + return True + + def constant_value(self, channel: ChannelID) -> Optional[float]: + assert channel == self._channel + return self._amplitude + + def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: + # only correct if `from_table` is used + return {self._channel: self._amplitude} @property def duration(self) -> TimeType: @@ -379,11 +390,29 @@ def __init__(self, expression: ExpressionScalar, super().__init__() if set(expression.variables) - set('t'): raise ValueError('FunctionWaveforms may not depend on anything but "t"') + elif not expression.variables: + warnings.warn("Constant FunctionWaveform is not recommended as the constant propagation will be suboptimal", + category=UserWarning) self._expression = expression self._duration = TimeType.from_float(duration, absolute_error=PULSE_TO_WAVEFORM_ERROR) self._channel_id = channel + @classmethod + def from_expression(cls, expression: ExpressionScalar, duration: float, channel: ChannelID) -> Union['FunctionWaveform', ConstantWaveform]: + if expression.variables: + return cls(expression, duration, channel) + else: + return ConstantWaveform(expression.evaluate_numeric(), duration, channel) + + def is_constant(self) -> bool: + # only correct if `from_expression` is used + return False + + def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: + # only correct if `from_table` is used + return None + @property def defined_channels(self) -> Set[ChannelID]: return {self._channel_id} From b6c43ab5872bae96ffb777ca331de2438ac97ee0 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 31 May 2021 15:13:28 +0200 Subject: [PATCH 05/23] Use AbstractSet interface for transformations module --- qupulse/_program/transformation.py | 47 ++++++++++++++++-------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/qupulse/_program/transformation.py b/qupulse/_program/transformation.py index 6df76ad77..6afe13962 100644 --- a/qupulse/_program/transformation.py +++ b/qupulse/_program/transformation.py @@ -1,4 +1,4 @@ -from typing import Mapping, Set, Tuple, Sequence +from typing import Mapping, Set, Tuple, Sequence, AbstractSet, Union from abc import abstractmethod from numbers import Real @@ -26,11 +26,11 @@ def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Ma """ @abstractmethod - def get_output_channels(self, input_channels: Set[ChannelID]) -> Set[ChannelID]: + def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: """Return the channel identifiers""" @abstractmethod - def get_input_channels(self, output_channels: Set[ChannelID]) -> Set[ChannelID]: + def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: """Channels that are required for getting data for the requested output channel""" def chain(self, next_transformation: 'Transformation') -> 'Transformation': @@ -48,14 +48,14 @@ class IdentityTransformation(Transformation, metaclass=SingletonABCMeta): def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Mapping[ChannelID, np.ndarray]: return data - def get_output_channels(self, input_channels: Set[ChannelID]) -> Set[ChannelID]: + def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return input_channels @property def compare_key(self) -> None: return None - def get_input_channels(self, output_channels: Set[ChannelID]) -> Set[ChannelID]: + def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return output_channels def chain(self, next_transformation: Transformation) -> Transformation: @@ -77,12 +77,12 @@ def __init__(self, *transformations: Transformation): def transformations(self) -> Tuple[Transformation, ...]: return self._transformations - def get_output_channels(self, input_channels: Set[ChannelID]) -> Set[ChannelID]: + def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: for transformation in self._transformations: input_channels = transformation.get_output_channels(input_channels) return input_channels - def get_input_channels(self, output_channels: Set[ChannelID]) -> Set[ChannelID]: + def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: for transformation in reversed(self._transformations): output_channels = transformation.get_input_channels(output_channels) return output_channels @@ -96,7 +96,7 @@ def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Ma def compare_key(self) -> Tuple[Transformation, ...]: return self._transformations - def chain(self, next_transformation) -> 'ChainedTransformation': + def chain(self, next_transformation) -> Transformation: return chain_transformations(*self.transformations, next_transformation) def __repr__(self): @@ -133,6 +133,8 @@ def __init__(self, self._matrix = transformation_matrix self._input_channels = tuple(sorted(input_channels)) self._output_channels = tuple(sorted(output_channels)) + self._input_channels_set = frozenset(self._input_channels) + self._output_channels_set = frozenset(self._output_channels) def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Mapping[ChannelID, np.ndarray]: data_out = {forwarded_channel: data[forwarded_channel] @@ -154,20 +156,21 @@ def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Ma return data_out - def get_output_channels(self, input_channels: Set[ChannelID]) -> Set[ChannelID]: - if not input_channels.issuperset(self._input_channels): - raise KeyError('Invalid input channels', input_channels, set(self._input_channels)) + def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: + if not input_channels >= self._input_channels_set: + # input_channels is not a superset of the required input channels + raise KeyError('Invalid input channels', input_channels, self._input_channels_set) - return input_channels.difference(self._input_channels).union(self._output_channels) + return (input_channels - self._input_channels_set) | self._output_channels_set - def get_input_channels(self, output_channels: Set[ChannelID]) -> Set[ChannelID]: - forwarded = output_channels.difference(self._output_channels) + def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: + forwarded = output_channels - self._output_channels_set if not forwarded.isdisjoint(self._input_channels): - raise KeyError('Is input channel', forwarded.intersection(self._input_channels)) + raise KeyError('Is input channel', forwarded & self._input_channels_set) elif output_channels.isdisjoint(self._output_channels): return output_channels else: - return forwarded.union(self._input_channels) + return forwarded | self._input_channels_set @property def compare_key(self) -> Tuple[Tuple[ChannelID], Tuple[ChannelID], bytes]: @@ -201,10 +204,10 @@ def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Ma return {channel: channel_values + self._offsets[channel] if channel in self._offsets else channel_values for channel, channel_values in data.items()} - def get_input_channels(self, output_channels: Set[ChannelID]): + def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return output_channels - def get_output_channels(self, input_channels: Set[ChannelID]): + def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return input_channels @property @@ -227,10 +230,10 @@ def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Ma return {channel: channel_values * self._factors[channel] if channel in self._factors else channel_values for channel, channel_values in data.items()} - def get_input_channels(self, output_channels: Set[ChannelID]): + def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return output_channels - def get_output_channels(self, input_channels: Set[ChannelID]): + def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return input_channels @property @@ -283,10 +286,10 @@ def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Ma def compare_key(self) -> Tuple[Tuple[ChannelID, float], ...]: return tuple(sorted(self._channels.items())) - def get_input_channels(self, output_channels: Set[ChannelID]) -> Set[ChannelID]: + def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return output_channels - self._channels.keys() - def get_output_channels(self, input_channels: Set[ChannelID]) -> Set[ChannelID]: + def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return input_channels | self._channels.keys() def __repr__(self): From 320d014b79af048da12e907434398c936ea26324 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 31 May 2021 15:15:50 +0200 Subject: [PATCH 06/23] Change transformation interface to accept scalar values --- qupulse/_program/transformation.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/qupulse/_program/transformation.py b/qupulse/_program/transformation.py index 6afe13962..4a1eb875e 100644 --- a/qupulse/_program/transformation.py +++ b/qupulse/_program/transformation.py @@ -15,7 +15,8 @@ class Transformation(Comparable): of input and output channels might differ.""" @abstractmethod - def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Mapping[ChannelID, np.ndarray]: + def __call__(self, time: Union[np.ndarray, float], + data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: """Apply transformation to data Args: time: @@ -45,7 +46,8 @@ def is_constant_invariant(self): class IdentityTransformation(Transformation, metaclass=SingletonABCMeta): - def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Mapping[ChannelID, np.ndarray]: + def __call__(self, time: Union[np.ndarray, float], + data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: return data def get_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: @@ -87,7 +89,8 @@ def get_input_channels(self, output_channels: AbstractSet[ChannelID]) -> Abstrac output_channels = transformation.get_input_channels(output_channels) return output_channels - def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Mapping[ChannelID, np.ndarray]: + def __call__(self, time: Union[np.ndarray, float], + data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: for transformation in self._transformations: data = transformation(time, data) return data @@ -136,7 +139,8 @@ def __init__(self, self._input_channels_set = frozenset(self._input_channels) self._output_channels_set = frozenset(self._output_channels) - def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Mapping[ChannelID, np.ndarray]: + def __call__(self, time: Union[np.ndarray, float], + data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: data_out = {forwarded_channel: data[forwarded_channel] for forwarded_channel in set(data.keys()).difference(self._input_channels)} @@ -200,7 +204,8 @@ def __init__(self, offsets: Mapping[ChannelID, Real]): """ self._offsets = dict(offsets.items()) - def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Mapping[ChannelID, np.ndarray]: + def __call__(self, time: Union[np.ndarray, float], + data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: return {channel: channel_values + self._offsets[channel] if channel in self._offsets else channel_values for channel, channel_values in data.items()} @@ -226,7 +231,8 @@ class ScalingTransformation(Transformation): def __init__(self, factors: Mapping[ChannelID, Real]): self._factors = dict(factors.items()) - def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Mapping[ChannelID, np.ndarray]: + def __call__(self, time: Union[np.ndarray, float], + data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: return {channel: channel_values * self._factors[channel] if channel in self._factors else channel_values for channel, channel_values in data.items()} @@ -277,7 +283,8 @@ def __init__(self, channels: Mapping[ChannelID, Real]): self._channels = {channel: float(value) for channel, value in channels.items()} - def __call__(self, time: np.ndarray, data: Mapping[ChannelID, np.ndarray]) -> Mapping[ChannelID, np.ndarray]: + def __call__(self, time: Union[np.ndarray, float], + data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: overwritten = {channel: np.full_like(time, fill_value=value, dtype=float) for channel, value in self._channels.items()} return {**data, **overwritten} From cead9bf29f06cd2b944f62caca473e374425ead3 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 31 May 2021 15:18:24 +0200 Subject: [PATCH 07/23] Use AbstractSet for Waveform.defined_channels --- qupulse/_program/waveforms.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index a10908e54..8e76cca21 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -125,7 +125,7 @@ def get_sampled(self, @property @abstractmethod - def defined_channels(self) -> Set[ChannelID]: + def defined_channels(self) -> AbstractSet[ChannelID]: """The channels this waveform should played on. Use :func:`~qupulse.pulses.instructions.get_measurement_windows` to get a waveform for a subset of these.""" @@ -310,7 +310,7 @@ def unsafe_sample(self, return output_array @property - def defined_channels(self) -> Set[ChannelID]: + def defined_channels(self) -> AbstractSet[ChannelID]: return {self._channel_id} def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform': @@ -347,7 +347,7 @@ def duration(self) -> TimeType: return time_from_float(float(self._duration), absolute_error=PULSE_TO_WAVEFORM_ERROR) @property - def defined_channels(self) -> Set[ChannelID]: + def defined_channels(self) -> AbstractSet[ChannelID]: """The channels this waveform should played on. Use :func:`~qupulse.pulses.instructions.get_measurement_windows` to get a waveform for a subset of these.""" @@ -414,7 +414,7 @@ def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: return None @property - def defined_channels(self) -> Set[ChannelID]: + def defined_channels(self) -> AbstractSet[ChannelID]: return {self._channel_id} @property @@ -468,7 +468,7 @@ def flattened_sub_waveforms() -> Iterable[Waveform]: ) @property - def defined_channels(self) -> Set[ChannelID]: + def defined_channels(self) -> AbstractSet[ChannelID]: return self._sequenced_waveforms[0].defined_channels def unsafe_sample(self, @@ -562,12 +562,13 @@ def get_sub_waveform_sort_key(waveform): self._sub_waveforms = tuple(sorted(flatten_sub_waveforms(sub_waveforms), key=get_sub_waveform_sort_key)) - self.__defined_channels = set() + defined_channels = set() for waveform in self._sub_waveforms: - if waveform.defined_channels & self.__defined_channels: + if waveform.defined_channels & defined_channels: raise ValueError('Channel may not be defined in multiple waveforms', - waveform.defined_channels & self.__defined_channels) - self.__defined_channels |= waveform.defined_channels + waveform.defined_channels & defined_channels) + defined_channels |= waveform.defined_channels + self._defined_channels = frozenset(defined_channels) if not all(isclose(waveform.duration, self._sub_waveforms[0].duration) for waveform in self._sub_waveforms[1:]): # meaningful error message: @@ -597,8 +598,8 @@ def __getitem__(self, key: ChannelID) -> Waveform: raise KeyError('Unknown channel ID: {}'.format(key), key) @property - def defined_channels(self) -> Set[ChannelID]: - return self.__defined_channels + def defined_channels(self) -> AbstractSet[ChannelID]: + return self._defined_channels @property def compare_key(self) -> Any: @@ -633,7 +634,7 @@ def __init__(self, body: Waveform, repetition_count: int): raise ValueError('Repetition count must be an integer >0') @property - def defined_channels(self) -> Set[ChannelID]: + def defined_channels(self) -> AbstractSet[ChannelID]: return self._body.defined_channels def unsafe_sample(self, @@ -685,7 +686,7 @@ def transformation(self) -> Transformation: return self._transformation @property - def defined_channels(self) -> Set[ChannelID]: + def defined_channels(self) -> AbstractSet[ChannelID]: return self.transformation.get_output_channels(self.inner_waveform.defined_channels) @property @@ -798,8 +799,8 @@ def duration(self) -> TimeType: return self._lhs.duration @property - def defined_channels(self) -> Set[ChannelID]: - return set.union(self._lhs.defined_channels, self._rhs.defined_channels) + def defined_channels(self) -> AbstractSet[ChannelID]: + return self._lhs.defined_channels | self._rhs.defined_channels def unsafe_sample(self, channel: ChannelID, @@ -853,7 +854,7 @@ def duration(self) -> TimeType: return self._inner_waveform.duration @property - def defined_channels(self) -> Set[ChannelID]: + def defined_channels(self) -> AbstractSet[ChannelID]: return self._inner_waveform.defined_channels def unsafe_sample(self, From d9cd8b8619b29c37bcfa8ec7c0a8240c7b84ec23 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 31 May 2021 15:34:46 +0200 Subject: [PATCH 08/23] Add multi channel construction method for ConstantWaveform --- qupulse/_program/waveforms.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 8e76cca21..4951df526 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -331,6 +331,17 @@ def __init__(self, duration: float, amplitude: Any, channel: ChannelID): self._amplitude = amplitude self._channel = channel + @classmethod + def from_mapping(cls, duration: Real, constant_values: Mapping[ChannelID, float]) -> Waveform: + """Construct a ConstantWaveform or a MultiChannelWaveform of ConstantWaveforms with given duration and values""" + assert constant_values + if len(constant_values) == 1: + (channel, amplitude), = constant_values.items() + return cls(duration, amplitude=amplitude, channel=channel) + else: + return MultiChannelWaveform([cls(duration, amplitude=amplitude, channel=channel) + for channel, amplitude in constant_values.items()]) + def is_constant(self) -> bool: return True From b9ab4069f0db9ab62a609e5256259ce2571f4731 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 31 May 2021 15:40:12 +0200 Subject: [PATCH 09/23] Add better constant interface and optimizing constructor method to SequenceWaveform --- qupulse/_program/waveforms.py | 63 ++++++++++++++++++++++++++++++----- 1 file changed, 54 insertions(+), 9 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 4951df526..7a1851b2e 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -463,21 +463,62 @@ def __init__(self, sub_waveforms: Iterable[Waveform]): "SequenceWaveform cannot be constructed without channel waveforms." ) - def flattened_sub_waveforms() -> Iterable[Waveform]: - for sub_waveform in sub_waveforms: - if isinstance(sub_waveform, SequenceWaveform): - yield from sub_waveform._sequenced_waveforms - else: - yield sub_waveform - - self._sequenced_waveforms = tuple(flattened_sub_waveforms()) + self._sequenced_waveforms = tuple(sub_waveforms) self._duration = sum(waveform.duration for waveform in self._sequenced_waveforms) - if not all(waveform.defined_channels == self.defined_channels for waveform in self._sequenced_waveforms[1:]): + + defined_channels = self._sequenced_waveforms[0].defined_channels + if not all(waveform.defined_channels == defined_channels + for waveform in itertools.islice(self._sequenced_waveforms, 1, None)): raise ValueError( "SequenceWaveform cannot be constructed from waveforms of different" "defined channels." ) + @classmethod + def from_sequence(cls, waveforms: Sequence['Waveform']) -> 'Waveform': + """Returns a waveform the represents the given sequence of waveforms. Applies some optimizations.""" + assert waveforms, "Sequence must not be empty" + if len(waveforms) == 1: + return waveforms[0] + + flattened = [] + constant_values = waveforms[0].constant_value_dict() + for wf in waveforms: + if constant_values and constant_values != wf.constant_value_dict(): + constant_values = None + if isinstance(wf, cls): + flattened.extend(wf.sequenced_waveforms) + else: + flattened.append(wf) + if constant_values is None: + return cls(sub_waveforms=flattened) + else: + duration = sum(wf.duration for wf in flattened) + ConstantWaveform.from_mapping(duration, constant_values) + + def is_constant(self) -> bool: + # only correct if from_sequence is used for construction + return False + + def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: + # only correct if from_sequence is used for construction + return None + + def constant_value(self, channel: ChannelID) -> Optional[float]: + v = None + for wf in self._sequenced_waveforms: + wf_cv = wf.constant_value(channel) + if wf_cv is None: + return None + elif wf_cv == v: + continue + elif v is None: + v = wf_cv + else: + assert v != wf_cv + return None + return v + @property def defined_channels(self) -> AbstractSet[ChannelID]: return self._sequenced_waveforms[0].defined_channels @@ -514,6 +555,10 @@ def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'W sub_waveform.unsafe_get_subset_for_channels(channels & sub_waveform.defined_channels) for sub_waveform in self._sequenced_waveforms if sub_waveform.defined_channels & channels) + @property + def sequenced_waveforms(self) -> Sequence[Waveform]: + return self._sequenced_waveforms + class MultiChannelWaveform(Waveform): """A MultiChannelWaveform is a Waveform object that allows combining arbitrary Waveform objects From 04ed1552b1581369f1af8acb58f5d5e6795d3070 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 31 May 2021 15:41:44 +0200 Subject: [PATCH 10/23] Add more optimized constant inspection and construction methods to MultiChannelWaveform --- qupulse/_program/waveforms.py | 72 +++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 7a1851b2e..9a3d9fdcb 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -187,6 +187,11 @@ def __neg__(self): def __pos__(self): return self + def _sort_key_for_channels(self) -> Sequence[Tuple[str, int]]: + """Makes reproducible sorting by defined channels possible""" + return sorted((ch, 0) if isinstance(ch, str) else ('', ch) for ch in self.defined_channels) + + class TableWaveformEntry(NamedTuple('TableWaveformEntry', [('t', Real), ('v', float), @@ -579,16 +584,17 @@ class MultiChannelWaveform(Waveform): assigned more than one channel of any Waveform object it consists of """ - def __init__(self, sub_waveforms: Iterable[Waveform]) -> None: + def __init__(self, sub_waveforms: List[Waveform]) -> None: """Create a new MultiChannelWaveform instance. + Use `MultiChannelWaveform.from_parallel` for optimal construction. Requires a list of subwaveforms in the form (Waveform, List(int)) where the list defines the channel mapping, i.e., a value y at index x in the list means that channel x of the subwaveform will be mapped to channel y of this MultiChannelWaveform object. Args: - sub_waveforms (Iterable( Waveform )): The list of sub waveforms of this - MultiChannelWaveform + sub_waveforms: The list of sub waveforms of this + MultiChannelWaveform. List might get sorted! Raises: ValueError, if a channel mapping is out of bounds of the channels defined by this MultiChannelWaveform @@ -602,21 +608,12 @@ def __init__(self, sub_waveforms: Iterable[Waveform]) -> None: "MultiChannelWaveform cannot be constructed without channel waveforms." ) - # avoid unnecessary multi channel nesting - def flatten_sub_waveforms(to_flatten): - for sub_waveform in to_flatten: - if isinstance(sub_waveform, MultiChannelWaveform): - yield from sub_waveform._sub_waveforms - else: - yield sub_waveform - # sort the waveforms with their defined channels to make compare key reproducible - def get_sub_waveform_sort_key(waveform): - return tuple(sorted(tuple('{}_stringified_numeric_channel'.format(ch) if isinstance(ch, int) else ch - for ch in waveform.defined_channels))) + if not isinstance(sub_waveforms, list): + sub_waveforms = list(sub_waveforms) + sub_waveforms.sort(key=lambda wf: wf._sort_key_for_channels()) - self._sub_waveforms = tuple(sorted(flatten_sub_waveforms(sub_waveforms), - key=get_sub_waveform_sort_key)) + self._sub_waveforms = tuple(sub_waveforms) defined_channels = set() for waveform in self._sub_waveforms: @@ -643,6 +640,41 @@ def get_sub_waveform_sort_key(waveform): durations ) + @staticmethod + def from_parallel(waveforms: Sequence[Waveform]) -> Waveform: + assert waveforms, "ARgument must not be empty" + if len(waveforms) == 1: + return waveforms[0] + + # we do not look at constant values here because there is no benefit. We would need to construct a new + # MultiChannelWaveform anyways + + # avoid unnecessary multi channel nesting + flattened = [] + for waveform in waveforms: + if isinstance(waveform, MultiChannelWaveform): + flattened.extend(waveform._sub_waveforms) + else: + flattened.append(waveform) + + return MultiChannelWaveform(flattened) + + def is_constant(self) -> bool: + return all(wf.is_constant() for wf in self._sub_waveforms) + + def constant_value(self, channel: ChannelID) -> Optional[float]: + return self[channel].constant_value(channel) + + def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: + d = {} + for wf in self._sub_waveforms: + wf_d = wf.constant_value_dict() + if wf_d is None: + return None + else: + d.update(wf_d) + return d + @property def duration(self) -> TimeType: return self._sub_waveforms[0].duration @@ -669,13 +701,13 @@ def unsafe_sample(self, return self[channel].unsafe_sample(channel, sample_times, output_array) def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform': - relevant_sub_waveforms = tuple(swf for swf in self._sub_waveforms if swf.defined_channels & channels) + relevant_sub_waveforms = [swf for swf in self._sub_waveforms if swf.defined_channels & channels] if len(relevant_sub_waveforms) == 1: return relevant_sub_waveforms[0].get_subset_for_channels(channels) elif len(relevant_sub_waveforms) > 1: - return MultiChannelWaveform( - sub_waveform.get_subset_for_channels(channels & sub_waveform.defined_channels) - for sub_waveform in relevant_sub_waveforms) + return MultiChannelWaveform.from_parallel( + [sub_waveform.get_subset_for_channels(channels & sub_waveform.defined_channels) + for sub_waveform in relevant_sub_waveforms]) else: raise KeyError('Unknown channels: {}'.format(channels)) From 10c2c83b254c55c4df0ec207db3c856a4e349f52 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 31 May 2021 15:42:42 +0200 Subject: [PATCH 11/23] Add more optimal constant propagation to RepetitionWaveform --- qupulse/_program/waveforms.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 9a3d9fdcb..d3b696c2f 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -721,6 +721,18 @@ def __init__(self, body: Waveform, repetition_count: int): if repetition_count < 1 or not isinstance(repetition_count, int): raise ValueError('Repetition count must be an integer >0') + self.is_constant = self._body.is_constant + self.constant_value = self._body.constant_value + self.constant_value_dict = self._body.constant_value_dict + + @classmethod + def from_repetition_count(cls, body: Waveform, repetition_count: int) -> Waveform: + constant_values = body.constant_value_dict() + if constant_values is None: + return RepetitionWaveform(body, repetition_count) + else: + return ConstantWaveform.from_mapping(body.duration * repetition_count, constant_values) + @property def defined_channels(self) -> AbstractSet[ChannelID]: return self._body.defined_channels From 2bbaf7423a362196637a4a2da8685325f76514d5 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 31 May 2021 15:43:25 +0200 Subject: [PATCH 12/23] Add better constant propagation to transformation and functor waveform --- qupulse/_program/waveforms.py | 67 +++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index d3b696c2f..dff14be77 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -777,6 +777,34 @@ def __init__(self, inner_waveform: Waveform, transformation: Transformation): self._cached_data = None self._cached_times = lambda: None + @classmethod + def from_transformation(cls, inner_waveform: Waveform, transformation: Transformation) -> Waveform: + constant_values = inner_waveform.constant_value_dict() + + if constant_values is None or not transformation.is_constant_invariant(): + return cls(inner_waveform, transformation) + + transformed_constant_values = transformation(0., constant_values) + return ConstantWaveform.from_mapping(inner_waveform.duration, transformed_constant_values) + + def is_constant(self) -> bool: + # only true if `from_transformation` was used + return False + + def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: + # only true if `from_transformation` was used + return None + + def constant_value(self, channel: ChannelID) -> Optional[float]: + if not self._transformation.is_constant_invariant(): + return None + in_channels = self._transformation.get_input_channels({channel}) + in_values = {ch: self._inner_waveform.constant_value(ch) for ch in in_channels} + if any(val is None for val in in_values.values()): + return None + else: + return self._transformation(0., in_values)[channel][0] + @property def inner_waveform(self) -> Waveform: return self._inner_waveform @@ -831,6 +859,8 @@ def __init__(self, inner_waveform: Waveform, channel_subset: Set[ChannelID]): self._inner_waveform = inner_waveform self._channel_subset = frozenset(channel_subset) + self.constant_value = self._inner_waveform.constant_value + @property def inner_waveform(self) -> Waveform: return self._inner_waveform @@ -882,6 +912,16 @@ def __init__(self, assert np.isclose(float(self._lhs.duration), float(self._rhs.duration)) assert arithmetic_operator in self.operator_map + def constant_value(self, channel: ChannelID) -> Optional[float]: + lhs = self._lhs.constant_value(channel) + if lhs is not None and lhs == self._rhs.constant_value(channel): + return lhs + else: + return None + + def is_constant(self) -> bool: + return self.lhs.is_constant() and self.rhs.is_constant() + @property def lhs(self) -> Waveform: return self._lhs @@ -942,6 +982,8 @@ def compare_key(self) -> Tuple[str, Waveform, Waveform]: class FunctorWaveform(Waveform): """Apply a channel wise functor that works inplace to all results""" + CONSTANT_INVARIANT_FUNCTORS = (np.negative,) + def __init__(self, inner_waveform: Waveform, functor: Mapping[ChannelID, 'Callable']): self._inner_waveform = inner_waveform self._functor = dict(functor.items()) @@ -949,6 +991,31 @@ def __init__(self, inner_waveform: Waveform, functor: Mapping[ChannelID, 'Callab assert set(functor.keys()) == inner_waveform.defined_channels, ("There is no default identity mapping (yet)." "File an issue on github if you need it.") + @classmethod + def from_functor(cls, inner_waveform: Waveform, functor: Mapping[ChannelID, callable]): + constant_values = inner_waveform.constant_value_dict() + if constant_values is None or functor not in cls.CONSTANT_INVARIANT_FUNCTORS: + return FunctorWaveform(inner_waveform, functor) + + funced_constant_values = {ch: functor[ch](val) for ch, val in constant_values.items()} + return ConstantWaveform.from_mapping(inner_waveform.duration, funced_constant_values) + + def is_constant(self) -> bool: + # only correct if `from_functor` was used + return False + + def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: + # only correct if `from_functor` was used + return None + + def constant_value(self, channel: ChannelID) -> Optional[float]: + inner = self._inner_waveform.constant_value(channel) + func = self._functor[channel] + if inner is None or func not in self.CONSTANT_INVARIANT_FUNCTORS: + return None + else: + return func(inner) + @property def duration(self) -> TimeType: return self._inner_waveform.duration From c131a77546b7e03d51bde03dc85e2aa82b787261 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 31 May 2021 15:43:54 +0200 Subject: [PATCH 13/23] Fix some docstrings and missing imports --- qupulse/_program/waveforms.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index dff14be77..ea492536e 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -11,7 +11,7 @@ from numbers import Real from typing import ( AbstractSet, Any, FrozenSet, Iterable, Mapping, NamedTuple, Sequence, Set, - Tuple, Union, cast, Optional) + Tuple, Union, cast, Optional, List) from weakref import WeakValueDictionary, ref import numpy as np @@ -355,7 +355,6 @@ def constant_value(self, channel: ChannelID) -> Optional[float]: return self._amplitude def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: - # only correct if `from_table` is used return {self._channel: self._amplitude} @property @@ -370,7 +369,7 @@ def defined_channels(self) -> AbstractSet[ChannelID]: return {self._channel} @property - def compare_key(self) -> Tuple[Any]: + def compare_key(self) -> Tuple[Any, ...]: return self._duration, self._amplitude, self._channel def unsafe_sample(self, @@ -426,7 +425,7 @@ def is_constant(self) -> bool: return False def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: - # only correct if `from_table` is used + # only correct if `from_expression` is used return None @property @@ -459,7 +458,7 @@ def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> Wa class SequenceWaveform(Waveform): """This class allows putting multiple PulseTemplate together in one waveform on the hardware.""" def __init__(self, sub_waveforms: Iterable[Waveform]): - """ + """Use Waveform.from_sequence for optimal construction :param subwaveforms: All waveforms must have the same defined channels """ From 2a4f746822544a82763022ed4dc57080c4cfa513 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 31 May 2021 15:44:14 +0200 Subject: [PATCH 14/23] Fix some waveform tests --- tests/_program/waveforms_tests.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index 46e3502a5..8d3ad7da3 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -136,7 +136,7 @@ def test_init_several_channels(self) -> None: MultiChannelWaveform((dwf_a, dwf_a)) dwf_c_valid = DummyWaveform(duration=2.2, defined_channels={'C'}) - waveform_flat = MultiChannelWaveform((waveform, dwf_c_valid)) + waveform_flat = MultiChannelWaveform.from_parallel((waveform, dwf_c_valid)) self.assertEqual(len(waveform_flat.compare_key), 3) def test_unsafe_sample(self) -> None: @@ -281,7 +281,7 @@ def test_init(self): self.assertEqual(swf1.duration, 2*dwf_ab.duration) self.assertEqual(len(swf1.compare_key), 2) - swf2 = SequenceWaveform((swf1, dwf_ab)) + swf2 = SequenceWaveform.from_sequence((swf1, dwf_ab)) self.assertEqual(swf2.duration, 3 * dwf_ab.duration) self.assertEqual(len(swf2.compare_key), 3) @@ -386,33 +386,32 @@ def test_validate_input_duplicate_removal(self): def test_duration(self) -> None: entries = [TableWaveformEntry(0, 0, HoldInterpolationStrategy()), TableWaveformEntry(5, 1, HoldInterpolationStrategy())] - waveform = TableWaveform('A', entries) + waveform = TableWaveform.from_table('A', entries) self.assertEqual(5, waveform.duration) def test_duration_no_entries_exception(self) -> None: with self.assertRaises(ValueError): - waveform = TableWaveform('A', []) - self.assertEqual(0, waveform.duration) + TableWaveform.from_table('A', []) def test_few_entries(self) -> None: with self.assertRaises(ValueError): - TableWaveform('A', [[]]) + TableWaveform.from_table('A', []) with self.assertRaises(ValueError): - TableWaveform('A', [TableWaveformEntry(0, 0, HoldInterpolationStrategy())]) + TableWaveform.from_table('A', [TableWaveformEntry(0, 0, HoldInterpolationStrategy())]) def test_unsafe_get_subset_for_channels(self): interp = DummyInterpolationStrategy() - entries = [TableWaveformEntry(0, 0, interp), + entries = (TableWaveformEntry(0, 0, interp), TableWaveformEntry(2.1, -33.2, interp), - TableWaveformEntry(5.7, 123.4, interp)] + TableWaveformEntry(5.7, 123.4, interp)) waveform = TableWaveform('A', entries) self.assertIs(waveform.unsafe_get_subset_for_channels({'A'}), waveform) def test_unsafe_sample(self) -> None: interp = DummyInterpolationStrategy() - entries = [TableWaveformEntry(0, 0, interp), + entries = (TableWaveformEntry(0, 0, interp), TableWaveformEntry(2.1, -33.2, interp), - TableWaveformEntry(5.7, 123.4, interp)] + TableWaveformEntry(5.7, 123.4, interp)) waveform = TableWaveform('A', entries) sample_times = numpy.linspace(.5, 5.5, num=11) @@ -436,7 +435,7 @@ def test_simple_properties(self): TableWaveformEntry(2.1, -33.2, interp), TableWaveformEntry(5.7, 123.4, interp)] chan = 'A' - waveform = TableWaveform(chan, entries) + waveform = TableWaveform.from_table(chan, entries) self.assertEqual(waveform.defined_channels, {chan}) self.assertIs(waveform.unsafe_get_subset_for_channels({'A'}), waveform) From 58b59a62e7f5e84bb50866f66ba837668f51d849 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 11 Jun 2021 11:07:59 +0200 Subject: [PATCH 15/23] Add and fix repr to some waveforms --- qupulse/_program/waveforms.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 5597327fd..1aa6cf493 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -201,7 +201,7 @@ def __init__(self, t: float, v: float, interp: InterpolationStrategy): raise TypeError('{} is neither callable nor of type InterpolationStrategy'.format(interp)) def __repr__(self): - return f'{type(self).__name__}(t={self.t}, v={self.v}, interp="{self.interp}")' + return f'{type(self).__name__}(t={self.t!r}, v={self.v!r}, interp={self.interp!r})' class TableWaveform(Waveform): @@ -322,7 +322,7 @@ def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'W return self def __repr__(self): - return f'{type(self).__name__}(channel={self._channel_id}, waveform_table={self._table})' + return f'{type(self).__name__}(channel={self._channel_id!r}, waveform_table={self._table!r})' class ConstantWaveform(Waveform): @@ -386,6 +386,10 @@ def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: """Unsafe version of :func:`~qupulse.pulses.instructions.get_measurement_windows`.""" return self + def __repr__(self): + return f"{type(self).__name__}(duration={self.duration!r}, "\ + f"amplitude={self._amplitude!r}, channel={self._channel!r})" + class FunctionWaveform(Waveform): """Waveform obtained from instantiating a FunctionPulseTemplate.""" @@ -457,6 +461,10 @@ def unsafe_sample(self, def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> Waveform: return self + def __repr__(self): + return f"{type(self).__name__}(duration={self.duration!r}, "\ + f"expression={self._expression!r}, channel={self._channel_id!r})" + class SequenceWaveform(Waveform): """This class allows putting multiple PulseTemplate together in one waveform on the hardware.""" @@ -569,6 +577,9 @@ def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'W def sequenced_waveforms(self) -> Sequence[Waveform]: return self._sequenced_waveforms + def __repr__(self): + return f"{type(self).__name__}({self._sequenced_waveforms})" + class MultiChannelWaveform(Waveform): """A MultiChannelWaveform is a Waveform object that allows combining arbitrary Waveform objects @@ -716,6 +727,9 @@ def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'W else: raise KeyError('Unknown channels: {}'.format(channels)) + def __repr__(self): + return f"{type(self).__name__}({self._sub_waveforms!r})" + class RepetitionWaveform(Waveform): """This class allows putting multiple PulseTemplate together in one waveform on the hardware.""" @@ -771,6 +785,9 @@ def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'R return RepetitionWaveform(body=self._body.unsafe_get_subset_for_channels(channels), repetition_count=self._repetition_count) + def __repr__(self): + return f"{type(self).__name__}(body={self._body!r}, repetition_count={self._repetition_count!r})" + class TransformingWaveform(Waveform): def __init__(self, inner_waveform: Waveform, transformation: Transformation): From 2fd032173bed37d8d8ec8f8dfd93a9138d0a7539 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 11 Jun 2021 11:13:38 +0200 Subject: [PATCH 16/23] Fix bugs in waveform constant propagation --- qupulse/_program/waveforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 1aa6cf493..21f3232c2 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -158,7 +158,7 @@ def is_constant(self) -> bool: Returns: True if all channels have constant values. """ - return self.constant_value_dict() is None + return self.constant_value_dict() is not None def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: result = {ch: self.constant_value(ch) for ch in self.defined_channels} @@ -422,7 +422,7 @@ def from_expression(cls, expression: ExpressionScalar, duration: float, channel: if expression.variables: return cls(expression, duration, channel) else: - return ConstantWaveform(expression.evaluate_numeric(), duration, channel) + return ConstantWaveform(amplitude=expression.evaluate_numeric(), duration=duration, channel=channel) def is_constant(self) -> bool: # only correct if `from_expression` is used @@ -512,7 +512,7 @@ def from_sequence(cls, waveforms: Sequence['Waveform']) -> 'Waveform': return cls(sub_waveforms=flattened) else: duration = sum(wf.duration for wf in flattened) - ConstantWaveform.from_mapping(duration, constant_values) + return ConstantWaveform.from_mapping(duration, constant_values) def is_constant(self) -> bool: # only correct if from_sequence is used for construction @@ -825,7 +825,7 @@ def constant_value(self, channel: ChannelID) -> Optional[float]: if any(val is None for val in in_values.values()): return None else: - return self._transformation(0., in_values)[channel][0] + return self._transformation(0., in_values)[channel] @property def inner_waveform(self) -> Waveform: From 4c94d70c7fc47d2ed9300e1f87708cba422283ee Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 11 Jun 2021 11:14:03 +0200 Subject: [PATCH 17/23] Rework constant propagation for ArithmeticWaveform --- qupulse/_program/waveforms.py | 48 +++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 21f3232c2..ef76bd9c4 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -934,15 +934,53 @@ def __init__(self, assert np.isclose(float(self._lhs.duration), float(self._rhs.duration)) assert arithmetic_operator in self.operator_map - def constant_value(self, channel: ChannelID) -> Optional[float]: - lhs = self._lhs.constant_value(channel) - if lhs is not None and lhs == self._rhs.constant_value(channel): - return lhs + @classmethod + def from_operator(cls, lhs: Waveform, arithmetic_operator: str, rhs: Waveform): + # one could optimize rhs_cv to being only created if lhs_cv is not None but this makes the code harder to read + lhs_cv = lhs.constant_value_dict() + rhs_cv = rhs.constant_value_dict() + if lhs_cv is None or rhs_cv is None: + return cls(lhs, arithmetic_operator, rhs) + else: + constant_values = dict(lhs_cv) + op = cls.operator_map[arithmetic_operator] + rhs_op = cls.rhs_only_map[arithmetic_operator] + + for ch, rhs_val in rhs_cv.items(): + if ch in constant_values: + constant_values[ch] = op(constant_values[ch], rhs_val) + else: + constant_values[ch] = rhs_op(rhs_val) + + duration = lhs.duration + assert isclose(duration, rhs.duration) + + return ConstantWaveform.from_mapping(duration, constant_values) + + def constant_value(self, channel: ChannelID) -> Optional[float]: + if channel not in self._rhs.defined_channels: + return self._lhs.constant_value(channel) + rhs = self._rhs.constant_value(channel) + if rhs is None: return None + if channel in self._lhs.defined_channels: + lhs = self._lhs.constant_value(channel) + if lhs is None: + return None + + return self.operator_map[self._arithmetic_operator](lhs, rhs) + else: + return self.rhs_only_map[self._arithmetic_operator](rhs) + def is_constant(self) -> bool: - return self.lhs.is_constant() and self.rhs.is_constant() + # only correct if from_operator is used + return False + + def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: + # only correct if from_operator is used + return None @property def lhs(self) -> Waveform: From 520f2b0dcca4929bd11651107d31617352d81742 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 11 Jun 2021 11:14:20 +0200 Subject: [PATCH 18/23] Rework FunctorWaveform --- qupulse/_program/waveforms.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index ef76bd9c4..c8e4d1d12 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -1040,11 +1040,11 @@ def compare_key(self) -> Tuple[str, Waveform, Waveform]: class FunctorWaveform(Waveform): - """Apply a channel wise functor that works inplace to all results""" + # TODO: Use Protocol to enforce that it accepts second argument has the keyword out + Functor = callable - CONSTANT_INVARIANT_FUNCTORS = (np.negative,) - - def __init__(self, inner_waveform: Waveform, functor: Mapping[ChannelID, 'Callable']): + """Apply a channel wise functor that works inplace to all results. The functor must accept two arguments""" + def __init__(self, inner_waveform: Waveform, functor: Mapping[ChannelID, Functor]): self._inner_waveform = inner_waveform self._functor = dict(functor.items()) @@ -1052,9 +1052,9 @@ def __init__(self, inner_waveform: Waveform, functor: Mapping[ChannelID, 'Callab "File an issue on github if you need it.") @classmethod - def from_functor(cls, inner_waveform: Waveform, functor: Mapping[ChannelID, callable]): + def from_functor(cls, inner_waveform: Waveform, functor: Mapping[ChannelID, Functor]): constant_values = inner_waveform.constant_value_dict() - if constant_values is None or functor not in cls.CONSTANT_INVARIANT_FUNCTORS: + if constant_values is None: return FunctorWaveform(inner_waveform, functor) funced_constant_values = {ch: functor[ch](val) for ch, val in constant_values.items()} @@ -1070,11 +1070,10 @@ def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: def constant_value(self, channel: ChannelID) -> Optional[float]: inner = self._inner_waveform.constant_value(channel) - func = self._functor[channel] - if inner is None or func not in self.CONSTANT_INVARIANT_FUNCTORS: + if inner is None: return None else: - return func(inner) + return self._functor[channel](inner) @property def duration(self) -> TimeType: @@ -1088,10 +1087,12 @@ def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, output_array: Union[np.ndarray, None] = None) -> np.ndarray: - return self._functor[channel](self._inner_waveform.unsafe_sample(channel, sample_times, output_array)) + inner_output = self._inner_waveform.unsafe_sample(channel, sample_times, output_array) + return self._functor[channel](inner_output, out=inner_output) def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: - return SubsetWaveform(self, channels) + return FunctorWaveform(self._inner_waveform.unsafe_get_subset_for_channels(channels), + {ch: self._functor[ch] for ch in channels}) @property def compare_key(self) -> Tuple[Waveform, FrozenSet]: From 00f9dee972d49b964ad05b772d14300647abddce Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 11 Jun 2021 11:14:52 +0200 Subject: [PATCH 19/23] Increase test coverage --- tests/_program/waveforms_tests.py | 388 ++++++++++++++++++++++++++- tests/pulses/function_pulse_tests.py | 54 +--- tests/pulses/sequencing_dummies.py | 5 +- 3 files changed, 386 insertions(+), 61 deletions(-) diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index 8d3ad7da3..5e5d20ece 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -8,13 +8,49 @@ from qupulse.pulses.interpolation import HoldInterpolationStrategy, LinearInterpolationStrategy,\ JumpInterpolationStrategy from qupulse._program.waveforms import MultiChannelWaveform, RepetitionWaveform, SequenceWaveform,\ - TableWaveformEntry, TableWaveform, TransformingWaveform, SubsetWaveform, ArithmeticWaveform, ConstantWaveform -from qupulse._program.transformation import Transformation + TableWaveformEntry, TableWaveform, TransformingWaveform, SubsetWaveform, ArithmeticWaveform, ConstantWaveform,\ + Waveform, FunctorWaveform, FunctionWaveform +from qupulse._program.transformation import LinearTransformation +from qupulse.expressions import ExpressionScalar from tests.pulses.sequencing_dummies import DummyWaveform, DummyInterpolationStrategy from tests._program.transformation_tests import TransformationStub +def assert_constant_consistent(test_case: unittest.TestCase, wf: Waveform): + if wf.is_constant(): + cvs = wf.constant_value_dict() + test_case.assertEqual(wf.defined_channels, cvs.keys()) + for ch in wf.defined_channels: + test_case.assertEqual(cvs[ch], wf.constant_value(ch)) + else: + test_case.assertIsNone(wf.constant_value_dict()) + test_case.assertIn(None, {wf.constant_value(ch) for ch in wf.defined_channels}) + + +class WaveformStub(Waveform): + @property + def defined_channels(self): + raise NotImplementedError() + + def unsafe_get_subset_for_channels(self, channels) -> 'Waveform': + raise NotImplementedError() + + def unsafe_sample(self, + channel, + sample_times, + output_array=None) -> np.ndarray: + raise NotImplementedError() + + @property + def compare_key(self): + raise NotImplementedError() + + @property + def duration(self) -> TimeType: + raise NotImplementedError() + + class WaveformTest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -90,6 +126,22 @@ def test_get_subset_for_channels(self): wf_sub = wf_ab.get_subset_for_channels({'A'}) self.assertEqual(wf_sub.defined_channels, {'A'}) + def test_constant_default_impl(self): + wf = DummyWaveform(defined_channels={'A', 'B'}) + self.assertFalse(wf.is_constant()) + + values = {'A': 4., 'B': 5.} + wf.constant_value = lambda ch: values[ch] + self.assertEqual(values, wf.constant_value_dict()) + assert_constant_consistent(self, wf) + + def test_negation(self): + wf = DummyWaveform(defined_channels={'A', 'B'}) + self.assertIs(wf, +wf) + + expected_neg = FunctorWaveform(wf, {'A': np.negative, 'B': np.negative}) + self.assertEqual(expected_neg, -wf) + class MultiChannelWaveformTest(unittest.TestCase): def test_init_no_args(self) -> None: @@ -98,6 +150,19 @@ def test_init_no_args(self) -> None: with self.assertRaises(ValueError): MultiChannelWaveform(None) + def test_from_parallel(self): + dwf_a = DummyWaveform(duration=2.2, defined_channels={'A'}) + dwf_b = DummyWaveform(duration=2.2, defined_channels={'B'}) + dwf_c = DummyWaveform(duration=2.2, defined_channels={'C'}) + + self.assertIs(dwf_a, MultiChannelWaveform.from_parallel([dwf_a])) + + wf_ab = MultiChannelWaveform.from_parallel([dwf_a, dwf_b]) + self.assertEqual(wf_ab, MultiChannelWaveform([dwf_a, dwf_b])) + + wf_abc = MultiChannelWaveform.from_parallel([wf_ab, dwf_c]) + self.assertEqual(wf_abc, MultiChannelWaveform([dwf_a, dwf_b, dwf_c])) + def test_get_item(self): dwf_a = DummyWaveform(duration=2.2, defined_channels={'A'}) dwf_b = DummyWaveform(duration=2.2, defined_channels={'B'}) @@ -204,6 +269,25 @@ def test_unsafe_get_subset_for_channels(self): self.assertIs(sub_ab.unsafe_get_subset_for_channels({'A'}), dwf_a) self.assertIs(sub_ab.unsafe_get_subset_for_channels({'B'}), dwf_b) + def test_constant_default_impl(self): + wf_non_const_a = DummyWaveform(defined_channels={'A'}, duration=3) + wf_non_const_b = DummyWaveform(defined_channels={'B'}, duration=3) + wf_const_c = ConstantWaveform(channel='C', amplitude=2.2, duration=3) + wf_const_d = ConstantWaveform(channel='D', amplitude=3.3, duration=3) + + wf_const = MultiChannelWaveform.from_parallel((wf_const_c, wf_const_d)) + wf_non_const = MultiChannelWaveform.from_parallel((wf_non_const_b, wf_non_const_a)) + wf_mixed = MultiChannelWaveform.from_parallel((wf_non_const_a, wf_const_c)) + + assert_constant_consistent(self, wf_const) + assert_constant_consistent(self, wf_non_const) + assert_constant_consistent(self, wf_mixed) + + self.assertEqual(wf_const.constant_value_dict(), {'C': 2.2, 'D': 3.3}) + self.assertIsNone(wf_non_const.constant_value_dict()) + self.assertIsNone(wf_mixed.constant_value_dict()) + self.assertEqual(wf_mixed.constant_value('C'), 2.2) + class RepetitionWaveformTest(unittest.TestCase): def __init__(self, *args, **kwargs): @@ -222,6 +306,17 @@ def test_init(self): self.assertIs(wf._body, body_wf) self.assertEqual(wf._repetition_count, 3) + assert_constant_consistent(self, wf) + + def test_from_repetition_count(self): + dwf = DummyWaveform() + self.assertEqual(RepetitionWaveform(dwf, 3), RepetitionWaveform.from_repetition_count(dwf, 3)) + + cwf = ConstantWaveform(duration=3, amplitude=2.2, channel='A') + with mock.patch.object(ConstantWaveform, 'from_mapping', return_value=mock.sentinel) as from_mapping: + self.assertIs(from_mapping.return_value, RepetitionWaveform.from_repetition_count(cwf, 5)) + from_mapping.assert_called_once_with(15, {'A': 2.2}) + def test_duration(self): wf = RepetitionWaveform(DummyWaveform(duration=2.2), 3) self.assertEqual(wf.duration, TimeType.from_float(2.2)*3) @@ -281,10 +376,35 @@ def test_init(self): self.assertEqual(swf1.duration, 2*dwf_ab.duration) self.assertEqual(len(swf1.compare_key), 2) - swf2 = SequenceWaveform.from_sequence((swf1, dwf_ab)) + swf2 = SequenceWaveform((swf1, dwf_ab)) self.assertEqual(swf2.duration, 3 * dwf_ab.duration) - self.assertEqual(len(swf2.compare_key), 3) + self.assertEqual(len(swf2.compare_key), 2) + + def test_from_sequence(self): + dwf = DummyWaveform(duration=1.1, defined_channels={'A'}) + + self.assertIs(dwf, SequenceWaveform.from_sequence((dwf,))) + + swf1 = SequenceWaveform.from_sequence((dwf, dwf)) + swf2 = SequenceWaveform.from_sequence((swf1, dwf)) + + self.assertEqual(3*(dwf,), swf2.sequenced_waveforms) + + cwf_2_a = ConstantWaveform(duration=1.1, amplitude=2.2, channel='A') + cwf_3 = ConstantWaveform(duration=1.1, amplitude=3.3, channel='A') + cwf_2_b = ConstantWaveform(duration=1.1, amplitude=2.2, channel='A') + + with mock.patch.object(ConstantWaveform, 'from_mapping', return_value=mock.sentinel) as from_mapping: + new_constant = SequenceWaveform.from_sequence((cwf_2_a, cwf_2_b)) + self.assertIs(from_mapping.return_value, new_constant) + from_mapping.assert_called_once_with(2*TimeType.from_float(1.1), {'A': 2.2}) + + swf3 = SequenceWaveform.from_sequence((cwf_2_a, dwf)) + self.assertEqual((cwf_2_a, dwf), swf3.sequenced_waveforms) + + swf3 = SequenceWaveform.from_sequence((cwf_2_a, cwf_3)) + self.assertEqual((cwf_2_a, cwf_3), swf3.sequenced_waveforms) def test_sample_times_type(self) -> None: with mock.patch.object(DummyWaveform, 'unsafe_sample') as unsafe_sample_patch: @@ -339,7 +459,6 @@ def test_unsafe_get_subset_for_channels(self): class ConstantWaveformTests(unittest.TestCase): - def test_waveform_duration(self): waveform = ConstantWaveform(10, 1., 'P1') self.assertEqual(waveform.duration, 10) @@ -350,6 +469,23 @@ def test_waveform_sample(self): result = waveform.unsafe_sample('P1', sample_times) self.assertTrue(np.all(result == .1)) + self.assertIs(waveform, waveform.unsafe_get_subset_for_channels({'A'})) + + def test_from_mapping(self): + from_single = ConstantWaveform.from_mapping(1., {'A': 2.}) + expected_single = ConstantWaveform(duration=1., amplitude=2., channel='A') + self.assertEqual(expected_single, from_single) + + from_multi = ConstantWaveform.from_mapping(1., {'A': 2., 'B': 3.}) + expected_from_multi = MultiChannelWaveform([ConstantWaveform(duration=1., amplitude=2., channel='A'), + ConstantWaveform(duration=1., amplitude=3., channel='B')]) + self.assertEqual(expected_from_multi, from_multi) + + def test_constness(self): + waveform = ConstantWaveform(10, .1, 'P1') + self.assertTrue(waveform.is_constant()) + assert_constant_consistent(self, waveform) + class TableWaveformTests(unittest.TestCase): @@ -365,11 +501,16 @@ def test_validate_input_errors(self): TableWaveform._validate_input([TableWaveformEntry(0.1, 0.2, HoldInterpolationStrategy()), TableWaveformEntry(0.2, 0.2, HoldInterpolationStrategy())]) - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, "not increasing"): TableWaveform._validate_input([TableWaveformEntry(0.0, 0.2, HoldInterpolationStrategy()), TableWaveformEntry(0.2, 0.2, HoldInterpolationStrategy()), TableWaveformEntry(0.1, 0.2, HoldInterpolationStrategy())]) + with self.assertRaisesRegex(ValueError, "Negative"): + TableWaveform._validate_input([TableWaveformEntry(0.0, 0.2, HoldInterpolationStrategy()), + TableWaveformEntry(-0.2, 0.2, HoldInterpolationStrategy()), + TableWaveformEntry(0.1, 0.2, HoldInterpolationStrategy())]) + def test_validate_input_duplicate_removal(self): validated = TableWaveform._validate_input([TableWaveformEntry(0.0, 0.2, HoldInterpolationStrategy()), TableWaveformEntry(0.1, 0.2, LinearInterpolationStrategy()), @@ -439,6 +580,10 @@ def test_simple_properties(self): self.assertEqual(waveform.defined_channels, {chan}) self.assertIs(waveform.unsafe_get_subset_for_channels({'A'}), waveform) + assert_constant_consistent(self, waveform) + + evaled = eval(repr(waveform)) + self.assertEqual(evaled, waveform) class WaveformEntryTest(unittest.TestCase): @@ -446,9 +591,14 @@ def test_interpolation_exception(self): with self.assertRaises(TypeError): TableWaveformEntry(1, 2, 3) + def test_repr(self): + interpolation = DummyInterpolationStrategy() + self.assertEqual(f"TableWaveformEntry(t={1.}, v={2.}, interp={interpolation})", + repr(TableWaveformEntry(t=1., v=2., interp=interpolation))) + class TransformationDummy(TransformationStub): - def __init__(self, output_channels=None, transformed=None, input_channels=None): + def __init__(self, output_channels=None, transformed=None, input_channels=None, constant_invariant=None): if output_channels: self.get_output_channels = mock.MagicMock(return_value=output_channels) @@ -458,8 +608,33 @@ def __init__(self, output_channels=None, transformed=None, input_channels=None): if transformed is not None: type(self).__call__ = mock.MagicMock(return_value=transformed) + if constant_invariant is not None: + self.is_constant_invariant = mock.MagicMock(return_value=constant_invariant) + class TransformingWaveformTest(unittest.TestCase): + def test_from_transformation(self): + const_output = {'c': 4.4, 'd': 5.5, 'e': 6.6} + trafo = TransformationDummy(output_channels=const_output.keys(), constant_invariant=False) + const_trafo = TransformationDummy(output_channels=const_output.keys(), constant_invariant=True, + transformed=const_output) + dummy_wf = DummyWaveform(duration=1.5, defined_channels={'a', 'b'}) + const_wf = ConstantWaveform.from_mapping(3, {'a': 2.2, 'b': 3.3}) + + self.assertEqual(TransformingWaveform(inner_waveform=dummy_wf, transformation=trafo), + TransformingWaveform.from_transformation(inner_waveform=dummy_wf, transformation=trafo)) + + self.assertEqual(TransformingWaveform(inner_waveform=dummy_wf, transformation=const_trafo), + TransformingWaveform.from_transformation(inner_waveform=dummy_wf, transformation=const_trafo)) + + self.assertEqual(TransformingWaveform(inner_waveform=const_wf, transformation=trafo), + TransformingWaveform.from_transformation(inner_waveform=const_wf, transformation=trafo)) + + with mock.patch.object(ConstantWaveform, 'from_mapping', return_value=mock.sentinel) as from_mapping: + self.assertIs(from_mapping.return_value, + TransformingWaveform.from_transformation(inner_waveform=const_wf, transformation=const_trafo)) + from_mapping.assert_called_once_with(const_wf.duration, const_output) + def test_simple_properties(self): output_channels = {'c', 'd', 'e'} @@ -523,6 +698,35 @@ def test_unsafe_sample(self): c_time, c_data = pos_args np.testing.assert_equal((time, expected_call_data), pos_args) + def test_const_value(self): + output_channels = {'c', 'd', 'e'} + trafo = TransformationStub() + inner_wf = WaveformStub() + + trafo_wf = TransformingWaveform(inner_wf, trafo) + + self.assertFalse(trafo_wf.is_constant()) + self.assertIsNone(trafo_wf.constant_value_dict()) + + with mock.patch.object(trafo, 'is_constant_invariant', return_value=False) as is_constant_invariant: + self.assertIsNone(trafo_wf.constant_value('A')) + is_constant_invariant.assert_called_once_with() + + with mock.patch.object(trafo, 'is_constant_invariant', return_value=True): + # all inputs constant + inner_const_values = {'A': 1.1, 'B': 2.2} + + with mock.patch.object(trafo, 'get_input_channels', return_value=inner_const_values.keys()): + with mock.patch.object(inner_wf, 'constant_value', side_effect=inner_const_values.values()) as constant_value: + with mock.patch.object(TransformationStub, '__call__', return_value={'C': mock.sentinel}) as call: + self.assertIs(trafo_wf.constant_value('C'), call.return_value['C']) + call.assert_called_once_with(0., inner_const_values) + self.assertEqual([mock.call(ch) for ch in inner_const_values], constant_value.call_args_list) + + inner_const_values['B'] = None + with mock.patch.object(inner_wf, 'constant_value', side_effect=inner_const_values.values()) as constant_value: + self.assertIsNone(trafo_wf.constant_value('C')) + class SubsetWaveformTest(unittest.TestCase): def test_simple_properties(self): @@ -565,6 +769,47 @@ def test_unsafe_sample(self): class ArithmeticWaveformTest(unittest.TestCase): + def test_from_operator(self): + lhs = DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'c'}) + rhs = DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'd'}) + + lhs_const = ConstantWaveform.from_mapping(1.5, {'a': 1.1, 'b': 2.2, 'c': 3.3}) + rhs_const = ConstantWaveform.from_mapping(1.5, {'a': 1.2, 'b': 2.4, 'd': 3.4}) + + self.assertEqual(ArithmeticWaveform(lhs, '+', rhs), ArithmeticWaveform.from_operator(lhs, '+', rhs)) + self.assertEqual(ArithmeticWaveform(lhs_const, '+', rhs), ArithmeticWaveform.from_operator(lhs_const, '+', rhs)) + self.assertEqual(ArithmeticWaveform(lhs, '+', rhs_const), ArithmeticWaveform.from_operator(lhs, '+', rhs_const)) + + expected = ConstantWaveform.from_mapping(1.5, {'a': 1.1-1.2, 'b': 2.2-2.4, 'c': 3.3, 'd': -3.4}) + consted = ArithmeticWaveform.from_operator(lhs_const, '-', rhs_const) + self.assertEqual(expected, consted) + + def test_const_propagation(self): + lhs = MultiChannelWaveform([ + DummyWaveform(duration=1.5, defined_channels={'a', 'c', 'd'}), + ConstantWaveform.from_mapping(1.5, {'e': 1.2, 'f': 1.3, 'h': 4.6}) + ]) + rhs = MultiChannelWaveform([ + DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'e'}), + ConstantWaveform.from_mapping(1.5, {'f': 2.5, 'g': 3.5}) + ]) + + wf = ArithmeticWaveform(lhs, '-', rhs) + + expected = {'a': None, + 'b': None, + 'c': None, + 'd': None, + 'e': None, + 'f': 1.3-2.5, + 'g': -3.5, + 'h': 4.6} + + actual = {ch: wf.constant_value(ch) for ch in wf.defined_channels} + self.assertEqual(expected, actual) + + + def test_simple_properties(self): lhs = DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'c'}) rhs = DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'd'}) @@ -620,3 +865,132 @@ def test_unsafe_sample(self): np.testing.assert_equal(data, result) if output_array is not None: self.assertIs(result, output_array) + + +class FunctionWaveformTest(unittest.TestCase): + + def test_equality(self) -> None: + wf1a = FunctionWaveform(ExpressionScalar('2*t'), 3, channel='A') + wf1b = FunctionWaveform(ExpressionScalar('2*t'), 3, channel='A') + wf3 = FunctionWaveform(ExpressionScalar('2*t+2'), 3, channel='A') + wf4 = FunctionWaveform(ExpressionScalar('2*t'), 4, channel='A') + self.assertEqual(wf1a, wf1a) + self.assertEqual(wf1a, wf1b) + self.assertNotEqual(wf1a, wf3) + self.assertNotEqual(wf1a, wf4) + + def test_defined_channels(self) -> None: + wf = FunctionWaveform(ExpressionScalar('t'), 4, channel='A') + self.assertEqual({'A'}, wf.defined_channels) + + def test_duration(self) -> None: + wf = FunctionWaveform(expression=ExpressionScalar('2*t'), duration=4/5, + channel='A') + self.assertEqual(TimeType.from_float(4/5), wf.duration) + + def test_unsafe_sample(self): + fw = FunctionWaveform(ExpressionScalar('sin(2*pi*t) + 3'), 5, channel='A') + + t = np.linspace(0, 5, dtype=float) + expected_result = np.sin(2*np.pi*t) + 3 + result = fw.unsafe_sample(channel='A', sample_times=t) + np.testing.assert_equal(result, expected_result) + + out_array = np.empty_like(t) + result = fw.unsafe_sample(channel='A', sample_times=t, output_array=out_array) + np.testing.assert_equal(result, expected_result) + self.assertIs(result, out_array) + + def test_constant_evaluation(self): + # cause for 596 + fw = FunctionWaveform(ExpressionScalar(3), 5, channel='A') + t = np.linspace(0, 5, dtype=float) + expected_result = np.full_like(t, fill_value=3.) + out_array = np.full_like(t, fill_value=np.nan) + result = fw.unsafe_sample(channel='A', sample_times=t, output_array=out_array) + self.assertIs(result, out_array) + np.testing.assert_equal(result, expected_result) + + result = fw.unsafe_sample(channel='A', sample_times=t) + np.testing.assert_equal(result, expected_result) + + assert_constant_consistent(self, fw) + + def test_unsafe_get_subset_for_channels(self): + fw = FunctionWaveform(ExpressionScalar('sin(2*pi*t) + 3'), 5, channel='A') + self.assertIs(fw.unsafe_get_subset_for_channels({'A'}), fw) + + def test_construction(self): + with self.assertRaises(ValueError): + FunctionWaveform(ExpressionScalar('sin(omega*t)'), duration=5, channel='A') + + const = FunctionWaveform.from_expression(ExpressionScalar('4.'), duration=5, channel='A') + expected_const = ConstantWaveform(duration=5, amplitude=4., channel='A') + self.assertEqual(expected_const, const) + + linear = FunctionWaveform.from_expression(ExpressionScalar('4.*t'), 5, 'A') + expected_linear = FunctionWaveform(ExpressionScalar('4.*t'), 5, 'A') + self.assertEqual(expected_linear, linear) + + +class FunctorWaveformTests(unittest.TestCase): + def test_from_functor(self): + dummy_wf = DummyWaveform(1.5, defined_channels={'A', 'B'}) + const_wf = ConstantWaveform.from_mapping(1.5, {'A': 1.1, 'B': 2.2}) + + wf = FunctorWaveform.from_functor(dummy_wf, {'A': np.negative, 'B': np.positive}) + self.assertEqual(FunctorWaveform(dummy_wf, {'A': np.negative, 'B': np.positive}), wf) + self.assertFalse(wf.is_constant()) + assert_constant_consistent(self, wf) + + wf = FunctorWaveform.from_functor(const_wf, {'A': np.negative, 'B': np.positive}) + self.assertEqual(ConstantWaveform.from_mapping(1.5, {'A': -1.1, 'B': 2.2}), wf) + assert_constant_consistent(self, wf) + + def test_const_value(self): + mixed_wf = MultiChannelWaveform([DummyWaveform(1.5, defined_channels={'A'}), + ConstantWaveform(1.5, 1.1, 'B')]) + wf = FunctorWaveform(mixed_wf, {'A': np.negative, 'B': np.negative}) + self.assertIsNone(wf.constant_value('A')) + self.assertEqual(-1.1, wf.constant_value('B')) + + def test_unsafe_sample(self): + inner_wf = DummyWaveform(defined_channels={'A', 'B'}) + functors = dict(A=mock.Mock(return_value=1.), B=mock.Mock(return_value=2.)) + wf = FunctorWaveform(inner_wf, functors) + + with mock.patch.object(inner_wf, 'unsafe_sample', return_value=mock.sentinel) as inner_sample: + self.assertEqual(wf.unsafe_sample('A', 3.14, 6.75), 1.) + inner_sample.assert_called_once_with('A', 3.14, 6.75) + functors['A'].assert_called_once_with(inner_sample.return_value, out=inner_sample.return_value) + + def test_unsafe_get_subset_for_channels(self): + inner_wf = DummyWaveform(defined_channels={'A', 'B'}) + inner_subset_wf = DummyWaveform(defined_channels={'A'}) + functors = dict(A=mock.Mock(return_value=1.), B=mock.Mock(return_value=2.)) + inner_functors = {'A': functors['A']} + wf = FunctorWaveform(inner_wf, functors) + + with mock.patch.object(inner_wf, 'unsafe_get_subset_for_channels', return_value=inner_subset_wf) as inner_subset: + self.assertEqual(FunctorWaveform(inner_subset_wf, inner_functors), + wf.unsafe_get_subset_for_channels({'A'})) + inner_subset.assert_called_once_with({'A'}) + + def test_compare_key(self): + inner_wf_1 = DummyWaveform(defined_channels={'A', 'B'}) + inner_wf_2 = DummyWaveform(defined_channels={'A', 'B'}) + functors_1 = dict(A=np.positive, B=np.negative) + functors_2 = dict(A=np.negative, B=np.negative) + + wf11 = FunctorWaveform(inner_wf_1, functors_1) + wf12 = FunctorWaveform(inner_wf_1, functors_2) + wf21 = FunctorWaveform(inner_wf_2, functors_1) + wf22 = FunctorWaveform(inner_wf_2, functors_2) + + self.assertEqual((inner_wf_1, frozenset(functors_1.items())), wf11.compare_key) + self.assertEqual(wf11, wf11) + self.assertEqual(wf11, FunctorWaveform(inner_wf_1, functors_1)) + + self.assertNotEqual(wf11, wf12) + self.assertNotEqual(wf11, wf21) + self.assertNotEqual(wf11, wf22) diff --git a/tests/pulses/function_pulse_tests.py b/tests/pulses/function_pulse_tests.py index 657e4c6d9..a2fabcf2c 100644 --- a/tests/pulses/function_pulse_tests.py +++ b/tests/pulses/function_pulse_tests.py @@ -3,11 +3,11 @@ import numpy as np from qupulse.utils.types import TimeType -from qupulse.pulses.function_pulse_template import FunctionPulseTemplate,\ - FunctionWaveform +from qupulse.pulses.function_pulse_template import FunctionPulseTemplate from qupulse.serialization import Serializer, Serializable, PulseStorage from qupulse.expressions import Expression from qupulse.pulses.parameters import ParameterConstraintViolation, ParameterConstraint +from qupulse._program.waveforms import FunctionWaveform from tests.serialization_dummies import DummySerializer, DummyStorageBackend from tests.pulses.sequencing_dummies import DummyParameter @@ -216,56 +216,6 @@ def tpt_constructor(measurements=None): to_test_constructor=tpt_constructor, **kwargs) -class FunctionWaveformTest(unittest.TestCase): - - def test_equality(self) -> None: - wf1a = FunctionWaveform(Expression('2*t'), 3, channel='A') - wf1b = FunctionWaveform(Expression('2*t'), 3, channel='A') - wf3 = FunctionWaveform(Expression('2*t+2'), 3, channel='A') - wf4 = FunctionWaveform(Expression('2*t'), 4, channel='A') - self.assertEqual(wf1a, wf1a) - self.assertEqual(wf1a, wf1b) - self.assertNotEqual(wf1a, wf3) - self.assertNotEqual(wf1a, wf4) - - def test_defined_channels(self) -> None: - wf = FunctionWaveform(Expression('t'), 4, channel='A') - self.assertEqual({'A'}, wf.defined_channels) - - def test_duration(self) -> None: - wf = FunctionWaveform(expression=Expression('2*t'), duration=4/5, - channel='A') - self.assertEqual(TimeType.from_float(4/5), wf.duration) - - def test_unsafe_sample(self): - fw = FunctionWaveform(Expression('sin(2*pi*t) + 3'), 5, channel='A') - - t = np.linspace(0, 5, dtype=float) - expected_result = np.sin(2*np.pi*t) + 3 - result = fw.unsafe_sample(channel='A', sample_times=t) - np.testing.assert_equal(result, expected_result) - - out_array = np.empty_like(t) - result = fw.unsafe_sample(channel='A', sample_times=t, output_array=out_array) - np.testing.assert_equal(result, expected_result) - self.assertIs(result, out_array) - - def test_constant_evaluation(self): - # cause for 596 - fw = FunctionWaveform(Expression(3), 5, channel='A') - t = np.linspace(0, 5, dtype=float) - expected_result = np.full_like(t, fill_value=3.) - out_array = np.full_like(t, fill_value=np.nan) - result = fw.unsafe_sample(channel='A', sample_times=t, output_array=out_array) - self.assertIs(result, out_array) - np.testing.assert_equal(result, expected_result) - - result = fw.unsafe_sample(channel='A', sample_times=t) - np.testing.assert_equal(result, expected_result) - - def test_unsafe_get_subset_for_channels(self): - fw = FunctionWaveform(Expression('sin(2*pi*t) + 3'), 5, channel='A') - self.assertIs(fw.unsafe_get_subset_for_channels({'A'}), fw) class FunctionPulseMeasurementTest(unittest.TestCase): diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index cf99aa121..9fe759fd3 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -154,15 +154,16 @@ def last_value(self, channel) -> float: class DummyInterpolationStrategy(InterpolationStrategy): - def __init__(self) -> None: + def __init__(self, id_ = None) -> None: self.call_arguments = [] + self._id = id(self) if id_ is None else id_ def __call__(self, start: Tuple[float, float], end: Tuple[float, float], times: numpy.ndarray) -> numpy.ndarray: self.call_arguments.append((start, end, list(times))) return times def __repr__(self) -> str: - return "DummyInterpolationStrategy {}".format(id(self)) + return f"DummyInterpolationStrategy({id(self)})" @property def integral(self) -> ExpressionScalar: From ef4a130eec7c85e5b69601e482741e497b79940c Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Sat, 12 Jun 2021 23:01:57 +0200 Subject: [PATCH 20/23] Increase test coverage --- tests/_program/waveforms_tests.py | 52 +++++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 6 deletions(-) diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index 5e5d20ece..bb01c084b 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -11,7 +11,7 @@ TableWaveformEntry, TableWaveform, TransformingWaveform, SubsetWaveform, ArithmeticWaveform, ConstantWaveform,\ Waveform, FunctorWaveform, FunctionWaveform from qupulse._program.transformation import LinearTransformation -from qupulse.expressions import ExpressionScalar +from qupulse.expressions import ExpressionScalar, Expression from tests.pulses.sequencing_dummies import DummyWaveform, DummyInterpolationStrategy from tests._program.transformation_tests import TransformationStub @@ -357,6 +357,12 @@ def test_unsafe_sample(self): self.assertIs(output_expected, output_received) np.testing.assert_equal(output_received, inner_sample_times) + def test_repr(self): + body_wf = ConstantWaveform(amplitude=1.1, duration=1.3, channel='3') + wf = RepetitionWaveform(body_wf, 3) + r = repr(wf) + self.assertEqual(wf, eval(r)) + class SequenceWaveformTest(unittest.TestCase): def __init__(self, *args, **kwargs) -> None: @@ -389,6 +395,9 @@ def test_from_sequence(self): swf1 = SequenceWaveform.from_sequence((dwf, dwf)) swf2 = SequenceWaveform.from_sequence((swf1, dwf)) + assert_constant_consistent(self, swf1) + assert_constant_consistent(self, swf2) + self.assertEqual(3*(dwf,), swf2.sequenced_waveforms) cwf_2_a = ConstantWaveform(duration=1.1, amplitude=2.2, channel='A') @@ -402,9 +411,13 @@ def test_from_sequence(self): swf3 = SequenceWaveform.from_sequence((cwf_2_a, dwf)) self.assertEqual((cwf_2_a, dwf), swf3.sequenced_waveforms) + self.assertIsNone(swf3.constant_value('A')) + assert_constant_consistent(self, swf3) swf3 = SequenceWaveform.from_sequence((cwf_2_a, cwf_3)) self.assertEqual((cwf_2_a, cwf_3), swf3.sequenced_waveforms) + self.assertIsNone(swf3.constant_value('A')) + assert_constant_consistent(self, swf3) def test_sample_times_type(self) -> None: with mock.patch.object(DummyWaveform, 'unsafe_sample') as unsafe_sample_patch: @@ -457,6 +470,13 @@ def test_unsafe_get_subset_for_channels(self): self.assertEqual(sub_wf.compare_key[0].duration, TimeType.from_float(2.2)) self.assertEqual(sub_wf.compare_key[1].duration, TimeType.from_float(3.3)) + def test_repr(self): + cwf_2_a = ConstantWaveform(duration=1.1, amplitude=2.2, channel='A') + cwf_3 = ConstantWaveform(duration=1.1, amplitude=3.3, channel='A') + swf = SequenceWaveform([cwf_2_a, cwf_3]) + r = repr(swf) + self.assertEqual(swf, eval(r)) + class ConstantWaveformTests(unittest.TestCase): def test_waveform_duration(self): @@ -489,6 +509,15 @@ def test_constness(self): class TableWaveformTests(unittest.TestCase): + def test_from_table(self): + expected = ConstantWaveform(0.1, 0.2, 'A') + + for interp in (HoldInterpolationStrategy(), JumpInterpolationStrategy(), LinearInterpolationStrategy()): + wf = TableWaveform.from_table('A', + [TableWaveformEntry(0.0, 0.2, interp), + TableWaveformEntry(0.1, 0.2, interp)]) + self.assertEqual(expected, wf) + def test_validate_input_errors(self): with self.assertRaises(ValueError): TableWaveform._validate_input([TableWaveformEntry(0.0, 0.2, HoldInterpolationStrategy())]) @@ -786,16 +815,18 @@ def test_from_operator(self): def test_const_propagation(self): lhs = MultiChannelWaveform([ - DummyWaveform(duration=1.5, defined_channels={'a', 'c', 'd'}), + DummyWaveform(duration=1.5, defined_channels={'a', 'c', 'd', 'i'}), ConstantWaveform.from_mapping(1.5, {'e': 1.2, 'f': 1.3, 'h': 4.6}) ]) rhs = MultiChannelWaveform([ DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'e'}), - ConstantWaveform.from_mapping(1.5, {'f': 2.5, 'g': 3.5}) + ConstantWaveform.from_mapping(1.5, {'f': 2.5, 'g': 3.5, 'i': 6.4}) ]) wf = ArithmeticWaveform(lhs, '-', rhs) + assert_constant_consistent(self, wf) + expected = {'a': None, 'b': None, 'c': None, @@ -803,13 +834,12 @@ def test_const_propagation(self): 'e': None, 'f': 1.3-2.5, 'g': -3.5, - 'h': 4.6} + 'h': 4.6, + 'i': None} actual = {ch: wf.constant_value(ch) for ch in wf.defined_channels} self.assertEqual(expected, actual) - - def test_simple_properties(self): lhs = DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'c'}) rhs = DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'd'}) @@ -932,8 +962,18 @@ def test_construction(self): expected_linear = FunctionWaveform(ExpressionScalar('4.*t'), 5, 'A') self.assertEqual(expected_linear, linear) + def test_repr(self): + wf = FunctionWaveform(ExpressionScalar('sin(2*pi*t) + 3'), 5, channel='A') + r = repr(wf) + self.assertEqual(wf, eval(r)) + class FunctorWaveformTests(unittest.TestCase): + def test_duration(self): + dummy_wf = DummyWaveform(1.5, defined_channels={'A', 'B'}) + f_wf = FunctorWaveform.from_functor(dummy_wf, {'A': np.negative, 'B': np.positive}) + self.assertIs(dummy_wf.duration, f_wf.duration) + def test_from_functor(self): dummy_wf = DummyWaveform(1.5, defined_channels={'A', 'B'}) const_wf = ConstantWaveform.from_mapping(1.5, {'A': 1.1, 'B': 2.2}) From 4b6f6d606892b5d4b41f815e27a655c623503170 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Sat, 12 Jun 2021 23:02:40 +0200 Subject: [PATCH 21/23] Add constant propagation to interpolation --- qupulse/_program/waveforms.py | 23 +++++++++++++++++------ qupulse/pulses/interpolation.py | 25 +++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index c8e4d1d12..bb5e37ecc 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -277,11 +277,19 @@ def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: @classmethod def from_table(cls, channel: ChannelID, table: Sequence[EntryInInit]) -> Union['TableWaveform', 'ConstantWaveform']: table = cls._validate_input(table) - v = table[0].v - if all(entry.v == v for entry in table): - return ConstantWaveform(table[-1].t, v, channel) + v = None + for entry1, entry2 in pairwise(table): + piece = entry2.interp.constant_value(entry1[:2], entry2[:2]) + if piece is None: + break + if v is None: + v = piece + elif piece != v: + break else: - return TableWaveform(channel, table) + return ConstantWaveform(duration=table[-1].t, amplitude=v, channel=channel) + + return TableWaveform(channel, table) @property def compare_key(self) -> Any: @@ -330,7 +338,7 @@ class ConstantWaveform(Waveform): # TODO: remove _is_constant_waveform = True - def __init__(self, duration: float, amplitude: Any, channel: ChannelID): + def __init__(self, duration: Real, amplitude: Any, channel: ChannelID): """ Create a qupulse waveform corresponding to a ConstantPulseTemplate """ self._duration = duration self._amplitude = amplitude @@ -359,7 +367,10 @@ def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]: @property def duration(self) -> TimeType: - return time_from_float(float(self._duration), absolute_error=PULSE_TO_WAVEFORM_ERROR) + if isinstance(self._duration, TimeType): + return self._duration + else: + return time_from_float(float(self._duration), absolute_error=PULSE_TO_WAVEFORM_ERROR) @property def defined_channels(self) -> AbstractSet[ChannelID]: diff --git a/qupulse/pulses/interpolation.py b/qupulse/pulses/interpolation.py index 5a61f2215..da9a15f5c 100644 --- a/qupulse/pulses/interpolation.py +++ b/qupulse/pulses/interpolation.py @@ -9,7 +9,7 @@ from abc import ABCMeta, abstractmethod -from typing import Any, Tuple +from typing import Any, Tuple, Optional import numpy as np from qupulse.expressions import ExpressionScalar @@ -66,7 +66,19 @@ def __ne__(self, other: Any) -> bool: def __hash__(self) -> int: return hash(self.__repr__()) - + def constant_value(self, start: Tuple[float, float], end: Tuple[float, float]) -> Optional[float]: + """The value of the interpolation if it is constant. + + Args: + start: The start point of the interpolation as (time, value) pair. + end: The end point of the interpolation as (time, value) pair. + + Returns: + The value of the interpolation if it is constant + """ + return None + + class LinearInterpolationStrategy(InterpolationStrategy): """An InterpolationStrategy that interpolates linearly between two points.""" @@ -91,6 +103,9 @@ def __str__(self) -> str: def __repr__(self) -> str: return "" + def constant_value(self, start: Tuple[float, float], end: Tuple[float, float]) -> Optional[float]: + return start[1] if start[1] == end[1] else None + class HoldInterpolationStrategy(InterpolationStrategy): """An InterpolationStrategy that interpolates by holding the value of the start point for the @@ -122,6 +137,9 @@ def __str__(self) -> str: def __repr__(self) -> str: return "" + def constant_value(self, start: Tuple[float, float], end: Tuple[float, float]) -> Optional[float]: + return start[1] + class JumpInterpolationStrategy(InterpolationStrategy): """An InterpolationStrategy that interpolates by holding the value of the end point for the @@ -152,3 +170,6 @@ def __str__(self) -> str: def __repr__(self) -> str: return "" + + def constant_value(self, start: Tuple[float, float], end: Tuple[float, float]) -> Optional[float]: + return end[1] From 4a16aa7739cd8188cea2f2162d0c96a54d0ea8af Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Sat, 12 Jun 2021 23:25:35 +0200 Subject: [PATCH 22/23] Use from_table --- qupulse/pulses/point_pulse_template.py | 7 ++----- qupulse/pulses/table_pulse_template.py | 14 +++++--------- tests/pulses/point_pulse_template_tests.py | 14 +++++++------- 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/qupulse/pulses/point_pulse_template.py b/qupulse/pulses/point_pulse_template.py index 43cc3d40e..c9d69b245 100644 --- a/qupulse/pulses/point_pulse_template.py +++ b/qupulse/pulses/point_pulse_template.py @@ -92,13 +92,10 @@ def build_waveform(self, if ch is not None] mapped_channels, waveform_entries = zip(*channel_entries) - waveforms = [PointWaveform(mapped_channel, ch_entries) + waveforms = [PointWaveform.from_table(mapped_channel, ch_entries) for mapped_channel, ch_entries in zip(mapped_channels, waveform_entries)] - if len(waveforms) == 1: - return waveforms.pop() - else: - return MultiChannelWaveform(waveforms) + return MultiChannelWaveform.from_parallel(waveforms) @property def point_pulse_entries(self) -> Sequence[PointPulseEntry]: diff --git a/qupulse/pulses/table_pulse_template.py b/qupulse/pulses/table_pulse_template.py index 0a986f253..deee214d5 100644 --- a/qupulse/pulses/table_pulse_template.py +++ b/qupulse/pulses/table_pulse_template.py @@ -316,24 +316,20 @@ def build_waveform(self, MultiChannelWaveform]]: self.validate_parameter_constraints(parameters, volatile=set()) - if all(channel_mapping[channel] is None - for channel in self.defined_channels): - return None - instantiated = [(channel_mapping[channel], instantiated_channel) for channel, instantiated_channel in self.get_entries_instantiated(parameters).items() if channel_mapping[channel] is not None] + if not instantiated: + return None + if self.duration.evaluate_numeric(**parameters) == 0: return None - waveforms = [TableWaveform(*ch_instantiated) + waveforms = [TableWaveform.from_table(*ch_instantiated) for ch_instantiated in instantiated] - if len(waveforms) == 1: - return waveforms.pop() - else: - return MultiChannelWaveform(waveforms) + return MultiChannelWaveform.from_parallel(waveforms) @staticmethod def from_array(times: np.ndarray, voltages: np.ndarray, channels: List[ChannelID]) -> 'TablePulseTemplate': diff --git a/tests/pulses/point_pulse_template_tests.py b/tests/pulses/point_pulse_template_tests.py index 3962af181..9316379ed 100644 --- a/tests/pulses/point_pulse_template_tests.py +++ b/tests/pulses/point_pulse_template_tests.py @@ -118,7 +118,7 @@ def test_build_waveform_single_channel(self): parameters = {'t1': 0.1, 't2': 1., 'A': 1., 'B': 2., 'C': 19.} wf = ppt.build_waveform(parameters=parameters, channel_mapping={0: 1}) - expected = PointWaveform(1, [(0, 1., HoldInterpolationStrategy()), + expected = PointWaveform.from_table(1, [(0, 1., HoldInterpolationStrategy()), (0.1, 1., HoldInterpolationStrategy()), (1., 0., HoldInterpolationStrategy()), (1.1, 21., LinearInterpolationStrategy())]) @@ -132,7 +132,7 @@ def test_build_waveform_single_channel_with_measurements(self): parameters = {'t1': 0.1, 't2': 1., 'A': 1., 'B': 2., 'C': 19., 'n': 0.2} wf = ppt.build_waveform(parameters=parameters, channel_mapping={0: 1}) - expected = PointWaveform(1, [(0, 1., HoldInterpolationStrategy()), + expected = PointWaveform.from_table(1, [(0, 1., HoldInterpolationStrategy()), (0.1, 1., HoldInterpolationStrategy()), (1., 0., HoldInterpolationStrategy()), (1.1, 21., LinearInterpolationStrategy())]) @@ -145,11 +145,11 @@ def test_build_waveform_multi_channel_same(self): parameters = {'t1': 0.1, 't2': 1., 'A': 1., 'B': 2., 'C': 19., 'n': 0.2} wf = ppt.build_waveform(parameters=parameters, channel_mapping={0: 1, 'A': 'A'}) - expected_1 = PointWaveform(1, [(0, 1., HoldInterpolationStrategy()), + expected_1 = PointWaveform.from_table(1, ((0, 1., HoldInterpolationStrategy()), (0.1, 1., HoldInterpolationStrategy()), (1., 0., HoldInterpolationStrategy()), - (1.1, 21., LinearInterpolationStrategy())]) - expected_A = PointWaveform('A', [(0, 1., HoldInterpolationStrategy()), + (1.1, 21., LinearInterpolationStrategy()))) + expected_A = PointWaveform.from_table('A', [(0, 1., HoldInterpolationStrategy()), (0.1, 1., HoldInterpolationStrategy()), (1., 0., HoldInterpolationStrategy()), (1.1, 21., LinearInterpolationStrategy())]) @@ -164,11 +164,11 @@ def test_build_waveform_multi_channel_vectorized(self): parameters = {'t1': 0.1, 't2': 1., 'A': np.ones(2), 'B': np.arange(2), 'C': 19., 'n': 0.2} wf = ppt.build_waveform(parameters=parameters, channel_mapping={0: 1, 'A': 'A'}) - expected_1 = PointWaveform(1, [(0, 1., HoldInterpolationStrategy()), + expected_1 = PointWaveform.from_table(1, [(0, 1., HoldInterpolationStrategy()), (0.1, 1., HoldInterpolationStrategy()), (1., 0., HoldInterpolationStrategy()), (1.1, 19., LinearInterpolationStrategy())]) - expected_A = PointWaveform('A', [(0, 1., HoldInterpolationStrategy()), + expected_A = PointWaveform.from_table('A', [(0, 1., HoldInterpolationStrategy()), (0.1, 1., HoldInterpolationStrategy()), (1., 0., HoldInterpolationStrategy()), (1.1, 20., LinearInterpolationStrategy())]) From f05e0c7474af790a2104b412c6dd5d22279a3f9b Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Thu, 8 Jul 2021 11:50:30 +0200 Subject: [PATCH 23/23] Add newsfragment [skip ci] --- changes.d/588.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes.d/588.feature diff --git a/changes.d/588.feature b/changes.d/588.feature new file mode 100644 index 000000000..915df01a7 --- /dev/null +++ b/changes.d/588.feature @@ -0,0 +1 @@ +Adds the methods `is_constant`, `constant_value_dict` and `constant_value` to Waveform class to allow more efficient AWG usage.