diff --git a/qctoolkit/_program/_loop.py b/qctoolkit/_program/_loop.py index ddb4cce6a..6cbbd8be3 100644 --- a/qctoolkit/_program/_loop.py +++ b/qctoolkit/_program/_loop.py @@ -15,8 +15,7 @@ from qctoolkit.utils.types import MeasurementWindow from qctoolkit.utils import is_integer -from qctoolkit.pulses.sequence_pulse_template import SequenceWaveform -from qctoolkit.pulses.repetition_pulse_template import RepetitionWaveform +from qctoolkit._program.waveforms import SequenceWaveform, RepetitionWaveform __all__ = ['Loop', 'MultiChannelProgram', 'make_compatible'] diff --git a/qctoolkit/_program/transformation.py b/qctoolkit/_program/transformation.py new file mode 100644 index 000000000..a626e6ae7 --- /dev/null +++ b/qctoolkit/_program/transformation.py @@ -0,0 +1,55 @@ +from typing import Mapping, Set, Dict +from abc import abstractmethod + +import numpy as np +import pandas as pd + +from qctoolkit import ChannelID +from qctoolkit.comparable import Comparable + + +class Transformation(Comparable): + """Transforms numeric time-voltage values for multiple channels to other time-voltage values. The number and names + of input and output channels might differ.""" + + @abstractmethod + def __call__(self, time: np.ndarray, data: pd.DataFrame) -> pd.DataFrame: + """Apply transformation to data + Args: + time: + data: + + Returns: + transformed: A DataFrame that has been transformed with index == output_channels + """ + + @abstractmethod + def get_output_channels(self, input_channels: Set[ChannelID]) -> Set[ChannelID]: + """Return the channel identifiers""" + + +class LinearTransformation(Transformation): + def __init__(self, transformation_matrix: pd.DataFrame): + """ + + Args: + transformation_matrix: columns are input and index are output channels + """ + self._matrix = transformation_matrix + + def __call__(self, time: np.ndarray, data: pd.DataFrame) -> Mapping[ChannelID, np.ndarray]: + data_in = pd.DataFrame(data) + if set(data_in.index) != set(self._matrix.columns): + raise KeyError('Invalid input channels', set(data_in.index), set(self._matrix.columns)) + + return self._matrix @ data_in + + def get_output_channels(self, input_channels: Set[ChannelID]) -> Set[ChannelID]: + if input_channels != set(self._matrix.columns): + raise KeyError('Invalid input channels', input_channels, set(self._matrix.columns)) + + return set(self._matrix.index) + + @property + def compare_key(self) -> Dict[ChannelID, Dict[ChannelID, float]]: + return self._matrix.to_dict() diff --git a/qctoolkit/_program/waveforms.py b/qctoolkit/_program/waveforms.py index 6c92768bd..112c597e0 100644 --- a/qctoolkit/_program/waveforms.py +++ b/qctoolkit/_program/waveforms.py @@ -6,10 +6,11 @@ import itertools from abc import ABCMeta, abstractmethod -from weakref import WeakValueDictionary -from typing import Union, Set, Sequence, NamedTuple, Tuple, Any, List, Iterable +from weakref import WeakValueDictionary, ref +from typing import Union, Set, Sequence, NamedTuple, Tuple, Any, Iterable, FrozenSet, Optional import numpy as np +import pandas as pd from qctoolkit import ChannelID from qctoolkit.utils import checked_int_cast @@ -17,10 +18,11 @@ from qctoolkit.comparable import Comparable from qctoolkit.expressions import ExpressionScalar from qctoolkit.pulses.interpolation import InterpolationStrategy +from qctoolkit._program.transformation import Transformation __all__ = ["Waveform", "TableWaveform", "TableWaveformEntry", "FunctionWaveform", "SequenceWaveform", - "MultiChannelWaveform", "RepetitionWaveform"] + "MultiChannelWaveform", "RepetitionWaveform", "TransformingWaveform"] class Waveform(Comparable, metaclass=ABCMeta): @@ -481,3 +483,94 @@ def duration(self) -> TimeType: def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'RepetitionWaveform': return RepetitionWaveform(body=self._body.unsafe_get_subset_for_channels(channels), repetition_count=self._repetition_count) + + +class TransformingWaveform(Waveform): + def __init__(self, inner_waveform: Waveform, transformation: Transformation): + """""" + self._inner_waveform = inner_waveform + self._transformation = transformation + + # cache data of inner channels based identified and invalidated by the sample times + self._cached_data = None + self._cached_times = lambda: None + + @property + def inner_waveform(self) -> Waveform: + return self._inner_waveform + + @property + def transformation(self) -> Transformation: + return self._transformation + + @property + def defined_channels(self) -> Set[ChannelID]: + return self.transformation.get_output_channels(self.inner_waveform.defined_channels) + + @property + def compare_key(self) -> Tuple[Waveform, Transformation]: + return self.inner_waveform, self.transformation + + @property + def duration(self) -> TimeType: + return self.inner_waveform.duration + + def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'SubsetWaveform': + return SubsetWaveform(self, channel_subset=channels) + + def unsafe_sample(self, + channel: ChannelID, + sample_times: np.ndarray, + output_array: Union[np.ndarray, None] = None) -> np.ndarray: + if self._cached_times() is not sample_times: + inner_channels = tuple(self.inner_waveform.defined_channels) + inner_data = np.empty((len(inner_channels), sample_times.size)) + + for idx, inner_channel in enumerate(inner_channels): + self.inner_waveform.unsafe_sample(inner_channel, sample_times, + output_array=inner_data[idx, :]) + + inner_data = pd.DataFrame(inner_data, index=inner_channels) + + outer_data = self.transformation(sample_times, inner_data) + + self._cached_data = outer_data + self._cached_times = ref(sample_times) + + if output_array is None: + output_array = self._cached_data.loc[channel].values + else: + output_array[:] = self._cached_data.loc[channel].values + + return output_array + + +class SubsetWaveform(Waveform): + def __init__(self, inner_waveform: Waveform, channel_subset: Set[ChannelID]): + self._inner_waveform = inner_waveform + self._channel_subset = frozenset(channel_subset) + + @property + def inner_waveform(self) -> Waveform: + return self._inner_waveform + + @property + def defined_channels(self) -> FrozenSet[ChannelID]: + return self._channel_subset + + @property + def duration(self) -> TimeType: + return self.inner_waveform.duration + + @property + def compare_key(self) -> Tuple[frozenset, Waveform]: + return self.defined_channels, self.inner_waveform + + def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: + return self.inner_waveform.get_subset_for_channels(channels) + + def unsafe_sample(self, + channel: ChannelID, + sample_times: np.ndarray, + output_array: Union[np.ndarray, None]=None) -> np.ndarray: + return self.inner_waveform.unsafe_sample(channel, sample_times, output_array) diff --git a/qctoolkit/pulses/__init__.py b/qctoolkit/pulses/__init__.py index 253e1c63f..b51690d71 100644 --- a/qctoolkit/pulses/__init__.py +++ b/qctoolkit/pulses/__init__.py @@ -4,7 +4,7 @@ from qctoolkit.pulses.function_pulse_template import FunctionPulseTemplate as FunctionPT from qctoolkit.pulses.loop_pulse_template import ForLoopPulseTemplate as ForLoopPT from qctoolkit.pulses.multi_channel_pulse_template import AtomicMultiChannelPulseTemplate as AtomicMultiChannelPT -from qctoolkit.pulses.pulse_template_parameter_mapping import MappingPulseTemplate as MappingPT +from qctoolkit.pulses.mapping_pulse_template import MappingPulseTemplate as MappingPT from qctoolkit.pulses.repetition_pulse_template import RepetitionPulseTemplate as RepetitionPT from qctoolkit.pulses.sequence_pulse_template import SequencePulseTemplate as SequencePT from qctoolkit.pulses.table_pulse_template import TablePulseTemplate as TablePT diff --git a/qctoolkit/pulses/mapping_pulse_template.py b/qctoolkit/pulses/mapping_pulse_template.py new file mode 100644 index 000000000..83e9d5ac9 --- /dev/null +++ b/qctoolkit/pulses/mapping_pulse_template.py @@ -0,0 +1,361 @@ + +from typing import Optional, Set, Dict, Union, List, Any, Tuple +import itertools +import numbers + +from qctoolkit.utils.types import ChannelID +from qctoolkit.expressions import Expression, ExpressionScalar +from qctoolkit.pulses.pulse_template import PulseTemplate, MappingTuple +from qctoolkit.pulses.parameters import Parameter, MappedParameter, ParameterNotProvidedException, ParameterConstrainer +from qctoolkit.pulses.sequencing import Sequencer +from qctoolkit._program.instructions import InstructionBlock +from qctoolkit._program.waveforms import Waveform +from qctoolkit.pulses.conditions import Condition +from qctoolkit.serialization import Serializer, PulseRegistryType + +__all__ = [ + "MappingPulseTemplate", + "MissingMappingException", + "UnnecessaryMappingException", +] + + +class MappingPulseTemplate(PulseTemplate, ParameterConstrainer): + """This class can be used to remap parameters, the names of measurement windows and the names of channels. Besides + the standard constructor, there is a static member function from_tuple for convenience. The class also allows + constraining parameters by deriving from ParameterConstrainer""" + def __init__(self, template: PulseTemplate, *, + identifier: Optional[str]=None, + parameter_mapping: Optional[Dict[str, str]]=None, + measurement_mapping: Optional[Dict[str, str]] = None, + channel_mapping: Optional[Dict[ChannelID, ChannelID]] = None, + parameter_constraints: Optional[List[str]]=None, + allow_partial_parameter_mapping: bool=False, + registry: PulseRegistryType=None) -> None: + """Standard constructor for the MappingPulseTemplate. + + Mappings that are not specified are defaulted to identity mappings. Channels and measurement names of the + encapsulated template can be mapped partially by default. F.i. if channel_mapping only contains one of two + channels the other channel name is mapped to itself. + However, if a parameter mapping is specified and one or more parameters are not mapped a MissingMappingException + is raised. To allow partial mappings and enable the same behaviour as for the channel and measurement name + mapping allow_partial_parameter_mapping must be set to True. + Furthermore parameter constrains can be specified. + + :param template: The encapsulated pulse template whose parameters, measurement names and channels are mapped + :param parameter_mapping: if not none, mappings for all parameters must be specified + :param measurement_mapping: mappings for other measurement names are inserted + :param channel_mapping: mappings for other channels are auto inserted + :param parameter_constraints: + :param allow_partial_parameter_mapping: + """ + PulseTemplate.__init__(self, identifier=identifier) + ParameterConstrainer.__init__(self, parameter_constraints=parameter_constraints) + + if parameter_mapping is None: + parameter_mapping = dict((par, par) for par in template.parameter_names) + else: + mapped_internal_parameters = set(parameter_mapping.keys()) + internal_parameters = template.parameter_names + missing_parameter_mappings = internal_parameters - mapped_internal_parameters + if mapped_internal_parameters - internal_parameters: + raise UnnecessaryMappingException(template, mapped_internal_parameters - internal_parameters) + elif missing_parameter_mappings: + if allow_partial_parameter_mapping: + parameter_mapping.update({p: p for p in missing_parameter_mappings}) + else: + raise MissingMappingException(template, internal_parameters - mapped_internal_parameters) + parameter_mapping = dict((k, Expression(v)) for k, v in parameter_mapping.items()) + + measurement_mapping = dict() if measurement_mapping is None else measurement_mapping + internal_names = template.measurement_names + mapped_internal_names = set(measurement_mapping.keys()) + if mapped_internal_names - internal_names: + raise UnnecessaryMappingException(template, mapped_internal_names - internal_names) + missing_name_mappings = internal_names - mapped_internal_names + measurement_mapping = dict(itertools.chain(((name, name) for name in missing_name_mappings), + measurement_mapping.items())) + + channel_mapping = dict() if channel_mapping is None else channel_mapping + internal_channels = template.defined_channels + mapped_internal_channels = set(channel_mapping.keys()) + if mapped_internal_channels - internal_channels: + raise UnnecessaryMappingException(template,mapped_internal_channels - internal_channels) + missing_channel_mappings = internal_channels - mapped_internal_channels + channel_mapping = dict(itertools.chain(((name, name) for name in missing_channel_mappings), + channel_mapping.items())) + + if isinstance(template, MappingPulseTemplate) and template.identifier is None: + # avoid nested mappings + parameter_mapping = {p: Expression(expr.evaluate_symbolic(parameter_mapping)) + for p, expr in template.parameter_mapping.items()} + measurement_mapping = {k: measurement_mapping[v] + for k, v in template.measurement_mapping.items()} + channel_mapping = {k: channel_mapping[v] + for k, v in template.channel_mapping.items()} + template = template.template + + self.__template = template + self.__parameter_mapping = parameter_mapping + self.__external_parameters = set(itertools.chain(*(expr.variables for expr in self.__parameter_mapping.values()))) + self.__external_parameters |= self.constrained_parameters + self.__measurement_mapping = measurement_mapping + self.__channel_mapping = channel_mapping + self._register(registry=registry) + + @staticmethod + def from_tuple(mapping_tuple: MappingTuple) -> 'MappingPulseTemplate': + """Construct a MappingPulseTemplate from a tuple of mappings. The mappings are automatically assigned to the + mapped elements based on their content. + :param mapping_tuple: A tuple of mappings + :return: Constructed MappingPulseTemplate + """ + template, *mappings = mapping_tuple + + parameter_mapping = None + measurement_mapping = None + channel_mapping = None + + for mapping in mappings: + if len(mapping) == 0: + continue + + mapped = set(mapping.keys()) + if sum((mapped <= template.parameter_names, + mapped <= template.measurement_names, + mapped <= template.defined_channels)) > 1: + raise AmbiguousMappingException(template, mapping) + + if mapped == template.parameter_names: + if parameter_mapping: + raise MappingCollisionException(template, object_type='parameter', + mapped=template.parameter_names, + mappings=(parameter_mapping, mapping)) + parameter_mapping = mapping + elif mapped <= template.measurement_names: + if measurement_mapping: + raise MappingCollisionException(template, object_type='measurement', + mapped=template.measurement_names, + mappings=(measurement_mapping, mapping)) + measurement_mapping = mapping + elif mapped <= template.defined_channels: + if channel_mapping: + raise MappingCollisionException(template, object_type='channel', + mapped=template.defined_channels, + mappings=(channel_mapping, mapping)) + channel_mapping = mapping + else: + raise ValueError('Could not match mapping to mapped objects: {}'.format(mapping)) + return MappingPulseTemplate(template, + parameter_mapping=parameter_mapping, + measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping) + + @property + def template(self) -> PulseTemplate: + return self.__template + + @property + def measurement_mapping(self) -> Dict[str, str]: + return self.__measurement_mapping + + @property + def parameter_mapping(self) -> Dict[str, Expression]: + return self.__parameter_mapping + + @property + def channel_mapping(self) -> Dict[ChannelID, ChannelID]: + return self.__channel_mapping + + @property + def parameter_names(self) -> Set[str]: + return self.__external_parameters + + @property + def measurement_names(self) -> Set[str]: + return set(self.__measurement_mapping.values()) + + @property + def is_interruptable(self) -> bool: + return self.template.is_interruptable # pragma: no cover + + @property + def defined_channels(self) -> Set[ChannelID]: + return {self.__channel_mapping[k] for k in self.template.defined_channels} + + @property + def duration(self) -> Expression: + return self.__template.duration.evaluate_symbolic(self.__parameter_mapping) + + def get_serialization_data(self) -> Dict[str, Any]: + data = super().get_serialization_data() + + data['template'] = self.template + if self.__parameter_mapping: + data['parameter_mapping'] = self.__parameter_mapping + if self.__measurement_mapping: + data['measurement_mapping'] = self.__measurement_mapping + if self.__channel_mapping: + data['channel_mapping'] = self.__channel_mapping + + if self.parameter_constraints: + data['parameter_constraints'] = [str(c) for c in self.parameter_constraints] + + return data + + @classmethod + def deserialize(cls, + serializer: Optional[Serializer]=None, # compatibility to old serialization routines, deprecated + **kwargs) -> 'MappingPulseTemplate': + if serializer: # compatibility to old serialization routines, deprecated + kwargs['template'] = serializer.deserialize(kwargs["template"]) + return cls(**kwargs, allow_partial_parameter_mapping=True) + # return MappingPulseTemplate(template=serializer.deserialize(template), + # **kwargs) + + def map_parameters(self, + parameters: Dict[str, Union[Parameter, numbers.Real]]) -> Dict[str, Parameter]: + """Map parameter values according to the defined mappings. + + Args: + parameters (Dict(str -> Parameter)): A mapping of parameter names to Parameter + objects/values. + Returns: + A new dictionary which maps parameter names to parameter values which have been + mapped according to the mappings defined for template. + """ + missing = set(self.__external_parameters) - set(parameters.keys()) + if missing: + raise ParameterNotProvidedException(missing.pop()) + + self.validate_parameter_constraints(parameters=parameters) + if all(isinstance(parameter, Parameter) for parameter in parameters.values()): + return {parameter: MappedParameter(mapping_function, {name: parameters[name] + for name in mapping_function.variables}) + for (parameter, mapping_function) in self.__parameter_mapping.items()} + if all(isinstance(parameter, numbers.Real) for parameter in parameters.values()): + return {parameter: mapping_function.evaluate_numeric(**parameters) + for parameter, mapping_function in self.__parameter_mapping.items()} + raise TypeError('Values of parameter dict are neither all Parameter nor Real') + + def get_updated_measurement_mapping(self, measurement_mapping: Dict[str, str]) -> Dict[str, str]: + return {k: measurement_mapping[v] for k, v in self.__measurement_mapping.items()} + + def get_updated_channel_mapping(self, channel_mapping: Dict[ChannelID, ChannelID]) -> Dict[ChannelID, ChannelID]: + return {inner_ch: channel_mapping[outer_ch] for inner_ch, outer_ch in self.__channel_mapping.items()} + + def build_sequence(self, + sequencer: Sequencer, + parameters: Dict[str, Parameter], + conditions: Dict[str, Condition], + measurement_mapping: Dict[str, str], + channel_mapping: Dict[ChannelID, ChannelID], + instruction_block: InstructionBlock) -> None: + self.template.build_sequence(sequencer, + parameters=self.map_parameters(parameters), + conditions=conditions, + measurement_mapping=self.get_updated_measurement_mapping(measurement_mapping), + channel_mapping=self.get_updated_channel_mapping(channel_mapping), + instruction_block=instruction_block) + + def build_waveform(self, + parameters: Dict[str, numbers.Real], + channel_mapping: Dict[ChannelID, ChannelID]) -> Waveform: + """This gets called if the parent is atomic""" + return self.template.build_waveform( + parameters=self.map_parameters(parameters), + channel_mapping=self.get_updated_channel_mapping(channel_mapping)) + + def requires_stop(self, + parameters: Dict[str, Parameter], + conditions: Dict[str, Condition]) -> bool: + return self.template.requires_stop( + self.map_parameters(parameters), + conditions + ) + + @property + def integral(self) -> Dict[ChannelID, ExpressionScalar]: + internal_integral = self.__template.integral + expressions = dict() + + # sympy.subs() does not work if one of the mappings in the provided dict is an Expression object + # the following is an ugly workaround + # todo: make Expressions compatible with sympy.subs() + parameter_mapping = self.__parameter_mapping.copy() + for i in parameter_mapping: + if isinstance(parameter_mapping[i], ExpressionScalar): + parameter_mapping[i] = parameter_mapping[i].sympified_expression + + for channel in internal_integral: + expr = ExpressionScalar( + internal_integral[channel].sympified_expression.subs(parameter_mapping) + ) + channel_out = channel + if channel in self.__channel_mapping: + channel_out = self.__channel_mapping[channel] + expressions[channel_out] = expr + + return expressions + + +class MissingMappingException(Exception): + """Indicates that no mapping was specified for some parameter declaration of a + SequencePulseTemplate's subtemplate.""" + + def __init__(self, template: PulseTemplate, key: Union[str,Set[str]]) -> None: + super().__init__() + self.key = key + self.template = template + + def __str__(self) -> str: + return "The template {} needs a mapping function for parameter(s) {}".\ + format(self.template, self.key) + + +class UnnecessaryMappingException(Exception): + """Indicates that a mapping was provided that does not correspond to any of a + SequencePulseTemplate's subtemplate's parameter declarations and is thus obsolete.""" + + def __init__(self, template: PulseTemplate, key: Union[str, Set[str]]) -> None: + super().__init__() + self.template = template + self.key = key + + def __str__(self) -> str: + return "Mapping function for parameter(s) '{}', which template {} does not need"\ + .format(self.key, self.template) + + +class AutoMappingMatchingException(Exception): + """Indicates that the auto match of mappings to mapped objects by the keys failed""" + + def __init__(self, template: PulseTemplate): + super().__init__() + self.template = template + + +class AmbiguousMappingException(AutoMappingMatchingException): + """Indicates that a mapping may apply to multiple objects""" + + def __init__(self, template: PulseTemplate, mapping: Dict): + super().__init__(template) + self.mapping = mapping + + def __str__(self) -> str: + return "Could not match mapping uniquely to object type: {}\nParameters: {}\nChannels: {}\nMeasurements: {}"\ + .format(self.mapping, self.template.parameter_names, self.template.defined_channels, + self.template.measurement_names) + + +class MappingCollisionException(AutoMappingMatchingException): + """Indicates that multiple mappings are fitting for the same parameter type""" + def __init__(self, template: PulseTemplate, object_type: str, mapped: Set, mappings: Tuple[Dict, ...]): + super().__init__(template) + self.parameter_type = object_type + self.mappings = mappings + self.message = 'Got multiple candidates for the {type} mapping.\nMapped: {mapped}\nCandidates:\n'\ + .format(type=object_type, mapped=mapped) + + def __str__(self) -> str: + return self.message + '\n'.join(str(mapping) for mapping in self.mappings) diff --git a/qctoolkit/pulses/multi_channel_pulse_template.py b/qctoolkit/pulses/multi_channel_pulse_template.py index 11223876f..e302b8ae6 100644 --- a/qctoolkit/pulses/multi_channel_pulse_template.py +++ b/qctoolkit/pulses/multi_channel_pulse_template.py @@ -19,9 +19,9 @@ from qctoolkit.utils.types import ChannelID, TimeType from qctoolkit._program.waveforms import MultiChannelWaveform from qctoolkit.pulses.pulse_template import PulseTemplate, AtomicPulseTemplate -from qctoolkit.pulses.pulse_template_parameter_mapping import MappingPulseTemplate, MappingTuple +from qctoolkit.pulses.mapping_pulse_template import MappingPulseTemplate, MappingTuple from qctoolkit.pulses.parameters import Parameter, ParameterConstrainer -from qctoolkit.pulses.measurement import MeasurementDeclaration +from qctoolkit.pulses.measurement import MeasurementDeclaration, MeasurementWindow from qctoolkit.expressions import Expression, ExpressionScalar __all__ = ["AtomicMultiChannelPulseTemplate"] @@ -85,7 +85,9 @@ def duration(self) -> Expression: @property def parameter_names(self) -> Set[str]: - return set.union(*(st.parameter_names for st in self._subtemplates)) | self.constrained_parameters + return set.union(self.measurement_parameters, + self.constrained_parameters, + *(st.parameter_names for st in self._subtemplates)) @property def subtemplates(self) -> Sequence[AtomicPulseTemplate]: @@ -97,7 +99,7 @@ def defined_channels(self) -> Set[ChannelID]: @property def measurement_names(self) -> Set[str]: - return set.union(*(st.measurement_names for st in self._subtemplates)) + return super().measurement_names.union(*(st.measurement_names for st in self._subtemplates)) def build_waveform(self, parameters: Dict[str, numbers.Real], channel_mapping: Dict[ChannelID, Optional[ChannelID]]) -> Optional['MultiChannelWaveform']: @@ -117,6 +119,16 @@ def build_waveform(self, parameters: Dict[str, numbers.Real], else: return MultiChannelWaveform(sub_waveforms) + def get_measurement_windows(self, + parameters: Dict[str, numbers.Real], + measurement_mapping: Dict[str, Optional[str]]) -> List[MeasurementWindow]: + measurements = super().get_measurement_windows(parameters=parameters, + measurement_mapping=measurement_mapping) + for st in self.subtemplates: + measurements.extend(st.get_measurement_windows(parameters=parameters, + measurement_mapping=measurement_mapping)) + return measurements + def requires_stop(self, parameters: Dict[str, Parameter], conditions: Dict[str, 'Condition']) -> bool: diff --git a/qctoolkit/pulses/pulse_template_parameter_mapping.py b/qctoolkit/pulses/pulse_template_parameter_mapping.py index 83e9d5ac9..06365b41e 100644 --- a/qctoolkit/pulses/pulse_template_parameter_mapping.py +++ b/qctoolkit/pulses/pulse_template_parameter_mapping.py @@ -1,361 +1,11 @@ +from qctoolkit.pulses.mapping_pulse_template import MappingPulseTemplate -from typing import Optional, Set, Dict, Union, List, Any, Tuple -import itertools -import numbers +__all__ = ["MappingPulseTemplate"] -from qctoolkit.utils.types import ChannelID -from qctoolkit.expressions import Expression, ExpressionScalar -from qctoolkit.pulses.pulse_template import PulseTemplate, MappingTuple -from qctoolkit.pulses.parameters import Parameter, MappedParameter, ParameterNotProvidedException, ParameterConstrainer -from qctoolkit.pulses.sequencing import Sequencer -from qctoolkit._program.instructions import InstructionBlock -from qctoolkit._program.waveforms import Waveform -from qctoolkit.pulses.conditions import Condition -from qctoolkit.serialization import Serializer, PulseRegistryType +import warnings +warnings.warn("MappingPulseTemplate was moved from qctoolkit.pulses.pulse_template_parameter_mapping to " + "qctoolkit.pulses.mapping_pulse_template. Please consider fixing your stored pulse templates by loading " + "and storing them anew.", DeprecationWarning) -__all__ = [ - "MappingPulseTemplate", - "MissingMappingException", - "UnnecessaryMappingException", -] - - -class MappingPulseTemplate(PulseTemplate, ParameterConstrainer): - """This class can be used to remap parameters, the names of measurement windows and the names of channels. Besides - the standard constructor, there is a static member function from_tuple for convenience. The class also allows - constraining parameters by deriving from ParameterConstrainer""" - def __init__(self, template: PulseTemplate, *, - identifier: Optional[str]=None, - parameter_mapping: Optional[Dict[str, str]]=None, - measurement_mapping: Optional[Dict[str, str]] = None, - channel_mapping: Optional[Dict[ChannelID, ChannelID]] = None, - parameter_constraints: Optional[List[str]]=None, - allow_partial_parameter_mapping: bool=False, - registry: PulseRegistryType=None) -> None: - """Standard constructor for the MappingPulseTemplate. - - Mappings that are not specified are defaulted to identity mappings. Channels and measurement names of the - encapsulated template can be mapped partially by default. F.i. if channel_mapping only contains one of two - channels the other channel name is mapped to itself. - However, if a parameter mapping is specified and one or more parameters are not mapped a MissingMappingException - is raised. To allow partial mappings and enable the same behaviour as for the channel and measurement name - mapping allow_partial_parameter_mapping must be set to True. - Furthermore parameter constrains can be specified. - - :param template: The encapsulated pulse template whose parameters, measurement names and channels are mapped - :param parameter_mapping: if not none, mappings for all parameters must be specified - :param measurement_mapping: mappings for other measurement names are inserted - :param channel_mapping: mappings for other channels are auto inserted - :param parameter_constraints: - :param allow_partial_parameter_mapping: - """ - PulseTemplate.__init__(self, identifier=identifier) - ParameterConstrainer.__init__(self, parameter_constraints=parameter_constraints) - - if parameter_mapping is None: - parameter_mapping = dict((par, par) for par in template.parameter_names) - else: - mapped_internal_parameters = set(parameter_mapping.keys()) - internal_parameters = template.parameter_names - missing_parameter_mappings = internal_parameters - mapped_internal_parameters - if mapped_internal_parameters - internal_parameters: - raise UnnecessaryMappingException(template, mapped_internal_parameters - internal_parameters) - elif missing_parameter_mappings: - if allow_partial_parameter_mapping: - parameter_mapping.update({p: p for p in missing_parameter_mappings}) - else: - raise MissingMappingException(template, internal_parameters - mapped_internal_parameters) - parameter_mapping = dict((k, Expression(v)) for k, v in parameter_mapping.items()) - - measurement_mapping = dict() if measurement_mapping is None else measurement_mapping - internal_names = template.measurement_names - mapped_internal_names = set(measurement_mapping.keys()) - if mapped_internal_names - internal_names: - raise UnnecessaryMappingException(template, mapped_internal_names - internal_names) - missing_name_mappings = internal_names - mapped_internal_names - measurement_mapping = dict(itertools.chain(((name, name) for name in missing_name_mappings), - measurement_mapping.items())) - - channel_mapping = dict() if channel_mapping is None else channel_mapping - internal_channels = template.defined_channels - mapped_internal_channels = set(channel_mapping.keys()) - if mapped_internal_channels - internal_channels: - raise UnnecessaryMappingException(template,mapped_internal_channels - internal_channels) - missing_channel_mappings = internal_channels - mapped_internal_channels - channel_mapping = dict(itertools.chain(((name, name) for name in missing_channel_mappings), - channel_mapping.items())) - - if isinstance(template, MappingPulseTemplate) and template.identifier is None: - # avoid nested mappings - parameter_mapping = {p: Expression(expr.evaluate_symbolic(parameter_mapping)) - for p, expr in template.parameter_mapping.items()} - measurement_mapping = {k: measurement_mapping[v] - for k, v in template.measurement_mapping.items()} - channel_mapping = {k: channel_mapping[v] - for k, v in template.channel_mapping.items()} - template = template.template - - self.__template = template - self.__parameter_mapping = parameter_mapping - self.__external_parameters = set(itertools.chain(*(expr.variables for expr in self.__parameter_mapping.values()))) - self.__external_parameters |= self.constrained_parameters - self.__measurement_mapping = measurement_mapping - self.__channel_mapping = channel_mapping - self._register(registry=registry) - - @staticmethod - def from_tuple(mapping_tuple: MappingTuple) -> 'MappingPulseTemplate': - """Construct a MappingPulseTemplate from a tuple of mappings. The mappings are automatically assigned to the - mapped elements based on their content. - :param mapping_tuple: A tuple of mappings - :return: Constructed MappingPulseTemplate - """ - template, *mappings = mapping_tuple - - parameter_mapping = None - measurement_mapping = None - channel_mapping = None - - for mapping in mappings: - if len(mapping) == 0: - continue - - mapped = set(mapping.keys()) - if sum((mapped <= template.parameter_names, - mapped <= template.measurement_names, - mapped <= template.defined_channels)) > 1: - raise AmbiguousMappingException(template, mapping) - - if mapped == template.parameter_names: - if parameter_mapping: - raise MappingCollisionException(template, object_type='parameter', - mapped=template.parameter_names, - mappings=(parameter_mapping, mapping)) - parameter_mapping = mapping - elif mapped <= template.measurement_names: - if measurement_mapping: - raise MappingCollisionException(template, object_type='measurement', - mapped=template.measurement_names, - mappings=(measurement_mapping, mapping)) - measurement_mapping = mapping - elif mapped <= template.defined_channels: - if channel_mapping: - raise MappingCollisionException(template, object_type='channel', - mapped=template.defined_channels, - mappings=(channel_mapping, mapping)) - channel_mapping = mapping - else: - raise ValueError('Could not match mapping to mapped objects: {}'.format(mapping)) - return MappingPulseTemplate(template, - parameter_mapping=parameter_mapping, - measurement_mapping=measurement_mapping, - channel_mapping=channel_mapping) - - @property - def template(self) -> PulseTemplate: - return self.__template - - @property - def measurement_mapping(self) -> Dict[str, str]: - return self.__measurement_mapping - - @property - def parameter_mapping(self) -> Dict[str, Expression]: - return self.__parameter_mapping - - @property - def channel_mapping(self) -> Dict[ChannelID, ChannelID]: - return self.__channel_mapping - - @property - def parameter_names(self) -> Set[str]: - return self.__external_parameters - - @property - def measurement_names(self) -> Set[str]: - return set(self.__measurement_mapping.values()) - - @property - def is_interruptable(self) -> bool: - return self.template.is_interruptable # pragma: no cover - - @property - def defined_channels(self) -> Set[ChannelID]: - return {self.__channel_mapping[k] for k in self.template.defined_channels} - - @property - def duration(self) -> Expression: - return self.__template.duration.evaluate_symbolic(self.__parameter_mapping) - - def get_serialization_data(self) -> Dict[str, Any]: - data = super().get_serialization_data() - - data['template'] = self.template - if self.__parameter_mapping: - data['parameter_mapping'] = self.__parameter_mapping - if self.__measurement_mapping: - data['measurement_mapping'] = self.__measurement_mapping - if self.__channel_mapping: - data['channel_mapping'] = self.__channel_mapping - - if self.parameter_constraints: - data['parameter_constraints'] = [str(c) for c in self.parameter_constraints] - - return data - - @classmethod - def deserialize(cls, - serializer: Optional[Serializer]=None, # compatibility to old serialization routines, deprecated - **kwargs) -> 'MappingPulseTemplate': - if serializer: # compatibility to old serialization routines, deprecated - kwargs['template'] = serializer.deserialize(kwargs["template"]) - return cls(**kwargs, allow_partial_parameter_mapping=True) - # return MappingPulseTemplate(template=serializer.deserialize(template), - # **kwargs) - - def map_parameters(self, - parameters: Dict[str, Union[Parameter, numbers.Real]]) -> Dict[str, Parameter]: - """Map parameter values according to the defined mappings. - - Args: - parameters (Dict(str -> Parameter)): A mapping of parameter names to Parameter - objects/values. - Returns: - A new dictionary which maps parameter names to parameter values which have been - mapped according to the mappings defined for template. - """ - missing = set(self.__external_parameters) - set(parameters.keys()) - if missing: - raise ParameterNotProvidedException(missing.pop()) - - self.validate_parameter_constraints(parameters=parameters) - if all(isinstance(parameter, Parameter) for parameter in parameters.values()): - return {parameter: MappedParameter(mapping_function, {name: parameters[name] - for name in mapping_function.variables}) - for (parameter, mapping_function) in self.__parameter_mapping.items()} - if all(isinstance(parameter, numbers.Real) for parameter in parameters.values()): - return {parameter: mapping_function.evaluate_numeric(**parameters) - for parameter, mapping_function in self.__parameter_mapping.items()} - raise TypeError('Values of parameter dict are neither all Parameter nor Real') - - def get_updated_measurement_mapping(self, measurement_mapping: Dict[str, str]) -> Dict[str, str]: - return {k: measurement_mapping[v] for k, v in self.__measurement_mapping.items()} - - def get_updated_channel_mapping(self, channel_mapping: Dict[ChannelID, ChannelID]) -> Dict[ChannelID, ChannelID]: - return {inner_ch: channel_mapping[outer_ch] for inner_ch, outer_ch in self.__channel_mapping.items()} - - def build_sequence(self, - sequencer: Sequencer, - parameters: Dict[str, Parameter], - conditions: Dict[str, Condition], - measurement_mapping: Dict[str, str], - channel_mapping: Dict[ChannelID, ChannelID], - instruction_block: InstructionBlock) -> None: - self.template.build_sequence(sequencer, - parameters=self.map_parameters(parameters), - conditions=conditions, - measurement_mapping=self.get_updated_measurement_mapping(measurement_mapping), - channel_mapping=self.get_updated_channel_mapping(channel_mapping), - instruction_block=instruction_block) - - def build_waveform(self, - parameters: Dict[str, numbers.Real], - channel_mapping: Dict[ChannelID, ChannelID]) -> Waveform: - """This gets called if the parent is atomic""" - return self.template.build_waveform( - parameters=self.map_parameters(parameters), - channel_mapping=self.get_updated_channel_mapping(channel_mapping)) - - def requires_stop(self, - parameters: Dict[str, Parameter], - conditions: Dict[str, Condition]) -> bool: - return self.template.requires_stop( - self.map_parameters(parameters), - conditions - ) - - @property - def integral(self) -> Dict[ChannelID, ExpressionScalar]: - internal_integral = self.__template.integral - expressions = dict() - - # sympy.subs() does not work if one of the mappings in the provided dict is an Expression object - # the following is an ugly workaround - # todo: make Expressions compatible with sympy.subs() - parameter_mapping = self.__parameter_mapping.copy() - for i in parameter_mapping: - if isinstance(parameter_mapping[i], ExpressionScalar): - parameter_mapping[i] = parameter_mapping[i].sympified_expression - - for channel in internal_integral: - expr = ExpressionScalar( - internal_integral[channel].sympified_expression.subs(parameter_mapping) - ) - channel_out = channel - if channel in self.__channel_mapping: - channel_out = self.__channel_mapping[channel] - expressions[channel_out] = expr - - return expressions - - -class MissingMappingException(Exception): - """Indicates that no mapping was specified for some parameter declaration of a - SequencePulseTemplate's subtemplate.""" - - def __init__(self, template: PulseTemplate, key: Union[str,Set[str]]) -> None: - super().__init__() - self.key = key - self.template = template - - def __str__(self) -> str: - return "The template {} needs a mapping function for parameter(s) {}".\ - format(self.template, self.key) - - -class UnnecessaryMappingException(Exception): - """Indicates that a mapping was provided that does not correspond to any of a - SequencePulseTemplate's subtemplate's parameter declarations and is thus obsolete.""" - - def __init__(self, template: PulseTemplate, key: Union[str, Set[str]]) -> None: - super().__init__() - self.template = template - self.key = key - - def __str__(self) -> str: - return "Mapping function for parameter(s) '{}', which template {} does not need"\ - .format(self.key, self.template) - - -class AutoMappingMatchingException(Exception): - """Indicates that the auto match of mappings to mapped objects by the keys failed""" - - def __init__(self, template: PulseTemplate): - super().__init__() - self.template = template - - -class AmbiguousMappingException(AutoMappingMatchingException): - """Indicates that a mapping may apply to multiple objects""" - - def __init__(self, template: PulseTemplate, mapping: Dict): - super().__init__(template) - self.mapping = mapping - - def __str__(self) -> str: - return "Could not match mapping uniquely to object type: {}\nParameters: {}\nChannels: {}\nMeasurements: {}"\ - .format(self.mapping, self.template.parameter_names, self.template.defined_channels, - self.template.measurement_names) - - -class MappingCollisionException(AutoMappingMatchingException): - """Indicates that multiple mappings are fitting for the same parameter type""" - def __init__(self, template: PulseTemplate, object_type: str, mapped: Set, mappings: Tuple[Dict, ...]): - super().__init__(template) - self.parameter_type = object_type - self.mappings = mappings - self.message = 'Got multiple candidates for the {type} mapping.\nMapped: {mapped}\nCandidates:\n'\ - .format(type=object_type, mapped=mapped) - - def __str__(self) -> str: - return self.message + '\n'.join(str(mapping) for mapping in self.mappings) +from qctoolkit.serialization import SerializableMeta +SerializableMeta.deserialization_callbacks["qctoolkit.pulses.pulse_template_parameter_mapping.MappingPulseTemplate"] = SerializableMeta.deserialization_callbacks[MappingPulseTemplate.get_type_identifier()] diff --git a/qctoolkit/pulses/sequence_pulse_template.py b/qctoolkit/pulses/sequence_pulse_template.py index 793055fb3..5f54f2481 100644 --- a/qctoolkit/pulses/sequence_pulse_template.py +++ b/qctoolkit/pulses/sequence_pulse_template.py @@ -14,7 +14,7 @@ from qctoolkit.pulses.parameters import Parameter, ParameterConstrainer from qctoolkit.pulses.sequencing import InstructionBlock, Sequencer from qctoolkit.pulses.conditions import Condition -from qctoolkit.pulses.pulse_template_parameter_mapping import MappingPulseTemplate, MappingTuple +from qctoolkit.pulses.mapping_pulse_template import MappingPulseTemplate, MappingTuple from qctoolkit._program.waveforms import SequenceWaveform from qctoolkit.pulses.measurement import MeasurementDeclaration, MeasurementDefiner from qctoolkit.expressions import Expression, ExpressionScalar diff --git a/setup.py b/setup.py index 8e479ea4b..81a4e49a3 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ package_dir={'qctoolkit': 'qctoolkit'}, packages=packages, tests_require=['pytest'], - install_requires=['sympy>=1.1.1', 'numpy'] + requires_typing, + install_requires=['sympy>=1.1.1', 'numpy', 'pandas'] + requires_typing, extras_require={ 'testing': ['pytest'], 'plotting': ['matplotlib'], diff --git a/tests/_program/transformation_tests.py b/tests/_program/transformation_tests.py new file mode 100644 index 000000000..bfb19b1a9 --- /dev/null +++ b/tests/_program/transformation_tests.py @@ -0,0 +1,46 @@ +import unittest + +import pandas as pd +import numpy as np + +from qctoolkit._program.transformation import LinearTransformation + + + +class LinearTransformationTests(unittest.TestCase): + def test_compare_key(self): + trafo_dict = {'transformed_a': {'a': 1, 'b': -1, 'c': 0}, 'transformed_b': {'a': 1, 'b': 1, 'c': 1}} + trafo_matrix = pd.DataFrame(trafo_dict).T + trafo = LinearTransformation(trafo_matrix) + + self.assertEqual(trafo_matrix.to_dict(), trafo.compare_key) + + def test_get_output_channels(self): + trafo_dict = {'transformed_a': {'a': 1, 'b': -1, 'c': 0}, 'transformed_b': {'a': 1, 'b': 1, 'c': 1}} + trafo_matrix = pd.DataFrame(trafo_dict).T + trafo = LinearTransformation(trafo_matrix) + + self.assertEqual(trafo.get_output_channels({'a', 'b', 'c'}), {'transformed_a', 'transformed_b'}) + with self.assertRaisesRegex(KeyError, 'Invalid input channels'): + trafo.get_output_channels({'a', 'b'}) + + def test_call(self): + trafo_dict = {'transformed_a': {'a': 1., 'b': -1., 'c': 0.}, 'transformed_b': {'a': 1., 'b': 1., 'c': 1.}} + trafo_matrix = pd.DataFrame(trafo_dict).T + trafo = LinearTransformation(trafo_matrix) + + data = (np.arange(12.) + 1).reshape((3, 4)) + data = pd.DataFrame(data, index=list('abc')) + + transformed = trafo(np.full(4, np.NaN), data) + + expected = np.empty((2, 4)) + expected[0, :] = data.loc['a'] - data.loc['b'] + expected[1, :] = np.sum(data.values, axis=0) + + expected = pd.DataFrame(expected, index=['transformed_a', 'transformed_b']) + + pd.testing.assert_frame_equal(expected, transformed) + + with self.assertRaisesRegex(KeyError, 'Invalid input channels'): + trafo(np.full(4, np.NaN), data.loc[['a', 'b']]) diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index 8427a5721..4ead7898e 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -1,13 +1,16 @@ import unittest +from unittest import mock import numpy import numpy as np +import pandas as pd from qctoolkit.utils.types import time_from_float from qctoolkit.pulses.interpolation import HoldInterpolationStrategy, LinearInterpolationStrategy,\ JumpInterpolationStrategy from qctoolkit._program.waveforms import MultiChannelWaveform, RepetitionWaveform, SequenceWaveform,\ - TableWaveformEntry, TableWaveform + TableWaveformEntry, TableWaveform, TransformingWaveform, SubsetWaveform +from qctoolkit._program.transformation import Transformation from tests.pulses.sequencing_dummies import DummyWaveform, DummyInterpolationStrategy @@ -413,4 +416,127 @@ def test_simple_properties(self): class WaveformEntryTest(unittest.TestCase): def test_interpolation_exception(self): with self.assertRaises(TypeError): - TableWaveformEntry(1, 2, 3) \ No newline at end of file + TableWaveformEntry(1, 2, 3) + + +class TransformationDummy(Transformation): + def __init__(self, output_channels=None, transformed=None): + if output_channels: + self.get_output_channels = mock.MagicMock(return_value=output_channels) + + if transformed is not None: + type(self).__call__ = mock.MagicMock(return_value=transformed) + + def __call__(self, *args, **kwargs): + raise NotImplementedError() + + get_output_channels = () + + @property + def compare_key(self): + return id(self) + + +class TransformingWaveformTest(unittest.TestCase): + def test_simple_properties(self): + output_channels = {'c', 'd', 'e'} + + trafo = TransformationDummy(output_channels=output_channels) + + inner_wf = DummyWaveform(duration=1.5, defined_channels={'a', 'b'}) + trafo_wf = TransformingWaveform(inner_waveform=inner_wf, transformation=trafo) + + self.assertIs(trafo_wf.inner_waveform, inner_wf) + self.assertIs(trafo_wf.transformation, trafo) + self.assertEqual(trafo_wf.compare_key, (inner_wf, trafo)) + self.assertIs(trafo_wf.duration, inner_wf.duration) + self.assertIs(trafo_wf.defined_channels, output_channels) + trafo.get_output_channels.assert_called_once_with(inner_wf.defined_channels) + + def test_get_subset_for_channels(self): + output_channels = {'c', 'd', 'e'} + + trafo = TransformationDummy(output_channels=output_channels) + + inner_wf = DummyWaveform(duration=1.5, defined_channels={'a', 'b'}) + trafo_wf = TransformingWaveform(inner_waveform=inner_wf, transformation=trafo) + + subset_wf = trafo_wf.get_subset_for_channels({'c', 'd'}) + self.assertIsInstance(subset_wf, SubsetWaveform) + self.assertIs(subset_wf.inner_waveform, trafo_wf) + self.assertEqual(subset_wf.defined_channels, {'c', 'd'}) + + def test_unsafe_sample(self): + time = np.linspace(10, 20, num=25) + ch_a = np.exp(time) + ch_b = np.exp(-time) + ch_c = np.sinh(time) + ch_d = np.cosh(time) + ch_e = np.arctan(time) + + sample_output = {'a': ch_a, 'b': ch_b} + expected_call_data = pd.DataFrame(sample_output).T + + transformed = pd.DataFrame({'c': ch_c, 'd': ch_d, 'e': ch_e}).T + + trafo = TransformationDummy(transformed=transformed) + inner_wf = DummyWaveform(duration=1.5, defined_channels={'a', 'b'}, sample_output=sample_output) + trafo_wf = TransformingWaveform(inner_waveform=inner_wf, transformation=trafo) + + ch_d_out = trafo_wf.unsafe_sample('d', time) + np.testing.assert_equal(ch_d_out, ch_d) + + output = np.empty_like(time) + ch_d_out = trafo_wf.unsafe_sample('d', time, output_array=output) + self.assertIs(output, ch_d_out) + np.testing.assert_equal(ch_d_out, ch_d) + + call_list = TransformationDummy.__call__.call_args_list + self.assertEqual(len(call_list), 1) + + (pos_args, kw_args), = call_list + self.assertEqual(len(kw_args), 0) + + c_time, c_data = pos_args + np.testing.assert_equal(time, c_time) + pd.testing.assert_frame_equal(expected_call_data.sort_index(), c_data.sort_index()) + + +class SubsetWaveformTest(unittest.TestCase): + def test_simple_properties(self): + inner_wf = DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'c'}) + + subset_wf = SubsetWaveform(inner_wf, {'a', 'c'}) + + self.assertIs(subset_wf.inner_waveform, inner_wf) + self.assertEqual(subset_wf.compare_key, (frozenset(['a', 'c']), inner_wf)) + self.assertIs(subset_wf.duration, inner_wf.duration) + self.assertEqual(subset_wf.defined_channels, {'a', 'c'}) + + def test_get_subset_for_channels(self): + subsetted = DummyWaveform(defined_channels={'a'}) + with mock.patch.object(DummyWaveform, + 'get_subset_for_channels', + mock.Mock(return_value=subsetted)) as get_subset_for_channels: + inner_wf = DummyWaveform(defined_channels={'a', 'b', 'c'}) + subset_wf = SubsetWaveform(inner_wf, {'a', 'c'}) + + actual_subsetted = subset_wf.get_subset_for_channels({'a'}) + get_subset_for_channels.assert_called_once_with({'a'}) + self.assertIs(subsetted, actual_subsetted) + + def test_unsafe_sample(self): + """Test perfect forwarding""" + time = {'time'} + output = {'output'} + expected_data = {'data'} + + with mock.patch.object(DummyWaveform, + 'unsafe_sample', + mock.Mock(return_value=expected_data)) as unsafe_sample: + inner_wf = DummyWaveform(defined_channels={'a', 'b', 'c'}) + subset_wf = SubsetWaveform(inner_wf, {'a', 'c'}) + + actual_data = subset_wf.unsafe_sample('g', time, output) + self.assertIs(expected_data, actual_data) + unsafe_sample.assert_called_once_with('g', time, output) diff --git a/tests/pulses/__init__.py b/tests/pulses/__init__.py index 92cc63621..54923565f 100644 --- a/tests/pulses/__init__.py +++ b/tests/pulses/__init__.py @@ -8,7 +8,7 @@ 'multi_channel_pulse_template_tests', 'parameters_tests', 'plotting_tests', - 'pulse_template_parameter_mapping_tests', + 'mapping_pulse_template_tests.py', 'pulse_template_tests', 'repetition_pulse_template_tests', 'sample_pulse_generator', diff --git a/tests/pulses/function_pulse_tests.py b/tests/pulses/function_pulse_tests.py index 89236e967..cb5f2a544 100644 --- a/tests/pulses/function_pulse_tests.py +++ b/tests/pulses/function_pulse_tests.py @@ -82,6 +82,10 @@ def test_parameter_names_and_declarations_string_input(self) -> None: expected_parameter_names = {'foo', 'bar', 'hugo'} self.assertEqual(expected_parameter_names, template.parameter_names) + def test_integral(self) -> None: + pulse = FunctionPulseTemplate('sin(0.5*t+b)', '2*Tmax') + self.assertEqual({'default': Expression('2.0*cos(b) - 2.0*cos(1.0*Tmax+b)')}, pulse.integral) + class FunctionPulseSerializationTest(SerializableTests, unittest.TestCase): @@ -157,10 +161,6 @@ def test_requires_stop(self) -> None: def test_build_waveform_none(self): self.assertIsNone(self.fpt.build_waveform(self.valid_par_vals, channel_mapping={'A': None})) - def test_integral(self) -> None: - pulse = FunctionPulseTemplate('sin(0.5*t+b)', '2*Tmax') - self.assertEqual({'default': Expression('2.0*cos(b) - 2.0*cos(1.0*Tmax+b)')}, pulse.integral) - class TablePulseTemplateConstraintTest(ParameterConstrainerTest): def __init__(self, *args, **kwargs): diff --git a/tests/pulses/loop_pulse_template_tests.py b/tests/pulses/loop_pulse_template_tests.py index dbad68aa1..98354b64c 100644 --- a/tests/pulses/loop_pulse_template_tests.py +++ b/tests/pulses/loop_pulse_template_tests.py @@ -154,6 +154,20 @@ def test_parameter_names_param_only_in_constraint(self) -> None: loop_range=('a', 'b', 'c',), parameter_constraints=['k<=f']) self.assertEqual(flt.parameter_names, {'k', 'a', 'b', 'c', 'f'}) + def test_integral(self) -> None: + dummy = DummyPulseTemplate(defined_channels={'A', 'B'}, + parameter_names={'t1', 'i'}, + integrals={'A': ExpressionScalar('t1-i*3.1'), 'B': ExpressionScalar('i')}) + + pulse = ForLoopPulseTemplate(dummy, 'i', (1, 8, 2)) + + expected = {'A': ExpressionScalar('Sum(t1-3.1*(1+2*i), (i, 0, 3))'), + 'B': ExpressionScalar('Sum((1+2*i), (i, 0, 3))') } + self.assertEqual(expected, pulse.integral) + + +class ForLoopTemplateSequencingTests(unittest.TestCase): + def test_build_sequence_constraint_on_loop_var_exception(self): """This test is to assure the status-quo behavior of ForLoopPT handling parameter constraints affecting the loop index variable. Please see https://github.com/qutech/qc-toolkit/issues/232 .""" @@ -212,17 +226,6 @@ def test_requires_stop(self): parameters['A'] = DummyParameter(requires_stop=True) self.assertTrue(flt.requires_stop(parameters, dict())) - def test_integral(self) -> None: - dummy = DummyPulseTemplate(defined_channels={'A', 'B'}, - parameter_names={'t1', 'i'}, - integrals={'A': ExpressionScalar('t1-i*3.1'), 'B': ExpressionScalar('i')}) - - pulse = ForLoopPulseTemplate(dummy, 'i', (1, 8, 2)) - - expected = {'A': ExpressionScalar('Sum(t1-3.1*(1+2*i), (i, 0, 3))'), - 'B': ExpressionScalar('Sum((1+2*i), (i, 0, 3))') } - self.assertEqual(expected, pulse.integral) - class ForLoopPulseTemplateSerializationTests(SerializableTests, unittest.TestCase): diff --git a/tests/pulses/mapping_pulse_template_tests.py b/tests/pulses/mapping_pulse_template_tests.py new file mode 100644 index 000000000..b49ebc050 --- /dev/null +++ b/tests/pulses/mapping_pulse_template_tests.py @@ -0,0 +1,290 @@ +import unittest +import itertools + +from qctoolkit.pulses.mapping_pulse_template import MissingMappingException,\ + UnnecessaryMappingException, MappingPulseTemplate,\ + AmbiguousMappingException, MappingCollisionException +from qctoolkit.pulses.parameters import ParameterNotProvidedException +from qctoolkit.pulses.parameters import ConstantParameter, ParameterConstraintViolation, ParameterConstraint +from qctoolkit.expressions import Expression + +from tests.pulses.sequencing_dummies import DummyPulseTemplate, DummySequencer, DummyInstructionBlock +from tests.serialization_tests import SerializableTests +from tests.serialization_dummies import DummySerializer + + +class MappingTemplateTests(unittest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def test_init_exceptions(self): + template = DummyPulseTemplate(parameter_names={'foo', 'bar'}, defined_channels={'A'}, measurement_names={'B'}) + parameter_mapping = {'foo': 't*k', 'bar': 't*l'} + + with self.assertRaises(MissingMappingException): + MappingPulseTemplate(template, parameter_mapping={}) + with self.assertRaises(MissingMappingException): + MappingPulseTemplate(template, parameter_mapping={'bar': 'kneipe'}) + with self.assertRaises(UnnecessaryMappingException): + MappingPulseTemplate(template, parameter_mapping=dict(**parameter_mapping, foobar='asd')) + + with self.assertRaises(UnnecessaryMappingException): + MappingPulseTemplate(template, parameter_mapping=parameter_mapping, measurement_mapping=dict(a='b')) + with self.assertRaises(UnnecessaryMappingException): + MappingPulseTemplate(template, parameter_mapping=parameter_mapping, channel_mapping=dict(a='b')) + + with self.assertRaises(TypeError): + MappingPulseTemplate(template, parameter_mapping) + + MappingPulseTemplate(template, parameter_mapping=parameter_mapping) + + def test_from_tuple_exceptions(self): + template = DummyPulseTemplate(parameter_names={'foo', 'bar'}, + measurement_names={'foo', 'foobar'}, + defined_channels={'bar', 'foobar'}) + + with self.assertRaises(ValueError): + MappingPulseTemplate.from_tuple((template, {'A': 'B'})) + with self.assertRaises(AmbiguousMappingException): + MappingPulseTemplate.from_tuple((template, {'foo': 'foo'})) + with self.assertRaises(AmbiguousMappingException): + MappingPulseTemplate.from_tuple((template, {'bar': 'bar'})) + with self.assertRaises(AmbiguousMappingException): + MappingPulseTemplate.from_tuple((template, {'foobar': 'foobar'})) + + template = DummyPulseTemplate(parameter_names={'foo', 'bar'}) + with self.assertRaises(MappingCollisionException): + MappingPulseTemplate.from_tuple((template, {'foo': '1', 'bar': 2}, {'foo': '1', 'bar': 4})) + + template = DummyPulseTemplate(defined_channels={'A'}) + with self.assertRaises(MappingCollisionException): + MappingPulseTemplate.from_tuple((template, {'A': 'N'}, {'A': 'C'})) + + template = DummyPulseTemplate(measurement_names={'M'}) + with self.assertRaises(MappingCollisionException): + MappingPulseTemplate.from_tuple((template, {'M': 'N'}, {'M': 'N'})) + + def test_from_tuple(self): + template = DummyPulseTemplate(parameter_names={'foo', 'bar'}, + measurement_names={'m1', 'm2'}, + defined_channels={'c1', 'c2'}) + + def test_mapping_permutations(template: DummyPulseTemplate, + pmap, mmap, cmap): + direct = MappingPulseTemplate(template, + parameter_mapping=pmap, + measurement_mapping=mmap, + channel_mapping=cmap) + + mappings = [m for m in [pmap, mmap, cmap] if m is not None] + + for current_mapping_order in itertools.permutations(mappings): + mapper = MappingPulseTemplate.from_tuple((template, *current_mapping_order)) + self.assertEqual(mapper.measurement_mapping, direct.measurement_mapping) + self.assertEqual(mapper.channel_mapping, direct.channel_mapping) + self.assertEqual(mapper.parameter_mapping, direct.parameter_mapping) + + test_mapping_permutations(template, {'foo': 1, 'bar': 2}, {'m1': 'n1', 'm2': 'n2'}, {'c1': 'd1', 'c2': 'd2'}) + test_mapping_permutations(template, {'foo': 1, 'bar': 2}, {'m1': 'n1'}, {'c1': 'd1', 'c2': 'd2'}) + test_mapping_permutations(template, {'foo': 1, 'bar': 2}, None, {'c1': 'd1', 'c2': 'd2'}) + test_mapping_permutations(template, {'foo': 1, 'bar': 2}, {'m1': 'n1', 'm2': 'n2'}, {'c1': 'd1'}) + test_mapping_permutations(template, {'foo': 1, 'bar': 2}, {'m1': 'n1', 'm2': 'n2'}, None) + test_mapping_permutations(template, None, {'m1': 'n1', 'm2': 'n2'}, {'c1': 'd1', 'c2': 'd2'}) + test_mapping_permutations(template, None, {'m1': 'n1'}, {'c1': 'd1', 'c2': 'd2'}) + test_mapping_permutations(template, None, None, {'c1': 'd1', 'c2': 'd2'}) + test_mapping_permutations(template, None, {'m1': 'n1', 'm2': 'n2'}, {'c1': 'd1'}) + test_mapping_permutations(template, None, {'m1': 'n1', 'm2': 'n2'}, None) + + def test_external_params(self): + template = DummyPulseTemplate(parameter_names={'foo', 'bar'}) + st = MappingPulseTemplate(template, parameter_mapping={'foo': 't*k', 'bar': 't*l'}) + external_params = {'t', 'l', 'k'} + self.assertEqual(st.parameter_names, external_params) + + def test_constrained(self): + template = DummyPulseTemplate(parameter_names={'foo', 'bar'}) + st = MappingPulseTemplate(template, parameter_mapping={'foo': 't*k', 'bar': 't*l'}, parameter_constraints=['t < m']) + external_params = {'t', 'l', 'k', 'm'} + self.assertEqual(st.parameter_names, external_params) + + with self.assertRaises(ParameterConstraintViolation): + st.map_parameters(dict(t=1, l=2, k=3, m=0)) + + def test_map_parameters(self): + template = DummyPulseTemplate(parameter_names={'foo', 'bar'}) + st = MappingPulseTemplate(template, parameter_mapping={'foo': 't*k', 'bar': 't*l'}) + + parameters = {'t': ConstantParameter(3), 'k': ConstantParameter(2), 'l': ConstantParameter(7)} + values = {'foo': 6, 'bar': 21} + for k, v in st.map_parameters(parameters).items(): + self.assertEqual(v.get_value(), values[k]) + parameters.popitem() + with self.assertRaises(ParameterNotProvidedException): + st.map_parameters(parameters) + + parameters = dict(t=3, k=2, l=7) + values = {'foo': 6, 'bar': 21} + for k, v in st.map_parameters(parameters).items(): + self.assertEqual(v, values[k]) + + def test_partial_parameter_mapping(self): + template = DummyPulseTemplate(parameter_names={'foo', 'bar'}) + st = MappingPulseTemplate(template, parameter_mapping={'foo': 't*k'}, allow_partial_parameter_mapping=True) + + self.assertEqual(st.parameter_mapping, {'foo': 't*k', 'bar': 'bar'}) + + def test_nested_mapping_avoidance(self): + template = DummyPulseTemplate(parameter_names={'foo', 'bar'}) + st_1 = MappingPulseTemplate(template, parameter_mapping={'foo': 't*k'}, allow_partial_parameter_mapping=True) + st_2 = MappingPulseTemplate(st_1, parameter_mapping={'bar': 't*l'}, allow_partial_parameter_mapping=True) + + self.assertIs(st_2.template, template) + self.assertEqual(st_2.parameter_mapping, {'foo': 't*k', 'bar': 't*l'}) + + st_3 = MappingPulseTemplate(template, + parameter_mapping={'foo': 't*k'}, + allow_partial_parameter_mapping=True, + identifier='käse') + st_4 = MappingPulseTemplate(st_3, parameter_mapping={'bar': 't*l'}, allow_partial_parameter_mapping=True) + self.assertIs(st_4.template, st_3) + self.assertEqual(st_4.parameter_mapping, {'t': 't', 'k': 'k', 'bar': 't*l'}) + + def test_get_updated_channel_mapping(self): + template = DummyPulseTemplate(defined_channels={'foo', 'bar'}) + st = MappingPulseTemplate(template, channel_mapping={'bar': 'kneipe'}) + with self.assertRaises(KeyError): + st.get_updated_channel_mapping(dict()) + self.assertEqual(st.get_updated_channel_mapping({'kneipe': 'meas1', 'foo': 'meas2', 'troet': 'meas3'}), + {'foo': 'meas2', 'bar': 'meas1'}) + + def test_measurement_names(self): + template = DummyPulseTemplate(measurement_names={'foo', 'bar'}) + st = MappingPulseTemplate(template, measurement_mapping={'foo': 'froop', 'bar': 'kneipe'}) + self.assertEqual( st.measurement_names, {'froop','kneipe'} ) + + def test_defined_channels(self): + mapping = {'asd': 'A', 'fgh': 'B'} + template = DummyPulseTemplate(defined_channels=set(mapping.keys())) + st = MappingPulseTemplate(template, channel_mapping=mapping) + self.assertEqual(st.defined_channels, set(mapping.values())) + + def test_get_updated_measurement_mapping(self): + template = DummyPulseTemplate(measurement_names={'foo', 'bar'}) + st = MappingPulseTemplate(template, measurement_mapping={'bar': 'kneipe'}) + with self.assertRaises(KeyError): + st.get_updated_measurement_mapping(dict()) + self.assertEqual(st.get_updated_measurement_mapping({'kneipe': 'meas1', 'foo': 'meas2', 'troet': 'meas3'}), + {'foo': 'meas2', 'bar': 'meas1'}) + + def test_integral(self) -> None: + dummy = DummyPulseTemplate(defined_channels={'A', 'B'}, + parameter_names={'k', 'f', 'b'}, + integrals={'A': Expression('2*k'), 'other': Expression('-3.2*f+b')}) + pulse = MappingPulseTemplate(dummy, parameter_mapping={'k': 'f', 'b': 2.3}, channel_mapping={'A': 'default'}, + allow_partial_parameter_mapping=True) + + self.assertEqual({'default': Expression('2*f'), 'other': Expression('-3.2*f+2.3')}, pulse.integral) + + +class MappingPulseTemplateSequencingTests(unittest.TestCase): + + def test_build_sequence(self): + measurement_mapping = {'meas1': 'meas2'} + parameter_mapping = {'t': 'k'} + + template = DummyPulseTemplate(measurement_names=set(measurement_mapping.keys()), + parameter_names=set(parameter_mapping.keys())) + st = MappingPulseTemplate(template, parameter_mapping=parameter_mapping, measurement_mapping=measurement_mapping) + sequencer = DummySequencer() + block = DummyInstructionBlock() + pre_parameters = {'k': ConstantParameter(5)} + pre_measurement_mapping = {'meas2': 'meas3'} + pre_channel_mapping = {'default': 'A'} + conditions = dict(a=True) + st.build_sequence(sequencer, pre_parameters, conditions, pre_measurement_mapping, pre_channel_mapping, block) + + self.assertEqual(template.build_sequence_calls, 1) + forwarded_args = template.build_sequence_arguments[0] + self.assertEqual(forwarded_args[0], sequencer) + self.assertEqual(forwarded_args[1], st.map_parameters(pre_parameters)) + self.assertEqual(forwarded_args[2], conditions) + self.assertEqual(forwarded_args[3], + st.get_updated_measurement_mapping(pre_measurement_mapping)) + self.assertEqual(forwarded_args[4], + st.get_updated_channel_mapping(pre_channel_mapping)) + self.assertEqual(forwarded_args[5], block) + + @unittest.skip("Extend of dummy template for argument checking needed.") + def test_requires_stop(self): + pass + +class PulseTemplateParameterMappingExceptionsTests(unittest.TestCase): + + def test_missing_mapping_exception_str(self) -> None: + dummy = DummyPulseTemplate() + exception = MissingMappingException(dummy, 'foo') + self.assertIsInstance(str(exception), str) + + def test_unnecessary_mapping_exception_str(self) -> None: + dummy = DummyPulseTemplate() + exception = UnnecessaryMappingException(dummy, 'foo') + self.assertIsInstance(str(exception), str) + + +class MappingPulseTemplateSerializationTests(SerializableTests, unittest.TestCase): + + @property + def class_to_test(self): + return MappingPulseTemplate + + def make_kwargs(self): + return { + 'template': DummyPulseTemplate(defined_channels={'foo'}, + measurement_names={'meas'}, + parameter_names={'hugo', 'herbert', 'ilse'}), + 'parameter_mapping': {'hugo': Expression('2*k+c'), 'herbert': Expression('c-1.5'), 'ilse': Expression('ilse')}, + 'measurement_mapping': {'meas': 'seam'}, + 'channel_mapping': {'foo': 'default_channel'}, + 'parameter_constraints': [str(ParameterConstraint('c > 0'))] + } + + def make_instance(self, identifier=None, registry=None): + kwargs = self.make_kwargs() + return self.class_to_test(identifier=identifier, **kwargs, allow_partial_parameter_mapping=True, registry=registry) + + def assert_equal_instance_except_id(self, lhs: MappingPulseTemplate, rhs: MappingPulseTemplate): + self.assertIsInstance(lhs, MappingPulseTemplate) + self.assertIsInstance(rhs, MappingPulseTemplate) + self.assertEqual(lhs.template, rhs.template) + self.assertEqual(lhs.parameter_constraints, rhs.parameter_constraints) + self.assertEqual(lhs.channel_mapping, rhs.channel_mapping) + self.assertEqual(lhs.measurement_mapping, rhs.measurement_mapping) + self.assertEqual(lhs.parameter_mapping, rhs.parameter_mapping) + + +class MappingPulseTemplateOldSerializationTests(unittest.TestCase): + + def test_deserialize(self) -> None: + # test for deprecated version during transition period, remove after final switch + with self.assertWarnsRegex(DeprecationWarning, "deprecated", + msg="SequencePT does not issue warning for old serialization routines."): + dummy_pt = DummyPulseTemplate(defined_channels={'foo'}, + measurement_names={'meas'}, + parameter_names={'hugo', 'herbert', 'ilse'}) + serializer = DummySerializer() + data = { + 'template': serializer.dictify(dummy_pt), + 'parameter_mapping': {'hugo': str(Expression('2*k+c')), 'herbert': str(Expression('c-1.5')), + 'ilse': str(Expression('ilse'))}, + 'measurement_mapping': {'meas': 'seam'}, + 'channel_mapping': {'foo': 'default_channel'}, + 'parameter_constraints': [str(ParameterConstraint('c > 0'))] + } + deserialized = MappingPulseTemplate.deserialize(serializer=serializer, **data) + + self.assertIsInstance(deserialized, MappingPulseTemplate) + self.assertEqual(data['parameter_mapping'], deserialized.parameter_mapping) + self.assertEqual(data['channel_mapping'], deserialized.channel_mapping) + self.assertEqual(data['measurement_mapping'], deserialized.measurement_mapping) + self.assertEqual(data['parameter_constraints'], [str(pc) for pc in deserialized.parameter_constraints]) + self.assertIs(deserialized.template, dummy_pt) \ No newline at end of file diff --git a/tests/pulses/multi_channel_pulse_template_tests.py b/tests/pulses/multi_channel_pulse_template_tests.py index fd4a7dd2b..266942a0d 100644 --- a/tests/pulses/multi_channel_pulse_template_tests.py +++ b/tests/pulses/multi_channel_pulse_template_tests.py @@ -5,8 +5,9 @@ from qctoolkit.utils.types import time_from_float from qctoolkit.pulses.multi_channel_pulse_template import MultiChannelWaveform, MappingPulseTemplate, ChannelMappingException, AtomicMultiChannelPulseTemplate -from qctoolkit.pulses.parameters import ParameterConstraint, ParameterConstraintViolation +from qctoolkit.pulses.parameters import ParameterConstraint, ParameterConstraintViolation, ConstantParameter from qctoolkit.expressions import ExpressionScalar, Expression +from qctoolkit._program.instructions import InstructionBlock from tests.pulses.sequencing_dummies import DummyPulseTemplate, DummyWaveform from tests.serialization_dummies import DummySerializer @@ -14,8 +15,6 @@ from tests.serialization_tests import SerializableTests - - class AtomicMultiChannelPulseTemplateTest(unittest.TestCase): def __init__(self,*args,**kwargs): super().__init__(*args,**kwargs) @@ -118,7 +117,33 @@ def test_measurement_names(self): sts = [DummyPulseTemplate(duration='t1', defined_channels={'A'}, parameter_names={'a', 'b'}, measurement_names={'A', 'C'}), DummyPulseTemplate(duration='t1', defined_channels={'B'}, parameter_names={'a', 'c'}, measurement_names={'A', 'B'})] - self.assertEqual(AtomicMultiChannelPulseTemplate(*sts).measurement_names, {'A', 'B', 'C'}) + self.assertEqual(AtomicMultiChannelPulseTemplate(*sts, measurements=[('D', 1, 2)]).measurement_names, + {'A', 'B', 'C', 'D'}) + + def test_parameter_names(self): + sts = [DummyPulseTemplate(duration='t1', defined_channels={'A'}, parameter_names={'a', 'b'}, + measurement_names={'A', 'C'}), + DummyPulseTemplate(duration='t1', defined_channels={'B'}, parameter_names={'a', 'c'}, + measurement_names={'A', 'B'})] + pt = AtomicMultiChannelPulseTemplate(*sts, measurements=[('D', 'd', 2)], parameter_constraints=['d < e']) + + self.assertEqual(pt.parameter_names, + {'a', 'b', 'c', 'd', 'e'}) + + + def test_integral(self) -> None: + sts = [DummyPulseTemplate(duration='t1', defined_channels={'A'}, + integrals={'A': ExpressionScalar('2+k')}), + DummyPulseTemplate(duration='t1', defined_channels={'B', 'C'}, + integrals={'B': ExpressionScalar('t1-t0*3.1'), 'C': ExpressionScalar('l')})] + pulse = AtomicMultiChannelPulseTemplate(*sts) + self.assertEqual({'A': ExpressionScalar('2+k'), + 'B': ExpressionScalar('t1-t0*3.1'), + 'C': ExpressionScalar('l')}, + pulse.integral) + + +class MultiChannelPulseTemplateSequencingTests(unittest.TestCase): def test_requires_stop(self): sts = [DummyPulseTemplate(duration='t1', defined_channels={'A'}, parameter_names={'a', 'b'}, requires_stop=False), @@ -184,16 +209,45 @@ def test_build_waveform_none(self): wf = pt.build_waveform(parameters, channel_mapping=channel_mapping) self.assertIsNone(wf) - def test_integral(self) -> None: - sts = [DummyPulseTemplate(duration='t1', defined_channels={'A'}, - integrals={'A': ExpressionScalar('2+k')}), - DummyPulseTemplate(duration='t1', defined_channels={'B', 'C'}, - integrals={'B': ExpressionScalar('t1-t0*3.1'), 'C': ExpressionScalar('l')})] - pulse = AtomicMultiChannelPulseTemplate(*sts) - self.assertEqual({'A': ExpressionScalar('2+k'), - 'B': ExpressionScalar('t1-t0*3.1'), - 'C': ExpressionScalar('l')}, - pulse.integral) + def test_build_sequence(self): + wfs = [DummyWaveform(duration=1.1, defined_channels={'A'}), DummyWaveform(duration=1.1, defined_channels={'B'})] + sts = [DummyPulseTemplate(duration='t1', defined_channels={'A'}, waveform=wfs[0], measurements=[('m', 0, 1)]), + DummyPulseTemplate(duration='t1', defined_channels={'B'}, waveform=wfs[1]), + DummyPulseTemplate(duration='t1', defined_channels={'C'}, waveform=None)] + + pt = AtomicMultiChannelPulseTemplate(*sts, parameter_constraints=['a < b'], measurements=[('n', .1, .2)]) + + params = dict(a=ConstantParameter(1.0), b=ConstantParameter(1.1)) + measurement_mapping = dict(m='foo', n='bar') + channel_mapping = {'A': 'A', 'B': 'B', 'C': None} + + block = InstructionBlock() + pt.build_sequence(None, parameters=params, conditions={}, measurement_mapping=measurement_mapping, + channel_mapping=channel_mapping, instruction_block=block) + + expected_waveform = MultiChannelWaveform(wfs) + + expected_block = InstructionBlock() + measurements = [('bar', .1, .2), ('foo', 0, 1)] + expected_block.add_instruction_meas(measurements) + expected_block.add_instruction_exec(waveform=expected_waveform) + + self.assertEqual(len(block.instructions), len(expected_block.instructions)) + self.assertEqual(block.instructions[0].compare_key, expected_block.instructions[0].compare_key) + self.assertEqual(block.instructions[1].compare_key, expected_block.instructions[1].compare_key) + + def test_get_measurement_windows(self): + wfs = [DummyWaveform(duration=1.1, defined_channels={'A'}), DummyWaveform(duration=1.1, defined_channels={'B'})] + sts = [DummyPulseTemplate(duration='t1', defined_channels={'A'}, waveform=wfs[0], measurements=[('m', 0, 1), + ('n', 0.3, 0.4)]), + DummyPulseTemplate(duration='t1', defined_channels={'B'}, waveform=wfs[1], measurements=[('m', 0.1, .2)])] + + pt = AtomicMultiChannelPulseTemplate(*sts, parameter_constraints=['a < b'], measurements=[('n', .1, .2)]) + + measurement_mapping = dict(m='foo', n='bar') + expected = [('bar', .1, .2), ('foo', 0, 1), ('bar', .3, .4), ('foo', .1, .2)] + meas_windows = pt.get_measurement_windows({}, measurement_mapping) + self.assertEqual(expected, meas_windows) class AtomicMultiChannelPulseTemplateSerializationTests(SerializableTests, unittest.TestCase): diff --git a/tests/pulses/point_pulse_template_tests.py b/tests/pulses/point_pulse_template_tests.py index 464e86715..243d83623 100644 --- a/tests/pulses/point_pulse_template_tests.py +++ b/tests/pulses/point_pulse_template_tests.py @@ -74,6 +74,26 @@ def test_parameter_names(self): parameter_constraints=['a < b']).parameter_names, {'a', 'b', 'n', 'A', 'B', 't', 'C'}) + def test_integral(self) -> None: + pulse = PointPulseTemplate( + [(1, (2, 'b'), 'linear'), (3, (0, 0), 'jump'), (4, (2, 'c'), 'hold'), (5, (8, 'd'), 'hold')], + [0, 'other_channel'] + ) + self.assertEqual({0: ExpressionScalar(6), + 'other_channel': ExpressionScalar('1.0*b + 2.0*c')}, + pulse.integral) + + pulse = PointPulseTemplate( + [(1, ('2', 'b'), 'linear'), ('t0', (0, 0), 'jump'), (4, (2, 'c'), 'hold'), ('g', (8, 'd'), 'hold')], + ['symbolic', 1] + ) + self.assertEqual({'symbolic': ExpressionScalar('2.0*g - t0 - 1.0'), + 1: ExpressionScalar('b*(0.5*t0 - 0.5) + c*(g - 4.0) + c*(-t0 + 4.0)')}, + pulse.integral) + + +class PointPulseTemplateSequencingTests(unittest.TestCase): + def test_requires_stop_missing_param(self) -> None: table = PointPulseTemplate([('foo', 'v')], [0]) with self.assertRaises(ParameterNotProvidedException): @@ -182,23 +202,6 @@ def test_build_waveform_none_channel(self): self.assertIsInstance(wf, MultiChannelWaveform) self.assertEqual(wf.defined_channels, {1, 2}) - def test_integral(self) -> None: - pulse = PointPulseTemplate( - [(1, (2, 'b'), 'linear'), (3, (0, 0), 'jump'), (4, (2, 'c'), 'hold'), (5, (8, 'd'), 'hold')], - [0, 'other_channel'] - ) - self.assertEqual({0: ExpressionScalar(6), - 'other_channel': ExpressionScalar('1.0*b + 2.0*c')}, - pulse.integral) - - pulse = PointPulseTemplate( - [(1, ('2', 'b'), 'linear'), ('t0', (0, 0), 'jump'), (4, (2, 'c'), 'hold'), ('g', (8, 'd'), 'hold')], - ['symbolic', 1] - ) - self.assertEqual({'symbolic': ExpressionScalar('2.0*g - t0 - 1.0'), - 1: ExpressionScalar('b*(0.5*t0 - 0.5) + c*(g - 4.0) + c*(-t0 + 4.0)')}, - pulse.integral) - class TablePulseTemplateConstraintTest(ParameterConstrainerTest): def __init__(self, *args, **kwargs): diff --git a/tests/pulses/pulse_template_parameter_mapping_tests.py b/tests/pulses/pulse_template_parameter_mapping_tests.py index f71b3cf10..bff3f9e65 100644 --- a/tests/pulses/pulse_template_parameter_mapping_tests.py +++ b/tests/pulses/pulse_template_parameter_mapping_tests.py @@ -1,289 +1,21 @@ import unittest -import itertools +import warnings -from qctoolkit.pulses.pulse_template_parameter_mapping import MissingMappingException,\ - UnnecessaryMappingException, MappingPulseTemplate,\ - AmbiguousMappingException, MappingCollisionException -from qctoolkit.pulses.parameters import ParameterNotProvidedException -from qctoolkit.pulses.parameters import ConstantParameter, ParameterConstraintViolation, ParameterConstraint -from qctoolkit.expressions import Expression +from qctoolkit.serialization import Serializer +from tests.pulses.sequencing_dummies import DummyPulseTemplate +from tests.serialization_dummies import DummyStorageBackend -from tests.pulses.sequencing_dummies import DummyPulseTemplate, DummySequencer, DummyInstructionBlock -from tests.serialization_tests import SerializableTests -from tests.serialization_dummies import DummySerializer +class TestPulseTemplateParameterMappingFileTests(unittest.TestCase): -class MappingTemplateTests(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + # ensure that a MappingPulseTemplate imported from pulse_template_parameter_mapping serializes as from mapping_pulse_template + def test_pulse_template_parameter_include(self) -> None: + with warnings.catch_warnings(record=True): + warnings.simplefilter('ignore', DeprecationWarning) + from qctoolkit.pulses.pulse_template_parameter_mapping import MappingPulseTemplate + dummy_t = DummyPulseTemplate() + map_t = MappingPulseTemplate(dummy_t) + serializer = Serializer(DummyStorageBackend()) + type_str = serializer.get_type_identifier(map_t) + self.assertEqual("qctoolkit.pulses.mapping_pulse_template.MappingPulseTemplate", type_str) - def test_init_exceptions(self): - template = DummyPulseTemplate(parameter_names={'foo', 'bar'}, defined_channels={'A'}, measurement_names={'B'}) - parameter_mapping = {'foo': 't*k', 'bar': 't*l'} - - with self.assertRaises(MissingMappingException): - MappingPulseTemplate(template, parameter_mapping={}) - with self.assertRaises(MissingMappingException): - MappingPulseTemplate(template, parameter_mapping={'bar': 'kneipe'}) - with self.assertRaises(UnnecessaryMappingException): - MappingPulseTemplate(template, parameter_mapping=dict(**parameter_mapping, foobar='asd')) - - with self.assertRaises(UnnecessaryMappingException): - MappingPulseTemplate(template, parameter_mapping=parameter_mapping, measurement_mapping=dict(a='b')) - with self.assertRaises(UnnecessaryMappingException): - MappingPulseTemplate(template, parameter_mapping=parameter_mapping, channel_mapping=dict(a='b')) - - with self.assertRaises(TypeError): - MappingPulseTemplate(template, parameter_mapping) - - MappingPulseTemplate(template, parameter_mapping=parameter_mapping) - - def test_from_tuple_exceptions(self): - template = DummyPulseTemplate(parameter_names={'foo', 'bar'}, - measurement_names={'foo', 'foobar'}, - defined_channels={'bar', 'foobar'}) - - with self.assertRaises(ValueError): - MappingPulseTemplate.from_tuple((template, {'A': 'B'})) - with self.assertRaises(AmbiguousMappingException): - MappingPulseTemplate.from_tuple((template, {'foo': 'foo'})) - with self.assertRaises(AmbiguousMappingException): - MappingPulseTemplate.from_tuple((template, {'bar': 'bar'})) - with self.assertRaises(AmbiguousMappingException): - MappingPulseTemplate.from_tuple((template, {'foobar': 'foobar'})) - - template = DummyPulseTemplate(parameter_names={'foo', 'bar'}) - with self.assertRaises(MappingCollisionException): - MappingPulseTemplate.from_tuple((template, {'foo': '1', 'bar': 2}, {'foo': '1', 'bar': 4})) - - template = DummyPulseTemplate(defined_channels={'A'}) - with self.assertRaises(MappingCollisionException): - MappingPulseTemplate.from_tuple((template, {'A': 'N'}, {'A': 'C'})) - - template = DummyPulseTemplate(measurement_names={'M'}) - with self.assertRaises(MappingCollisionException): - MappingPulseTemplate.from_tuple((template, {'M': 'N'}, {'M': 'N'})) - - def test_from_tuple(self): - template = DummyPulseTemplate(parameter_names={'foo', 'bar'}, - measurement_names={'m1', 'm2'}, - defined_channels={'c1', 'c2'}) - - def test_mapping_permutations(template: DummyPulseTemplate, - pmap, mmap, cmap): - direct = MappingPulseTemplate(template, - parameter_mapping=pmap, - measurement_mapping=mmap, - channel_mapping=cmap) - - mappings = [m for m in [pmap, mmap, cmap] if m is not None] - - for current_mapping_order in itertools.permutations(mappings): - mapper = MappingPulseTemplate.from_tuple((template, *current_mapping_order)) - self.assertEqual(mapper.measurement_mapping, direct.measurement_mapping) - self.assertEqual(mapper.channel_mapping, direct.channel_mapping) - self.assertEqual(mapper.parameter_mapping, direct.parameter_mapping) - - test_mapping_permutations(template, {'foo': 1, 'bar': 2}, {'m1': 'n1', 'm2': 'n2'}, {'c1': 'd1', 'c2': 'd2'}) - test_mapping_permutations(template, {'foo': 1, 'bar': 2}, {'m1': 'n1'}, {'c1': 'd1', 'c2': 'd2'}) - test_mapping_permutations(template, {'foo': 1, 'bar': 2}, None, {'c1': 'd1', 'c2': 'd2'}) - test_mapping_permutations(template, {'foo': 1, 'bar': 2}, {'m1': 'n1', 'm2': 'n2'}, {'c1': 'd1'}) - test_mapping_permutations(template, {'foo': 1, 'bar': 2}, {'m1': 'n1', 'm2': 'n2'}, None) - test_mapping_permutations(template, None, {'m1': 'n1', 'm2': 'n2'}, {'c1': 'd1', 'c2': 'd2'}) - test_mapping_permutations(template, None, {'m1': 'n1'}, {'c1': 'd1', 'c2': 'd2'}) - test_mapping_permutations(template, None, None, {'c1': 'd1', 'c2': 'd2'}) - test_mapping_permutations(template, None, {'m1': 'n1', 'm2': 'n2'}, {'c1': 'd1'}) - test_mapping_permutations(template, None, {'m1': 'n1', 'm2': 'n2'}, None) - - def test_external_params(self): - template = DummyPulseTemplate(parameter_names={'foo', 'bar'}) - st = MappingPulseTemplate(template, parameter_mapping={'foo': 't*k', 'bar': 't*l'}) - external_params = {'t', 'l', 'k'} - self.assertEqual(st.parameter_names, external_params) - - def test_constrained(self): - template = DummyPulseTemplate(parameter_names={'foo', 'bar'}) - st = MappingPulseTemplate(template, parameter_mapping={'foo': 't*k', 'bar': 't*l'}, parameter_constraints=['t < m']) - external_params = {'t', 'l', 'k', 'm'} - self.assertEqual(st.parameter_names, external_params) - - with self.assertRaises(ParameterConstraintViolation): - st.map_parameters(dict(t=1, l=2, k=3, m=0)) - - def test_map_parameters(self): - template = DummyPulseTemplate(parameter_names={'foo', 'bar'}) - st = MappingPulseTemplate(template, parameter_mapping={'foo': 't*k', 'bar': 't*l'}) - - parameters = {'t': ConstantParameter(3), 'k': ConstantParameter(2), 'l': ConstantParameter(7)} - values = {'foo': 6, 'bar': 21} - for k, v in st.map_parameters(parameters).items(): - self.assertEqual(v.get_value(), values[k]) - parameters.popitem() - with self.assertRaises(ParameterNotProvidedException): - st.map_parameters(parameters) - - parameters = dict(t=3, k=2, l=7) - values = {'foo': 6, 'bar': 21} - for k, v in st.map_parameters(parameters).items(): - self.assertEqual(v, values[k]) - - def test_partial_parameter_mapping(self): - template = DummyPulseTemplate(parameter_names={'foo', 'bar'}) - st = MappingPulseTemplate(template, parameter_mapping={'foo': 't*k'}, allow_partial_parameter_mapping=True) - - self.assertEqual(st.parameter_mapping, {'foo': 't*k', 'bar': 'bar'}) - - def test_nested_mapping_avoidance(self): - template = DummyPulseTemplate(parameter_names={'foo', 'bar'}) - st_1 = MappingPulseTemplate(template, parameter_mapping={'foo': 't*k'}, allow_partial_parameter_mapping=True) - st_2 = MappingPulseTemplate(st_1, parameter_mapping={'bar': 't*l'}, allow_partial_parameter_mapping=True) - - self.assertIs(st_2.template, template) - self.assertEqual(st_2.parameter_mapping, {'foo': 't*k', 'bar': 't*l'}) - - st_3 = MappingPulseTemplate(template, - parameter_mapping={'foo': 't*k'}, - allow_partial_parameter_mapping=True, - identifier='käse') - st_4 = MappingPulseTemplate(st_3, parameter_mapping={'bar': 't*l'}, allow_partial_parameter_mapping=True) - self.assertIs(st_4.template, st_3) - self.assertEqual(st_4.parameter_mapping, {'t': 't', 'k': 'k', 'bar': 't*l'}) - - - def test_get_updated_channel_mapping(self): - template = DummyPulseTemplate(defined_channels={'foo', 'bar'}) - st = MappingPulseTemplate(template, channel_mapping={'bar': 'kneipe'}) - with self.assertRaises(KeyError): - st.get_updated_channel_mapping(dict()) - self.assertEqual(st.get_updated_channel_mapping({'kneipe': 'meas1', 'foo': 'meas2', 'troet': 'meas3'}), - {'foo': 'meas2', 'bar': 'meas1'}) - - def test_measurement_names(self): - template = DummyPulseTemplate(measurement_names={'foo', 'bar'}) - st = MappingPulseTemplate(template, measurement_mapping={'foo': 'froop', 'bar': 'kneipe'}) - self.assertEqual( st.measurement_names, {'froop','kneipe'} ) - - def test_defined_channels(self): - mapping = {'asd': 'A', 'fgh': 'B'} - template = DummyPulseTemplate(defined_channels=set(mapping.keys())) - st = MappingPulseTemplate(template, channel_mapping=mapping) - self.assertEqual(st.defined_channels, set(mapping.values())) - - def test_get_updated_measurement_mapping(self): - template = DummyPulseTemplate(measurement_names={'foo', 'bar'}) - st = MappingPulseTemplate(template, measurement_mapping={'bar': 'kneipe'}) - with self.assertRaises(KeyError): - st.get_updated_measurement_mapping(dict()) - self.assertEqual(st.get_updated_measurement_mapping({'kneipe': 'meas1', 'foo': 'meas2', 'troet': 'meas3'}), - {'foo': 'meas2', 'bar': 'meas1'}) - - def test_build_sequence(self): - measurement_mapping = {'meas1': 'meas2'} - parameter_mapping = {'t': 'k'} - - template = DummyPulseTemplate(measurement_names=set(measurement_mapping.keys()), - parameter_names=set(parameter_mapping.keys())) - st = MappingPulseTemplate(template, parameter_mapping=parameter_mapping, measurement_mapping=measurement_mapping) - sequencer = DummySequencer() - block = DummyInstructionBlock() - pre_parameters = {'k': ConstantParameter(5)} - pre_measurement_mapping = {'meas2': 'meas3'} - pre_channel_mapping = {'default': 'A'} - conditions = dict(a=True) - st.build_sequence(sequencer, pre_parameters, conditions, pre_measurement_mapping, pre_channel_mapping, block) - - self.assertEqual(template.build_sequence_calls, 1) - forwarded_args = template.build_sequence_arguments[0] - self.assertEqual(forwarded_args[0], sequencer) - self.assertEqual(forwarded_args[1], st.map_parameters(pre_parameters)) - self.assertEqual(forwarded_args[2], conditions) - self.assertEqual(forwarded_args[3], - st.get_updated_measurement_mapping(pre_measurement_mapping)) - self.assertEqual(forwarded_args[4], - st.get_updated_channel_mapping(pre_channel_mapping)) - self.assertEqual(forwarded_args[5], block) - - @unittest.skip("Extend of dummy template for argument checking needed.") - def test_requires_stop(self): - pass - - def test_integral(self) -> None: - dummy = DummyPulseTemplate(defined_channels={'A', 'B'}, - parameter_names={'k', 'f', 'b'}, - integrals={'A': Expression('2*k'), 'other': Expression('-3.2*f+b')}) - pulse = MappingPulseTemplate(dummy, parameter_mapping={'k': 'f', 'b': 2.3}, channel_mapping={'A': 'default'}, - allow_partial_parameter_mapping=True) - - self.assertEqual({'default': Expression('2*f'), 'other': Expression('-3.2*f+2.3')}, pulse.integral) - - -class PulseTemplateParameterMappingExceptionsTests(unittest.TestCase): - - def test_missing_mapping_exception_str(self) -> None: - dummy = DummyPulseTemplate() - exception = MissingMappingException(dummy, 'foo') - self.assertIsInstance(str(exception), str) - - def test_unnecessary_mapping_exception_str(self) -> None: - dummy = DummyPulseTemplate() - exception = UnnecessaryMappingException(dummy, 'foo') - self.assertIsInstance(str(exception), str) - - -class MappingPulseTemplateSerializationTests(SerializableTests, unittest.TestCase): - - @property - def class_to_test(self): - return MappingPulseTemplate - - def make_kwargs(self): - return { - 'template': DummyPulseTemplate(defined_channels={'foo'}, - measurement_names={'meas'}, - parameter_names={'hugo', 'herbert', 'ilse'}), - 'parameter_mapping': {'hugo': Expression('2*k+c'), 'herbert': Expression('c-1.5'), 'ilse': Expression('ilse')}, - 'measurement_mapping': {'meas': 'seam'}, - 'channel_mapping': {'foo': 'default_channel'}, - 'parameter_constraints': [str(ParameterConstraint('c > 0'))] - } - - def make_instance(self, identifier=None, registry=None): - kwargs = self.make_kwargs() - return self.class_to_test(identifier=identifier, **kwargs, allow_partial_parameter_mapping=True, registry=registry) - - def assert_equal_instance_except_id(self, lhs: MappingPulseTemplate, rhs: MappingPulseTemplate): - self.assertIsInstance(lhs, MappingPulseTemplate) - self.assertIsInstance(rhs, MappingPulseTemplate) - self.assertEqual(lhs.template, rhs.template) - self.assertEqual(lhs.parameter_constraints, rhs.parameter_constraints) - self.assertEqual(lhs.channel_mapping, rhs.channel_mapping) - self.assertEqual(lhs.measurement_mapping, rhs.measurement_mapping) - self.assertEqual(lhs.parameter_mapping, rhs.parameter_mapping) - - -class MappingPulseTemplateOldSerializationTests(unittest.TestCase): - - def test_deserialize(self) -> None: - # test for deprecated version during transition period, remove after final switch - with self.assertWarnsRegex(DeprecationWarning, "deprecated", - msg="SequencePT does not issue warning for old serialization routines."): - dummy_pt = DummyPulseTemplate(defined_channels={'foo'}, - measurement_names={'meas'}, - parameter_names={'hugo', 'herbert', 'ilse'}) - serializer = DummySerializer() - data = { - 'template': serializer.dictify(dummy_pt), - 'parameter_mapping': {'hugo': str(Expression('2*k+c')), 'herbert': str(Expression('c-1.5')), - 'ilse': str(Expression('ilse'))}, - 'measurement_mapping': {'meas': 'seam'}, - 'channel_mapping': {'foo': 'default_channel'}, - 'parameter_constraints': [str(ParameterConstraint('c > 0'))] - } - deserialized = MappingPulseTemplate.deserialize(serializer=serializer, **data) - - self.assertIsInstance(deserialized, MappingPulseTemplate) - self.assertEqual(data['parameter_mapping'], deserialized.parameter_mapping) - self.assertEqual(data['channel_mapping'], deserialized.channel_mapping) - self.assertEqual(data['measurement_mapping'], deserialized.measurement_mapping) - self.assertEqual(data['parameter_constraints'], [str(pc) for pc in deserialized.parameter_constraints]) - self.assertIs(deserialized.template, dummy_pt) \ No newline at end of file diff --git a/tests/pulses/sequence_pulse_template_tests.py b/tests/pulses/sequence_pulse_template_tests.py index bffddecab..5bcb94eab 100644 --- a/tests/pulses/sequence_pulse_template_tests.py +++ b/tests/pulses/sequence_pulse_template_tests.py @@ -6,7 +6,7 @@ from qctoolkit.expressions import Expression, ExpressionScalar from qctoolkit.pulses.table_pulse_template import TablePulseTemplate from qctoolkit.pulses.sequence_pulse_template import SequencePulseTemplate, SequenceWaveform -from qctoolkit.pulses.pulse_template_parameter_mapping import MappingPulseTemplate +from qctoolkit.pulses.mapping_pulse_template import MappingPulseTemplate from qctoolkit.pulses.parameters import ConstantParameter, ParameterConstraint, ParameterConstraintViolation from qctoolkit._program.instructions import MEASInstruction diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index b9d743217..cced1b28e 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -1,5 +1,5 @@ """STANDARD LIBRARY IMPORTS""" -from typing import Tuple, List, Dict, Optional, Set, Any +from typing import Tuple, List, Dict, Optional, Set, Any, Union import copy import numpy @@ -136,7 +136,7 @@ def add_instruction(self, instruction: Instruction) -> None: class DummyWaveform(Waveform): - def __init__(self, duration: float=0, sample_output: numpy.ndarray=None, defined_channels={'A'}) -> None: + def __init__(self, duration: float=0, sample_output: Union[numpy.ndarray, dict]=None, defined_channels={'A'}) -> None: super().__init__() self.duration_ = time_from_float(duration) self.sample_output = sample_output @@ -166,7 +166,10 @@ def unsafe_sample(self, if output_array is None: output_array = numpy.empty_like(sample_times) if self.sample_output is not None: - output_array[:] = self.sample_output + if isinstance(self.sample_output, dict): + output_array[:] = self.sample_output[channel] + else: + output_array[:] = self.sample_output else: output_array[:] = sample_times return output_array @@ -324,10 +327,6 @@ def duration(self): def parameter_names(self) -> Set[str]: return set(self.parameter_names_) - def get_measurement_windows(self, parameters: Dict[str, Parameter] = None) -> List[MeasurementWindow]: - """Return all measurement windows defined in this PulseTemplate.""" - raise NotImplementedError() - @property def build_sequence_calls(self): return len(self.build_sequence_arguments) diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index d4659dd56..d312be6a2 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -411,6 +411,19 @@ def test_measurement_names(self): tpt = TablePulseTemplate({0: [(10, 1)]}, measurements=[('A', 2, 3), ('AB', 0, 1)]) self.assertEqual(tpt.measurement_names, {'A', 'AB'}) + def test_identifier(self) -> None: + identifier = 'some name' + pulse = TablePulseTemplate(entries={0: [(1, 0)]}, identifier=identifier) + self.assertEqual(pulse.identifier, identifier) + + def test_integral(self) -> None: + pulse = TablePulseTemplate(entries={0: [(1, 2, 'linear'), (3, 0, 'jump'), (4, 2, 'hold'), (5, 8, 'hold')], + 'other_channel': [(0, 7, 'linear'), (2, 0, 'hold'), (10, 0)], + 'symbolic': [(3, 'a', 'hold'), ('b', 4, 'linear'), ('c', Expression('d'), 'hold')]}) + self.assertEqual(pulse.integral, {0: Expression('6'), + 'other_channel': Expression(7), + 'symbolic': Expression('(b-3)*a + 0.5 * (c-b)*(d+4)')}) + class TablePulseTemplateConstraintTest(ParameterConstrainerTest): def __init__(self, *args, **kwargs): @@ -608,19 +621,6 @@ def test_requires_stop(self) -> None: for expected_result, parameter_set, condition_set in test_sets: self.assertEqual(expected_result, table.requires_stop(parameter_set, condition_set)) - def test_identifier(self) -> None: - identifier = 'some name' - pulse = TablePulseTemplate(entries={0: [(1, 0)]}, identifier=identifier) - self.assertEqual(pulse.identifier, identifier) - - def test_integral(self) -> None: - pulse = TablePulseTemplate(entries={0: [(1, 2, 'linear'), (3, 0, 'jump'), (4, 2, 'hold'), (5, 8, 'hold')], - 'other_channel': [(0, 7, 'linear'), (2, 0, 'hold'), (10, 0)], - 'symbolic': [(3, 'a', 'hold'), ('b', 4, 'linear'), ('c', Expression('d'), 'hold')]}) - self.assertEqual(pulse.integral, {0: Expression('6'), - 'other_channel': Expression(7), - 'symbolic': Expression('(b-3)*a + 0.5 * (c-b)*(d+4)')}) - class TablePulseConcatenationTests(unittest.TestCase):