From b6e50f210082ac3a45ed4810e82dc7991c78db5e Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 20 Nov 2020 12:05:34 +0100 Subject: [PATCH 1/8] Avoid returning ndarrays from expression if possible --- qupulse/expressions.py | 10 ++++++++++ tests/expression_tests.py | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/qupulse/expressions.py b/qupulse/expressions.py index 90594dc69..476aa357f 100644 --- a/qupulse/expressions.py +++ b/qupulse/expressions.py @@ -347,6 +347,16 @@ def get_serialization_data(self) -> Union[str, float, int]: def is_nan(self) -> bool: return sympy.sympify('nan') == self._sympified_expression + def _parse_evaluate_numeric_result(self, + result: Union[Number, numpy.ndarray], + call_arguments: Any) -> Number: + """Overwrite super class method because we do not want to return a scalar numpy.ndarray""" + parsed = super()._parse_evaluate_numeric_result(result, call_arguments) + if isinstance(parsed, numpy.ndarray): + return parsed[()] + else: + return parsed + class ExpressionVariableMissingException(Exception): """An exception indicating that a variable value was not provided during expression evaluation. diff --git a/tests/expression_tests.py b/tests/expression_tests.py index 3c69437d3..acc1735eb 100644 --- a/tests/expression_tests.py +++ b/tests/expression_tests.py @@ -150,6 +150,14 @@ def test_evaluate_numpy(self): } np.testing.assert_equal((2 * 1.5 - 7) * np.ones(4), e.evaluate_numeric(**params)) + e = ExpressionScalar('a * b + c') + params = { + 'a': np.array(2), + 'b': np.array(1.5), + 'c': np.array(-7) + } + np.testing.assert_equal((2 * 1.5 - 7), e.evaluate_numeric(**params)) + def test_indexing(self): e = ExpressionScalar('a[i] * c') From d6736076363a23ed3b5c24ee6ae02b273470034c Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 20 Nov 2020 12:33:19 +0100 Subject: [PATCH 2/8] Plan deprecation of substitute_with_eval Improve analysis behavior of Broadcast Use a numpy compatible add on substitution --- qupulse/utils/sympy.py | 45 +++++++++++++++++++++++++++++++------- tests/utils/sympy_tests.py | 29 +++++++++++++++++++++++- 2 files changed, 65 insertions(+), 9 deletions(-) diff --git a/qupulse/utils/sympy.py b/qupulse/utils/sympy.py index 1e3350842..9038a437f 100644 --- a/qupulse/utils/sympy.py +++ b/qupulse/utils/sympy.py @@ -87,16 +87,25 @@ class Broadcast(sympy.Function): >>> assert bc.subs({'a': 2}) == sympy.Array([2, 2, 2]) >>> assert bc.subs({'a': (1, 2, 3)}) == sympy.Array([1, 2, 3]) """ + nargs = (2,) @classmethod def eval(cls, x, shape) -> Optional[sympy.Array]: - if hasattr(shape, 'free_symbols') and shape.free_symbols: + if getattr(shape, 'free_symbols', None): # cannot do anything return None if hasattr(x, '__len__') or not x.free_symbols: return sympy.Array(numpy.broadcast_to(x, shape)) + def _eval_Integral(self, *symbols, **assumptions): + x, shape = self.args + return Broadcast(sympy.Integral(x, *symbols, **assumptions), shape) + + def _eval_derivative(self, sym): + x, shape = self.args + return Broadcast(sympy.diff(x, sym), shape) + class Len(sympy.Function): nargs = 1 @@ -125,6 +134,22 @@ def numpy_compatible_mul(*args) -> Union[sympy.Mul, sympy.Array]: return sympy.Mul(*args) +def numpy_compatible_add(*args) -> Union[sympy.Add, sympy.Array]: + if any(isinstance(a, sympy.NDimArray) for a in args): + result = 0 + for a in args: + result = result + (numpy.array(a.tolist()) if isinstance(a, sympy.NDimArray) else a) + return sympy.Array(result) + else: + return sympy.Add(*args) + + +_NUMPY_COMPATIBLE = { + sympy.Add: numpy_compatible_add, + sympy.Mul: numpy_compatible_mul +} + + def numpy_compatible_ceiling(input_value: Any) -> Any: if isinstance(input_value, numpy.ndarray): return numpy.ceil(input_value).astype(numpy.int64) @@ -154,6 +179,8 @@ def sympify(expr: Union[str, Number, sympy.Expr, numpy.str_], **kwargs) -> sympy # putting numpy.str_ in sympy.sympify behaves unexpected in version 1.1.1 # It seems to ignore the locals argument expr = str(expr) + if isinstance(expr, (tuple, list)): + expr = numpy.array(expr) try: return sympy.sympify(expr, **kwargs, locals=sympify_namespace) except TypeError as err: @@ -192,6 +219,9 @@ def get_variables(expression: sympy.Expr) -> Sequence[str]: def substitute_with_eval(expression: sympy.Expr, substitutions: Dict[str, Union[sympy.Expr, numpy.ndarray, str]]) -> sympy.Expr: """Substitutes only sympy.Symbols. Workaround for numpy like array behaviour. ~Factor 3 slower compared to subs""" + warnings.warn("substitute_with_eval does not handle dummy symbols correctly and is planned to be removed", + FutureWarning) + substitutions = {k: v if isinstance(v, sympy.Expr) else sympify(v) for k, v in substitutions.items()} @@ -202,27 +232,26 @@ def substitute_with_eval(expression: sympy.Expr, string_representation = sympy.srepr(expression) return eval(string_representation, sympy.__dict__, {'Symbol': substitutions.__getitem__, - 'Mul': numpy_compatible_mul}) + 'Mul': numpy_compatible_mul, + 'Add': numpy_compatible_add}) def _recursive_substitution(expression: sympy.Expr, substitutions: Dict[sympy.Symbol, sympy.Expr]) -> sympy.Expr: if not expression.free_symbols: return expression - elif expression.func is sympy.Symbol: + elif expression.func in (sympy.Symbol, sympy.Dummy): return substitutions.get(expression, expression) - elif expression.func is sympy.Mul: - func = numpy_compatible_mul - else: - func = expression.func + func = _NUMPY_COMPATIBLE.get(expression.func, expression.func) substitutions = {s: substitutions.get(s, s) for s in get_free_symbols(expression)} return func(*(_recursive_substitution(arg, substitutions) for arg in expression.args)) def recursive_substitution(expression: sympy.Expr, substitutions: Dict[str, Union[sympy.Expr, numpy.ndarray, str]]) -> sympy.Expr: - substitutions = {sympy.Symbol(k): sympify(v) for k, v in substitutions.items()} + substitutions = {k if isinstance(k, (sympy.Symbol, sympy.Dummy)) else sympy.Symbol(k): sympify(v) + for k, v in substitutions.items()} for s in get_free_symbols(expression): substitutions.setdefault(s, s) return _recursive_substitution(expression, substitutions) diff --git a/tests/utils/sympy_tests.py b/tests/utils/sympy_tests.py index def2f9d2b..d489a7e51 100644 --- a/tests/utils/sympy_tests.py +++ b/tests/utils/sympy_tests.py @@ -14,6 +14,7 @@ a_ = IndexedBase(a) b_ = IndexedBase(b) +dummy_a = sympy.Dummy('a') from qupulse.utils.sympy import sympify as qc_sympify, substitute_with_eval, recursive_substitution, Len,\ evaluate_lambdified, evaluate_compiled, get_most_simple_representation, get_variables, get_free_symbols,\ @@ -44,12 +45,19 @@ vector_valued_cases = [ (a*b, {'a': sympy.Array([1, 2, 3])}, sympy.Array([1, 2, 3])*b), (a*b, {'a': sympy.Array([1, 2, 3]), 'b': sympy.Array([4, 5, 6])}, sympy.Array([4, 10, 18])), + (a + b, {'a': sympy.Array([1, 2, 3])}, sympy.Array([1 + b, 2 + b, 3 + b])), + (a + b, {'a': sympy.Array([1, 2, 3]), 'b': sympy.Array([4, 5, 6])}, sympy.Array([5, 7, 9])), ] full_featured_cases = [ (Sum(a_[i], (i, 0, Len(a) - 1)), {'a': sympy.Array([1, 2, 3])}, 6), ] +dummy_substitution_cases = [ + (a * dummy_a + sympy.exp(dummy_a), {'a': b}, b * dummy_a + sympy.exp(dummy_a)), + (a * dummy_a + sympy.exp(dummy_a), {dummy_a: b}, a * b + sympy.exp(b)), +] + ##################################################### SYMPIFY ########################################################## simple_sympify = [ @@ -199,10 +207,16 @@ def test_full_featured_cases(self): result = self.substitute(expr, subs) self.assertEqual(result, expected) + def test_dummy_subs(self): + for expr, subs, expected in dummy_substitution_cases: + result = self.substitute(expr, subs) + self.assertEqual(result, expected) + class SubstituteWithEvalTests(SubstitutionTests): def substitute(self, expression: sympy.Expr, substitutions: dict): - return substitute_with_eval(expression, substitutions) + with self.assertWarns(FutureWarning): + return substitute_with_eval(expression, substitutions) @unittest.expectedFailure def test_sum_substitution_cases(self): @@ -212,6 +226,10 @@ def test_sum_substitution_cases(self): def test_full_featured_cases(self): super().test_full_featured_cases() + @unittest.expectedFailure + def test_dummy_subs(self): + super().test_dummy_subs() + class RecursiveSubstitutionTests(SubstitutionTests): def substitute(self, expression: sympy.Expr, substitutions: dict): @@ -429,6 +447,15 @@ def test_expression_equality(self): test_numeric_equal = unittest.expectedFailure(test_expression_equality) if distutils.version.StrictVersion(sympy.__version__) >= distutils.version.StrictVersion('1.5') else test_expression_equality + def test_integral(self): + symbolic = Broadcast(a, (3,)) + + integ = sympy.Integral(symbolic, (a, 0, b)) + self.assertEqual(integ, Broadcast(sympy.Integral(a, (a, 0, b)), (3,))) + + diffed = sympy.diff(integ, b).subs({b: a}) + self.assertEqual(symbolic, diffed) + class IndexedBasedFinderTests(unittest.TestCase): def test_isinstance(self): From afb4f3a8b33c57dc9588603f20f894c4181c9d0b Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 20 Nov 2020 12:40:26 +0100 Subject: [PATCH 3/8] Add _as_expression to atomic PTs except AMCPT --- qupulse/pulses/arithmetic_pulse_template.py | 27 +++-- qupulse/pulses/function_pulse_template.py | 5 +- qupulse/pulses/mapping_pulse_template.py | 14 ++- qupulse/pulses/point_pulse_template.py | 51 ++++++--- qupulse/pulses/pulse_template.py | 9 ++ qupulse/pulses/table_pulse_template.py | 107 +++++++++++++++--- .../pulses/arithmetic_pulse_template_tests.py | 30 +++++ tests/pulses/function_pulse_tests.py | 5 + tests/pulses/mapping_pulse_template_tests.py | 21 ++++ tests/pulses/point_pulse_template_tests.py | 97 +++++++++++++--- tests/pulses/pulse_template_tests.py | 7 +- tests/pulses/sequencing_dummies.py | 31 ++++- tests/pulses/table_pulse_template_tests.py | 92 +++++++++++++-- 13 files changed, 422 insertions(+), 74 deletions(-) diff --git a/qupulse/pulses/arithmetic_pulse_template.py b/qupulse/pulses/arithmetic_pulse_template.py index 284e426a6..ae685f26f 100644 --- a/qupulse/pulses/arithmetic_pulse_template.py +++ b/qupulse/pulses/arithmetic_pulse_template.py @@ -18,6 +18,18 @@ IdentityTransformation +def _apply_operation_to_channel_dict(operator: str, + lhs: Mapping[ChannelID, Any], + rhs: Mapping[ChannelID, Any]) -> Dict[ChannelID, Any]: + result = dict(lhs) + for channel, rhs_value in rhs.items(): + if channel in result: + result[channel] = ArithmeticWaveform.operator_map[operator](result[channel], rhs_value) + else: + result[channel] = ArithmeticWaveform.rhs_only_map[operator](rhs_value) + return result + + class ArithmeticAtomicPulseTemplate(AtomicPulseTemplate): def __init__(self, lhs: AtomicPulseTemplate, @@ -96,17 +108,12 @@ def duration(self) -> ExpressionScalar: @property def integral(self) -> Dict[ChannelID, ExpressionScalar]: - lhs = self.lhs.integral - rhs = self.rhs.integral + return _apply_operation_to_channel_dict(self._arithmetic_operator, self.lhs.integral, self.rhs.integral) - result = lhs.copy() - - for channel, rhs_value in rhs.items(): - if channel in result: - result[channel] = ArithmeticWaveform.operator_map[self._arithmetic_operator](result[channel], rhs_value) - else: - result[channel] = ArithmeticWaveform.rhs_only_map[self._arithmetic_operator](rhs_value) - return result + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + return _apply_operation_to_channel_dict(self._arithmetic_operator, + self.lhs._as_expression(), + self.rhs._as_expression()) def build_waveform(self, parameters: Dict[str, Real], diff --git a/qupulse/pulses/function_pulse_template.py b/qupulse/pulses/function_pulse_template.py index 9df064681..145863a4c 100644 --- a/qupulse/pulses/function_pulse_template.py +++ b/qupulse/pulses/function_pulse_template.py @@ -9,7 +9,6 @@ from typing import Any, Dict, List, Set, Optional, Union import numbers -import numpy as np import sympy from qupulse.expressions import ExpressionScalar @@ -148,4 +147,8 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: sympy.integrate(self.__expression.sympified_expression, ('t', 0, self.duration.sympified_expression)) )} + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + expr = ExpressionScalar.make(self.__expression.underlying_expression.subs({'t': self._AS_EXPRESSION_TIME})) + return {self.__channel: expr} + diff --git a/qupulse/pulses/mapping_pulse_template.py b/qupulse/pulses/mapping_pulse_template.py index a9ff5a0f1..af5b64e00 100644 --- a/qupulse/pulses/mapping_pulse_template.py +++ b/qupulse/pulses/mapping_pulse_template.py @@ -354,18 +354,28 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: # todo: make Expressions compatible with sympy.subs() parameter_mapping = {parameter_name: expression.underlying_expression for parameter_name, expression in self.__parameter_mapping.items()} - for channel, ch_integral in internal_integral.items(): channel_out = self.__channel_mapping.get(channel, channel) if channel_out is None: continue expressions[channel_out] = ExpressionScalar( - ch_integral.sympified_expression.subs(parameter_mapping) + ch_integral.sympified_expression.subs(parameter_mapping, simultaneous=True) ) return expressions + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + parameter_mapping = {parameter_name: expression.underlying_expression + for parameter_name, expression in self.__parameter_mapping.items()} + inner = self.__template._as_expression() + return { + self.__channel_mapping.get(ch, ch): ExpressionScalar(ch_expr.sympified_expression.subs(parameter_mapping, + simultaneous=True)) + for ch, ch_expr in inner.items() + if self.__channel_mapping.get(ch, ch) is not None + } + class MissingMappingException(Exception): """Indicates that no mapping was specified for some parameter declaration of a diff --git a/qupulse/pulses/point_pulse_template.py b/qupulse/pulses/point_pulse_template.py index 40abcde4c..0d41abedf 100644 --- a/qupulse/pulses/point_pulse_template.py +++ b/qupulse/pulses/point_pulse_template.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Union, Set, Dict, Sequence, Any +from typing import Optional, List, Union, Set, Dict, Sequence, Any, Tuple from numbers import Real import itertools import numbers @@ -64,7 +64,8 @@ def defined_channels(self) -> Set[ChannelID]: def build_waveform(self, parameters: Dict[str, Real], - channel_mapping: Dict[ChannelID, Optional[ChannelID]]) -> Optional[TableWaveform]: + channel_mapping: Dict[ChannelID, Optional[ChannelID]]) -> Optional[Union[TableWaveform, + MultiChannelWaveform]]: self.validate_parameter_constraints(parameters=parameters, volatile=set()) if all(channel_mapping[channel] is None @@ -136,21 +137,39 @@ def parameter_names(self) -> Set[str]: @property def integral(self) -> Dict[ChannelID, ExpressionScalar]: - expressions = {channel: 0 for channel in self._channels} - for first_entry, second_entry in zip(self._entries[:-1], self._entries[1:]): - substitutions = {'t0': first_entry.t.sympified_expression, - 't1': second_entry.t.sympified_expression} - - v0 = sympy.IndexedBase(Broadcast(first_entry.v.underlying_expression, (len(self.defined_channels),))) - v1 = sympy.IndexedBase(Broadcast(second_entry.v.underlying_expression, (len(self.defined_channels),))) - - for i, channel in enumerate(self._channels): - substitutions['v0'] = v0[i] - substitutions['v1'] = v1[i] - - expressions[channel] += first_entry.interp.integral.sympified_expression.subs(substitutions) + expressions = {} + shape = (len(self.defined_channels),) + + for i, channel in enumerate(self._channels): + def value_trafo(v): + try: + return v.underlying_expression[i] + except TypeError: + return sympy.IndexedBase(Broadcast(v.underlying_expression, shape))[i] + pre_entry = TableEntry(0, self._entries[0].v, None) + entries = [pre_entry] + self._entries + expressions[channel] = TableEntry._sequence_integral(entries, expression_extractor=value_trafo) + return expressions - expressions = {c: ExpressionScalar(expressions[c]) for c in expressions} + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + t = self._AS_EXPRESSION_TIME + shape = (len(self.defined_channels),) + expressions = {} + + for i, channel in enumerate(self._channels): + def value_trafo(v): + try: + return v.underlying_expression[i] + except TypeError: + return sympy.IndexedBase(Broadcast(v.underlying_expression, shape))[i] + pre_value = value_trafo(self._entries[0].v) + post_value = value_trafo(self._entries[-1].v) + pw = TableEntry._sequence_as_expression(self._entries, + expression_extractor=value_trafo, + t=t, + post_value=post_value, + pre_value=pre_value) + expressions[channel] = pw return expressions diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index d546e6614..aa558f561 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -12,6 +12,8 @@ import collections from numbers import Real, Number +import sympy + from qupulse.utils.types import ChannelID, DocStringABCMeta, FrozenDict from qupulse.serialization import Serializable from qupulse.expressions import ExpressionScalar, Expression, ExpressionLike @@ -290,6 +292,8 @@ class AtomicPulseTemplate(PulseTemplate, MeasurementDefiner): Implies that no AtomicPulseTemplate object is interruptable. """ + _AS_EXPRESSION_TIME = sympy.Dummy('_t', positive=True) + def __init__(self, *, identifier: Optional[str], measurements: Optional[List[MeasurementDeclaration]]): @@ -345,6 +349,11 @@ def build_waveform(self, does not represent a valid waveform of finite length. """ + @abstractmethod + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + """Helper function to allow integral calculation in case of truncation. AtomicPulseTemplate._AS_EXPRESSION_TIME + is by convention the time variable.""" + class DoubleParameterNameException(Exception): diff --git a/qupulse/pulses/table_pulse_template.py b/qupulse/pulses/table_pulse_template.py index 22e56b28d..9b4fc70d1 100644 --- a/qupulse/pulses/table_pulse_template.py +++ b/qupulse/pulses/table_pulse_template.py @@ -7,7 +7,7 @@ declared parameters. """ -from typing import Union, Dict, List, Set, Optional, Any, Tuple, Sequence, NamedTuple +from typing import Union, Dict, List, Set, Optional, Any, Tuple, Sequence, NamedTuple, Callable import numbers import itertools import warnings @@ -15,6 +15,7 @@ import numpy as np import sympy from sympy.logic.boolalg import BooleanAtom +import more_itertools from qupulse.utils.types import ChannelID from qupulse.serialization import Serializer, PulseRegistryType @@ -38,13 +39,13 @@ class TableEntry(NamedTuple('TableEntry', [('t', ExpressionScalar), ('v', Expression), - ('interp', InterpolationStrategy)])): + ('interp', Optional[InterpolationStrategy])])): __slots__ = () - def __new__(cls, t: ValueInInit, v: ValueInInit, interp: Union[str, InterpolationStrategy]='default'): + def __new__(cls, t: ValueInInit, v: ValueInInit, interp: Optional[Union[str, InterpolationStrategy]]='default'): if interp in TablePulseTemplate.interpolation_strategies: interp = TablePulseTemplate.interpolation_strategies[interp] - if not isinstance(interp, InterpolationStrategy): + if interp is not None and not isinstance(interp, InterpolationStrategy): raise KeyError(interp, 'is not a valid interpolation strategy') return super().__new__(cls, ExpressionScalar.make(t), @@ -57,7 +58,73 @@ def instantiate(self, parameters: Dict[str, numbers.Real]) -> TableWaveformEntry self.interp) def get_serialization_data(self) -> tuple: - return self.t.get_serialization_data(), self.v.get_serialization_data(), str(self.interp) + interp = None if self.interp is None else str(self.interp) + return self.t.get_serialization_data(), self.v.get_serialization_data(), interp + + @classmethod + def _sequence_integral(cls, entry_sequence: Sequence['TableEntry'], + expression_extractor: Callable[[Expression], sympy.Expr]) -> ExpressionScalar: + """Returns an expression for the time integral over the complete sequence of table entries. + + Args: + entry_sequence: Sequence of table entries. Assumed to be ordered by time. + expression_extractor: Convert each entry's voltage into a sympy expression. Can be used to select single + channels from a vectorized expression. + + Returns: + Scalar expression for the integral. + """ + expr = 0 + for first_entry, second_entry in more_itertools.pairwise(entry_sequence): + substitutions = {'t0': first_entry.t.sympified_expression, + 'v0': expression_extractor(first_entry.v), + 't1': second_entry.t.sympified_expression, + 'v1': expression_extractor(second_entry.v)} + expr += second_entry.interp.integral.sympified_expression.subs(substitutions, simultaneous=True) + return ExpressionScalar(expr) + + @classmethod + def _sequence_as_expression(cls, entry_sequence: Sequence['TableEntry'], + expression_extractor: Callable[[Expression], sympy.Expr], + t: sympy.Dummy, + pre_value: Optional[sympy.Expr], + post_value: Optional[sympy.Expr]) -> ExpressionScalar: + """Create an expression out of a sequence of table entries. + + Args: + entry_sequence: Table entries to be represented as an expression. They are assumed to be ordered by time. + expression_extractor: Convert each entry's voltage into a sympy expression. Can be used to select single + channels from a vectorized expression. + t: Time variable + pre_value: If not None all t values smaller than the first entry's time give this value + post_value: If not None all t values larger than the last entry's time give this value + + Returns: + Scalar expression that covers the complete sequence and is zero outside. + """ + + # args are tested in order + piecewise_args = [] + for first_entry, second_entry in more_itertools.pairwise(entry_sequence): + t0, t1 = first_entry.t.sympified_expression, second_entry.t.sympified_expression + substitutions = {'t0': t0, + 'v0': expression_extractor(first_entry.v), + 't1': t1, + 'v1': expression_extractor(second_entry.v), + 't': t} + time_gate = sympy.And(t0 <= t, t < t1) + + interpolation_expr = second_entry.interp.expression.underlying_expression.subs(substitutions, + simultaneous=True) + + piecewise_args.append((interpolation_expr, time_gate)) + + if pre_value is not None: + piecewise_args.append((pre_value, t < entry_sequence[0].t.sympified_expression)) + if post_value is not None: + piecewise_args.append((post_value, t >= entry_sequence[-1].t.sympified_expression)) + + return ExpressionScalar(sympy.Piecewise(*piecewise_args)) class TablePulseTemplate(AtomicPulseTemplate, ParameterConstrainer): @@ -142,16 +209,17 @@ def __init__(self, entries: Dict[ChannelID, Sequence[EntryInInit]], self._register(registry=registry) def _add_entry(self, channel, new_entry: TableEntry) -> None: + ch_entries = self._entries[channel] # comparisons with Expression can yield None -> use 'is True' and 'is False' if (new_entry.t < 0) is True: raise ValueError('Time parameter number {} of channel {} is negative.'.format( - len(self._entries[channel]), channel)) + len(ch_entries), channel)) - for previous_entry in self._entries[channel]: + for previous_entry in ch_entries: if (new_entry.t < previous_entry.t) is True: raise ValueError('Time parameter number {} of channel {} is smaller than a previous one'.format( - len(self._entries[channel]), channel)) + len(ch_entries), channel)) self._entries[channel].append(new_entry) @@ -348,15 +416,24 @@ def is_valid_interpolation_strategy(inter): def integral(self) -> Dict[ChannelID, ExpressionScalar]: expressions = dict() for channel, channel_entries in self._entries.items(): + pre_entry = TableEntry(0, channel_entries[0].v, None) + post_entry = TableEntry(self.duration, channel_entries[-1].v, 'hold') + channel_entries = [pre_entry] + channel_entries + [post_entry] + expressions[channel] = TableEntry._sequence_integral(channel_entries, lambda v: v.sympified_expression) - expr = 0 - for first_entry, second_entry in zip(channel_entries[:-1], channel_entries[1:]): - substitutions = {'t0': ExpressionScalar(first_entry.t).sympified_expression, 'v0': ExpressionScalar(first_entry.v).sympified_expression, - 't1': ExpressionScalar(second_entry.t).sympified_expression, 'v1': ExpressionScalar(second_entry.v).sympified_expression} - - expr += first_entry.interp.integral.sympified_expression.subs(substitutions) - expressions[channel] = ExpressionScalar(expr) + return expressions + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + expressions = dict() + for channel, channel_entries in self._entries.items(): + pre_value = channel_entries[0].v.sympified_expression + post_value = channel_entries[-1].v.sympified_expression + + expressions[channel] = TableEntry._sequence_as_expression(channel_entries, + lambda v: v.sympified_expression, + t=self._AS_EXPRESSION_TIME, + pre_value=pre_value, + post_value=post_value) return expressions diff --git a/tests/pulses/arithmetic_pulse_template_tests.py b/tests/pulses/arithmetic_pulse_template_tests.py index df2929ca3..11b7d65d7 100644 --- a/tests/pulses/arithmetic_pulse_template_tests.py +++ b/tests/pulses/arithmetic_pulse_template_tests.py @@ -118,6 +118,36 @@ def test_integral(self): self.assertEqual(expected_plus, (lhs + rhs).integral) self.assertEqual(expected_minus, (lhs - rhs).integral) + def test_as_expression(self): + integrals_lhs = dict(a=ExpressionScalar('a_lhs'), b=ExpressionScalar('b')) + integrals_rhs = dict(a=ExpressionScalar('a_rhs'), c=ExpressionScalar('c')) + + duration = 4 + t = DummyPulseTemplate._AS_EXPRESSION_TIME + expr_lhs = {ch: i * t / duration for ch, i in integrals_lhs.items()} + expr_rhs = {ch: i * t / duration for ch, i in integrals_rhs.items()} + + lhs = DummyPulseTemplate(duration=duration, defined_channels={'a', 'b'}, + parameter_names={'x', 'y'}, integrals=integrals_lhs) + rhs = DummyPulseTemplate(duration=duration, defined_channels={'a', 'c'}, + parameter_names={'x', 'z'}, integrals=integrals_rhs) + + expected_added = { + 'a': expr_lhs['a'] + expr_rhs['a'], + 'b': expr_lhs['b'], + 'c': expr_rhs['c'] + } + added_expr = (lhs + rhs)._as_expression() + self.assertEqual(expected_added, added_expr) + + subs_expr = (lhs - rhs)._as_expression() + expected_subs = { + 'a': expr_lhs['a'] - expr_rhs['a'], + 'b': expr_lhs['b'], + 'c': -expr_rhs['c'] + } + self.assertEqual(expected_subs, subs_expr) + def test_duration(self): lhs = DummyPulseTemplate(duration=ExpressionScalar('x'), defined_channels={'a', 'b'}, parameter_names={'x', 'y'}) rhs = DummyPulseTemplate(duration=ExpressionScalar('y'), defined_channels={'a', 'c'}, parameter_names={'x', 'z'}) diff --git a/tests/pulses/function_pulse_tests.py b/tests/pulses/function_pulse_tests.py index b33c91e2e..ba4d214da 100644 --- a/tests/pulses/function_pulse_tests.py +++ b/tests/pulses/function_pulse_tests.py @@ -84,6 +84,11 @@ def test_integral(self) -> None: pulse = FunctionPulseTemplate('sin(0.5*t+b)', '2*Tmax') self.assertEqual({'default': Expression('2.0*cos(b) - 2.0*cos(1.0*Tmax+b)')}, pulse.integral) + def test_as_expression(self): + pulse = FunctionPulseTemplate('sin(0.5*t+b)', '2*Tmax') + expr = sympy.sin(0.5 * pulse._AS_EXPRESSION_TIME + sympy.sympify('b')) + self.assertEqual({'default': Expression.make(expr)}, pulse._as_expression()) + class FunctionPulseSerializationTest(SerializableTests, unittest.TestCase): diff --git a/tests/pulses/mapping_pulse_template_tests.py b/tests/pulses/mapping_pulse_template_tests.py index 2a70c78da..2c9d4355f 100644 --- a/tests/pulses/mapping_pulse_template_tests.py +++ b/tests/pulses/mapping_pulse_template_tests.py @@ -253,6 +253,26 @@ def test_integral(self) -> None: self.assertEqual({'a': Expression('2*f'), 'B': Expression('-3.2*f+2.3')}, pulse.integral) + def test_as_expression(self): + from sympy.abc import f, k, b + duration = 5 + dummy = DummyPulseTemplate(defined_channels={'A', 'B', 'C'}, + parameter_names={'k', 'f', 'b'}, + integrals={'A': Expression(2 * k), + 'B': Expression(-3.2*f+b), + 'C': Expression(1)}, duration=duration) + t = DummyPulseTemplate._AS_EXPRESSION_TIME + dummy_expr = {ch: i * t / duration for ch, i in dummy._integrals.items()} + pulse = MappingPulseTemplate(dummy, parameter_mapping={'k': 'f', 'b': 2.3}, channel_mapping={'A': 'a', + 'C': None}, + allow_partial_parameter_mapping=True) + + expected = { + 'a': Expression(2*f*t/duration), + 'B': Expression((-3.2*f + 2.3)*t/duration), + } + self.assertEqual(expected, pulse._as_expression()) + def test_duration(self): seconds2ns = 1e9 pulse_duration = 1.0765001496284785e-07 @@ -507,6 +527,7 @@ def test_deserialize(self) -> None: self.assertEqual(data['parameter_constraints'], [str(pc) for pc in deserialized.parameter_constraints]) self.assertIs(deserialized.template, dummy_pt) + class MappingPulseTemplateRegressionTests(unittest.TestCase): def test_issue_451(self): from qupulse.pulses import TablePT, SequencePT, AtomicMultiChannelPT diff --git a/tests/pulses/point_pulse_template_tests.py b/tests/pulses/point_pulse_template_tests.py index 969a240c5..07e858243 100644 --- a/tests/pulses/point_pulse_template_tests.py +++ b/tests/pulses/point_pulse_template_tests.py @@ -76,27 +76,34 @@ def test_parameter_names(self): def test_integral(self) -> None: pulse = PointPulseTemplate( - [(1, (2, 'b'), 'linear'), - (3, (0, 0), 'jump'), - (4, (2, 'c'), 'hold'), + [(1, (2, 'b'), 'hold'), + (3, (0, 0), 'linear'), + (4, (2, 'c'), 'jump'), (5, (8, 'd'), 'hold')], [0, 'other_channel'] ) - self.assertEqual({0: ExpressionScalar('6'), - 'other_channel': ExpressionScalar('b + 2*c')}, + self.assertEqual({0: ExpressionScalar('2 + 6'), + 'other_channel': ExpressionScalar('b + b + 2*c')}, pulse.integral) pulse = PointPulseTemplate( - [(1, ('2', 'b'), 'linear'), ('t0', (0, 0), 'jump'), (4, (2.0, 'c'), 'hold'), ('g', (8, 'd'), 'hold')], + [(1, ('2', 'b'), 'hold'), ('t0', (0, 0), 'linear'), (4, (2.0, 'c'), 'jump'), ('g', (8, 'd'), 'hold')], ['symbolic', 1] ) - self.assertEqual({'symbolic': ExpressionScalar('2.0*g - 1.0*t0 - 1.0'), - 1: ExpressionScalar('b*(t0 - 1) / 2 + c*(g - 4) + c*(-t0 + 4)')}, + self.assertEqual({'symbolic': ExpressionScalar('2 + 2.0*g - 1.0*t0 - 1.0'), + 1: ExpressionScalar('b + b*(t0 - 1) / 2 + c*(g - 4) + c*(-t0 + 4)')}, pulse.integral) ppt = PointPulseTemplate([(0, 0), ('t_init', 0)], ['X', 'Y']) self.assertEqual(ppt.integral, {'X': 0, 'Y': 0}) + ppt = PointPulseTemplate([(0., 'a'), ('t_1', 'b', 'linear'), ('t_2', (0, 0))], ('X', 'Y')) + parameters = {'a': (3.4, 4.1), 'b': 4, 't_1': 2, 't_2': 5} + integral = {ch: v.evaluate_in_scope(parameters) for ch, v in ppt.integral.items()} + self.assertEqual({'X': 2 * (3.4 + 4) / 2 + (5 - 2) * 4, + 'Y': 2 * (4.1 + 4) / 2 + (5 - 2) * 4}, + integral) + class PointPulseTemplateSequencingTests(unittest.TestCase): def test_build_waveform_empty(self): @@ -147,10 +154,10 @@ def test_build_waveform_multi_channel_same(self): (1., 0., HoldInterpolationStrategy()), (1.1, 21., LinearInterpolationStrategy())]) self.assertEqual(wf.defined_channels, {1, 'A'}) - self.assertEqual(wf._sub_waveforms[0].defined_channels, {1}) - self.assertEqual(wf._sub_waveforms[0], expected_1) - self.assertEqual(wf._sub_waveforms[1].defined_channels, {'A'}) - self.assertEqual(wf._sub_waveforms[1], expected_A) + self.assertEqual(wf._wf_pad[1][0].defined_channels, {1}) + self.assertEqual(wf._wf_pad[1][0], expected_1) + self.assertEqual(wf._wf_pad['A'][0].defined_channels, {'A'}) + self.assertEqual(wf._wf_pad['A'][0], expected_A) def test_build_waveform_multi_channel_vectorized(self): ppt = PointPulseTemplate([('t1', 'A'), @@ -168,10 +175,10 @@ def test_build_waveform_multi_channel_vectorized(self): (1., 0., HoldInterpolationStrategy()), (1.1, 20., LinearInterpolationStrategy())]) self.assertEqual(wf.defined_channels, {1, 'A'}) - self.assertEqual(wf._sub_waveforms[0].defined_channels, {1}) - self.assertEqual(wf._sub_waveforms[0], expected_1) - self.assertEqual(wf._sub_waveforms[1].defined_channels, {'A'}) - self.assertEqual(wf._sub_waveforms[1], expected_A) + self.assertEqual(wf._wf_pad[1][0].defined_channels, {1}) + self.assertEqual(wf._wf_pad[1][0], expected_1) + self.assertEqual(wf._wf_pad['A'][0].defined_channels, {'A'}) + self.assertEqual(wf._wf_pad['A'][0], expected_A) def test_build_waveform_none_channel(self): ppt = PointPulseTemplate([('t1', 'A'), @@ -290,3 +297,61 @@ def test_serializer_integration_old(self): self.assertEqual(template.point_pulse_entries, self.template.point_pulse_entries) self.assertEqual(template.measurement_declarations, self.template.measurement_declarations) self.assertEqual(template.parameter_constraints, self.template.parameter_constraints) + + +class PointPulseExpressionIntegralTests(unittest.TestCase): + def setUp(self): + self.template = PointPulseTemplate(**PointPulseTemplateSerializationTests().make_kwargs()) + self.parameter_sets = [ + {'foo': 1., 'hugo': 2., 'sudo': 3., 'A': 4., 'B': 5., 'a': 6., 'ilse': 7., 'k': 8.}, + {'foo': 1.1, 'hugo': 2.6, 'sudo': 2.7, 'A': np.array([3., 4.]), 'B': 5., 'a': 6., 'ilse': 7., 'k': 8.}, + ] + + def test_integral_as_expression_compatible(self): + import sympy + + t = self.template._AS_EXPRESSION_TIME + as_expression = self.template._as_expression() + integral = self.template.integral + duration = self.template.duration.underlying_expression + + self.assertEqual(self.template.defined_channels, integral.keys()) + self.assertEqual(self.template.defined_channels, as_expression.keys()) + + for channel in self.template.defined_channels: + ch_expr = as_expression[channel].underlying_expression + ch_int = integral[channel].underlying_expression + + symbolic = sympy.integrate(ch_expr, (t, 0, duration)) + symbolic = sympy.simplify(symbolic) + + for parameters in self.parameter_sets: + num_from_expr = ExpressionScalar(symbolic).evaluate_in_scope(parameters) + num_from_in = ExpressionScalar(ch_int).evaluate_in_scope(parameters) + np.testing.assert_almost_equal(num_from_in, num_from_expr) + + # TODO: the following fails even with a lot of assumptions in sympy 1.6 + # self.assertEqual(ch_int, symbolic) + + def test_as_expression_wf_and_sample_compatible(self): + as_expression = self.template._as_expression() + + for parameters in self.parameter_sets: + wf = self.template.build_waveform(parameters, {c: c for c in self.template.defined_channels}) + + ts = np.linspace(0, float(wf.duration), num=33) + sampled = {ch: wf.get_sampled(ch, ts) for ch in self.template.defined_channels} + + from_expr = {} + for ch, expected_vs in sampled.items(): + ch_expr = as_expression[ch] + + ch_from_expr = [] + for t, expected in zip(ts, expected_vs): + result_expr = ch_expr.evaluate_symbolic({**parameters, self.template._AS_EXPRESSION_TIME: t}) + ch_from_expr.append(result_expr.sympified_expression) + from_expr[ch] = ch_from_expr + + np.testing.assert_almost_equal(expected_vs, ch_from_expr) + + diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index 6ef448118..f78793685 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -3,6 +3,7 @@ from unittest import mock from typing import Optional, Dict, Set, Any, Union +import sympy from qupulse.parameter_scope import Scope, DictScope from qupulse.utils.types import ChannelID @@ -134,6 +135,9 @@ def duration(self) -> Expression: def integral(self) -> Dict[ChannelID, ExpressionScalar]: raise NotImplementedError() + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + raise NotImplementedError() + class PulseTemplateTest(unittest.TestCase): @@ -352,7 +356,7 @@ class AtomicPulseTemplateTests(unittest.TestCase): def test_internal_create_program(self) -> None: measurement_windows = [('M', 0, 5)] single_wf = DummyWaveform(duration=6, defined_channels={'A'}) - wf = MultiChannelWaveform([single_wf]) + wf = MultiChannelWaveform.from_iterable([single_wf]) template = AtomicPulseTemplateStub(measurements=measurement_windows, parameter_names={'foo'}) scope = DictScope.from_kwargs(foo=7.2, volatile={'gutes_zeuch'}) @@ -437,3 +441,4 @@ def test_internal_create_program_volatile(self): to_single_waveform=set(), global_transformation=None) self.assertEqual(Loop(), program) + diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index 549935a32..d22952bed 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -75,9 +75,9 @@ def __hash__(self): class DummyWaveform(Waveform): - def __init__(self, duration: float=0, sample_output: Union[numpy.ndarray, dict]=None, defined_channels=None) -> None: + def __init__(self, duration: Union[float, TimeType]=0, sample_output: Union[numpy.ndarray, dict]=None, defined_channels=None) -> None: super().__init__() - self.duration_ = TimeType.from_float(duration) + self.duration_ = duration if isinstance(duration, TimeType) else TimeType.from_float(duration) self.sample_output = sample_output if defined_channels is None: if isinstance(sample_output, dict): @@ -142,6 +142,15 @@ def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'Waveform' def defined_channels(self): return self.defined_channels_ + def last_value(self, channel) -> float: + if self.sample_output is None: + return 0. + elif isinstance(self.sample_output, dict): + sample_output = self.sample_output[channel] + else: + sample_output = self.sample_output + return sample_output[-1] + class DummyInterpolationStrategy(InterpolationStrategy): @@ -168,13 +177,13 @@ class DummyPulseTemplate(AtomicPulseTemplate): def __init__(self, requires_stop: bool=False, - parameter_names: Set[str]={}, - defined_channels: Set[ChannelID]={'default'}, + parameter_names: Set[str]=set(), + defined_channels: Set[ChannelID]=None, duration: Any=0, waveform: Waveform=tuple(), measurement_names: Set[str] = set(), measurements: list=list(), - integrals: Dict[ChannelID, ExpressionScalar]={'default': ExpressionScalar(0)}, + integrals: Dict[ChannelID, ExpressionScalar]=None, program: Optional[Loop]=None, identifier=None, registry=None) -> None: @@ -182,6 +191,11 @@ def __init__(self, self.requires_stop_ = requires_stop self.requires_stop_arguments = [] + if defined_channels is None: + defined_channels = {'default'} + if integrals is None: + integrals = {ch: ExpressionScalar(0) for ch in defined_channels} + self.parameter_names_ = parameter_names self.defined_channels_ = defined_channels self._duration = Expression(duration) @@ -252,3 +266,10 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: def compare_key(self) -> Tuple[Any, ...]: return (self.requires_stop_, self.parameter_names, self.defined_channels, self.duration, self.waveform, self.measurement_names, self.integral) + + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + assert self.duration != 0 + t = self._AS_EXPRESSION_TIME + duration = self.duration.underlying_expression + return {ch: ExpressionScalar(integral.underlying_expression*t/duration) + for ch, integral in self.integral.items()} diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index efae8335c..580fd9b7a 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -2,8 +2,9 @@ import warnings import numpy +import sympy -from qupulse.expressions import Expression +from qupulse.expressions import Expression, ExpressionScalar from qupulse.serialization import Serializer from qupulse.pulses.table_pulse_template import TablePulseTemplate, TableWaveform, TableEntry, TableWaveformEntry, ZeroDurationTablePulseTemplate, AmbiguousTablePulseEntry, concatenate from qupulse.pulses.parameters import ParameterNotProvidedException, ParameterConstraintViolation, ParameterConstraint @@ -38,6 +39,54 @@ def test_unknown_interpolation_strategy(self): with self.assertRaises(KeyError): TableEntry(0, 0, 'foo') + def test_sequence_integral(self): + def get_sympy(v): + return v.sympified_expression + + entries = [TableEntry(0, 0, 'hold'), TableEntry(1, 0, 'hold')] + self.assertEqual(ExpressionScalar(0), TableEntry._sequence_integral(entries, get_sympy)) + + entries = [TableEntry(0, 1, 'hold'), TableEntry(1, 1, 'hold')] + self.assertEqual(ExpressionScalar(1), TableEntry._sequence_integral(entries, get_sympy)) + + entries = [TableEntry(0, 0, 'linear'), TableEntry(1, 1, 'hold')] + self.assertEqual(ExpressionScalar(.5), TableEntry._sequence_integral(entries, get_sympy)) + + entries = [TableEntry('t0', 'a', 'linear'), TableEntry('t1', 'b', 'hold'), TableEntry('t2', 'c', 'hold')] + self.assertEqual(ExpressionScalar('(t1-t0)*(a+b)/2 + (t2-t1)*b'), + TableEntry._sequence_integral(entries, get_sympy)) + + def test_sequence_as_expression(self): + def get_sympy(v): + return v.sympified_expression + + t = sympy.Dummy('t') + + times = { + t: 0.5, + 't0': 0.3, + 't1': 0.7, + 't2': 1.3, + } + + entries = [TableEntry(0, 0, None), TableEntry(1, 0, 'hold')] + self.assertEqual(ExpressionScalar(0), + TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) + + entries = [TableEntry(0, 1, None), TableEntry(1, 1, 'hold')] + self.assertEqual(ExpressionScalar(1), + TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) + + entries = [TableEntry(0, 0, None), TableEntry(1, 1, 'linear')] + self.assertEqual(ExpressionScalar(.5), + TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) + + entries = [TableEntry('t0', 'a', 'linear'), + TableEntry('t1', 'b', 'hold'), + TableEntry('t2', 'c', 'hold')] + self.assertEqual(ExpressionScalar('(a+b)*.5'), + TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) + class TablePulseTemplateTest(unittest.TestCase): def __init__(self, *args, **kwargs): @@ -417,15 +466,42 @@ def test_identifier(self) -> None: self.assertEqual(pulse.identifier, identifier) def test_integral(self) -> None: - pulse = TablePulseTemplate(entries={0: [(1, 2, 'linear'), (3, 0, 'jump'), (4, 2, 'hold'), (5, 8, 'hold')], - 'other_channel': [(0, 7, 'linear'), (2, 0, 'hold'), (10, 0)], - 'symbolic': [(3, 'a', 'hold'), ('b', 4, 'linear'), ('c', Expression('d'), 'hold')]}) - expected = {0: Expression('6'), + pulse = TablePulseTemplate(entries={0: [(1, 2), (3, 0, 'linear'), (4, 2, 'jump'), (5, 8, 'hold')], + 'other_channel': [(0, 7), (2, 0, 'linear'), (10, 0)], + 'symbolic': [(3, 'a'), ('b', 4, 'hold'), ('c', Expression('d'), 'linear')]}) + expected = {0: Expression('2 + 2 + 2 + 2 + (Max(c, 10) - 5) * 8'), 'other_channel': Expression(7), - 'symbolic': Expression('(b-3.)*a + (c-b)*(d+4.) / 2')} + 'symbolic': Expression('3 * a + (b-3)*a + (c-b)*(d+4) / 2 + (Max(10, c) - c) * d')} self.assertEqual(expected, pulse.integral) + def test_as_expression(self): + pulse = TablePulseTemplate(entries={0: [(0, 0), (1, 2), (3, 0, 'linear'), (4, 2, 'jump'), (5, 8, 'hold')], + 'other_channel': [(0, 7), (2, 0, 'linear'), (10, 0)], + 'symbolic': [(3, 'a'), ('b', 4, 'hold'), + ('c', Expression('d'), 'linear')]}) + parameters = dict(a=2., b=4, c=9, d=8) + wf = pulse.build_waveform(parameters, channel_mapping={0: 0, + 'other_channel': 'other_channel', + 'symbolic': 'symbolic'}) + expr = pulse._as_expression() + ts = numpy.linspace(0, float(wf.duration), num=33) + sampled = {ch: wf.get_sampled(ch, ts) for ch in pulse.defined_channels} + + from_expr = {} + for ch, expected_vs in sampled.items(): + ch_expr = expr[ch] + + ch_from_expr = [] + for t, expected in zip(ts, expected_vs): + params = {**parameters, TablePulseTemplate._AS_EXPRESSION_TIME: t} + result = ch_expr.sympified_expression.subs(params, simultaneous=True) + ch_from_expr.append(result) + from_expr[ch] = ch_from_expr + + numpy.testing.assert_almost_equal(expected_vs, ch_from_expr) + + class TablePulseTemplateConstraintTest(ParameterConstrainerTest): def __init__(self, *args, **kwargs): @@ -614,10 +690,10 @@ def test_build_waveform_multi_channel(self): channel_mapping=channel_mapping) self.assertIsInstance(waveform, MultiChannelWaveform) - self.assertEqual(len(waveform._sub_waveforms), 2) + self.assertEqual(len(waveform._wf_pad), 2) channels = {'oh', 'ch'} - for wf in waveform._sub_waveforms: + for wf, _ in waveform._wf_pad.values(): self.assertIsInstance(wf, TableWaveform) self.assertIn(wf._channel_id, channels) channels.remove(wf._channel_id) From 907f456c0d5c55ee82e3311b26ad8023fd39320b Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 25 Nov 2020 15:16:54 +0100 Subject: [PATCH 4/8] Improve Broadcast performance by translating to lambda directly and specialcasing indexing --- qupulse/pulses/point_pulse_template.py | 9 ++- qupulse/utils/sympy.py | 83 ++++++++++++++++++++++++-- tests/utils/sympy_tests.py | 26 ++++++-- 3 files changed, 101 insertions(+), 17 deletions(-) diff --git a/qupulse/pulses/point_pulse_template.py b/qupulse/pulses/point_pulse_template.py index 0d41abedf..43cc3d40e 100644 --- a/qupulse/pulses/point_pulse_template.py +++ b/qupulse/pulses/point_pulse_template.py @@ -6,12 +6,11 @@ import sympy import numpy as np -from qupulse.utils.sympy import Broadcast +from qupulse.utils.sympy import IndexedBroadcast from qupulse.utils.types import ChannelID from qupulse.expressions import Expression, ExpressionScalar from qupulse._program.waveforms import TableWaveform, TableWaveformEntry -from qupulse.pulses.parameters import Parameter, ParameterNotProvidedException, ParameterConstraint,\ - ParameterConstrainer +from qupulse.pulses.parameters import ParameterConstraint, ParameterConstrainer from qupulse.pulses.pulse_template import AtomicPulseTemplate, MeasurementDeclaration from qupulse.pulses.table_pulse_template import TableEntry, EntryInInit from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform @@ -145,7 +144,7 @@ def value_trafo(v): try: return v.underlying_expression[i] except TypeError: - return sympy.IndexedBase(Broadcast(v.underlying_expression, shape))[i] + return IndexedBroadcast(v.underlying_expression, shape, i) pre_entry = TableEntry(0, self._entries[0].v, None) entries = [pre_entry] + self._entries expressions[channel] = TableEntry._sequence_integral(entries, expression_extractor=value_trafo) @@ -161,7 +160,7 @@ def value_trafo(v): try: return v.underlying_expression[i] except TypeError: - return sympy.IndexedBase(Broadcast(v.underlying_expression, shape))[i] + return IndexedBroadcast(v.underlying_expression, shape, i) pre_value = value_trafo(self._entries[0].v) post_value = value_trafo(self._entries[-1].v) pw = TableEntry._sequence_as_expression(self._entries, diff --git a/qupulse/utils/sympy.py b/qupulse/utils/sympy.py index 9038a437f..bfa98f0e8 100644 --- a/qupulse/utils/sympy.py +++ b/qupulse/utils/sympy.py @@ -2,6 +2,7 @@ from numbers import Number from types import CodeType import warnings +import functools import builtins import math @@ -80,7 +81,7 @@ def __contains__(self, k) -> bool: class Broadcast(sympy.Function): - """Broadcast x to the specified shape using numpy.broadcast_to + """Broadcast x to the specified shape using numpy.broadcast_to. The shape must not be symbolic. Examples: >>> bc = Broadcast('a', (3,)) @@ -90,14 +91,20 @@ class Broadcast(sympy.Function): nargs = (2,) @classmethod - def eval(cls, x, shape) -> Optional[sympy.Array]: - if getattr(shape, 'free_symbols', None): - # cannot do anything + def eval(cls, x, shape: Tuple[int]) -> Optional[sympy.Array]: + shape = _parse_broadcast_shape(shape, user=cls) + if shape is None: return None if hasattr(x, '__len__') or not x.free_symbols: return sympy.Array(numpy.broadcast_to(x, shape)) + def __getitem__(self, item: Union): + return IndexedBroadcast(*self.args, item) + + # Not iterable. If not set to None __getitem__ would be used for iterating + __iter__ = None + def _eval_Integral(self, *symbols, **assumptions): x, shape = self.args return Broadcast(sympy.Integral(x, *symbols, **assumptions), shape) @@ -106,6 +113,45 @@ def _eval_derivative(self, sym): x, shape = self.args return Broadcast(sympy.diff(x, sym), shape) + def _numpycode(self, printer, **kwargs): + x, shape = map(functools.partial(printer._print, **kwargs), self.args) + return f'broadcast_to({x}, {shape})' + + +class IndexedBroadcast(sympy.Function): + """Broadcast x to the specified shape using numpy.broadcast_to and index in the result.""" + nargs = (3,) + + @classmethod + def eval(cls, x, shape: Tuple[int], idx: int) -> Optional[sympy.Expr]: + shape = _parse_broadcast_shape(shape, user=cls) + idx = _parse_broadcast_index(idx, user=cls) + if shape is None or idx is None: + return None + + if hasattr(x, '__len__') or not x.free_symbols: + return sympy.Array(numpy.broadcast_to(x, shape))[idx] + + def _eval_Integral(self, *symbols, **assumptions): + x, shape, idx = self.args + return IndexedBroadcast(sympy.Integral(x, *symbols, **assumptions), shape, idx) + + def _eval_derivative(self, sym): + x, shape, idx = self.args + return IndexedBroadcast(sympy.diff(x, sym), shape, idx) + + def _eval_is_commutative(self): + x, shape, idx = self.args + result = self.eval(*self.args) + if result is None: + return x.is_commutative + else: + return result.is_commutative + + def _numpycode(self, printer, **kwargs): + x, shape, idx = map(functools.partial(printer._print, **kwargs), self.args) + return f'broadcast_to({x}, {shape})[{idx}]' + class Len(sympy.Function): nargs = 1 @@ -121,7 +167,8 @@ def eval(cls, arg) -> Optional[sympy.Integer]: sympify_namespace = {'len': Len, 'Len': Len, - 'Broadcast': Broadcast} + 'Broadcast': Broadcast, + 'IndexedBroadcast': IndexedBroadcast} def numpy_compatible_mul(*args) -> Union[sympy.Mul, sympy.Array]: @@ -187,7 +234,7 @@ def sympify(expr: Union[str, Number, sympy.Expr, numpy.str_], **kwargs) -> sympy if True:#err.args[0] == "'Symbol' object is not subscriptable": indexed_base = get_subscripted_symbols(expr) - return sympy.sympify(expr, **kwargs, locals={**{k: sympy.IndexedBase(k) + return sympy.sympify(expr, **kwargs, locals={**{k: k if isinstance(k, Broadcast) else sympy.IndexedBase(k) for k in indexed_base}, **sympify_namespace}) @@ -302,3 +349,27 @@ def almost_equal(lhs: sympy.Expr, rhs: sympy.Expr, epsilon: float=1e-15) -> Opti return False else: return None + + +class UnsupportedBroadcastArgumentWarning(RuntimeWarning): + pass + + +def _parse_broadcast_shape(shape: Tuple[int], user: type) -> Optional[Tuple[int]]: + try: + return tuple(map(int, shape)) + except TypeError as err: + warnings.warn(f"The shape passed to {user.__module__}.{user.__name__} is not convertible to a tuple of integers: {err}\n" + "Be aware that using a symbolic shape can lead to unexpected behaviour.", + category=UnsupportedBroadcastArgumentWarning) + return None + + +def _parse_broadcast_index(idx: int, user: type) -> Optional[int]: + try: + return int(idx) + except TypeError as err: + warnings.warn(f"The index passed to {user.__module__}.{user.__name__} is not convertible to an integer: {err}\n" + "Be aware that using a symbolic index can lead to unexpected behaviour.", + category=UnsupportedBroadcastArgumentWarning) + return None diff --git a/tests/utils/sympy_tests.py b/tests/utils/sympy_tests.py index d489a7e51..3851a8264 100644 --- a/tests/utils/sympy_tests.py +++ b/tests/utils/sympy_tests.py @@ -18,7 +18,7 @@ from qupulse.utils.sympy import sympify as qc_sympify, substitute_with_eval, recursive_substitution, Len,\ evaluate_lambdified, evaluate_compiled, get_most_simple_representation, get_variables, get_free_symbols,\ - almost_equal, Broadcast, IndexedBasedFinder + almost_equal, Broadcast, IndexedBasedFinder, IndexedBroadcast ################################################### SUBSTITUTION ####################################################### @@ -76,7 +76,9 @@ ] index_sympify = [ - ('a[i]', a_[i]) + ('a[i]', a_[i]), + ('Broadcast(a, (3,))[0]', Broadcast(a, (3,))[0]), + ('IndexedBroadcast(a, (3,), 1)', IndexedBroadcast(a, (3,), 1)) ] @@ -448,13 +450,25 @@ def test_expression_equality(self): test_numeric_equal = unittest.expectedFailure(test_expression_equality) if distutils.version.StrictVersion(sympy.__version__) >= distutils.version.StrictVersion('1.5') else test_expression_equality def test_integral(self): - symbolic = Broadcast(a, (3,)) + symbolic = Broadcast(a*c, (3,)) + indexed = symbolic[1] integ = sympy.Integral(symbolic, (a, 0, b)) - self.assertEqual(integ, Broadcast(sympy.Integral(a, (a, 0, b)), (3,))) + idx_integ = sympy.Integral(indexed, (a, 0, b)) + self.assertEqual(integ, Broadcast(sympy.Integral(a*c, (a, 0, b)), (3,))) + self.assertEqual(idx_integ, Broadcast(sympy.Integral(a*c, (a, 0, b)), (3,))[1]) + + diffed = sympy.diff(symbolic, a) + idx_diffed = sympy.diff(indexed, a) + self.assertEqual(symbolic.subs(a, 1), diffed) + self.assertEqual(indexed.subs(a, 1), idx_diffed) + + def test_indexing(self): + symbolic = Broadcast(a, (3,)) + indexed = symbolic[1] - diffed = sympy.diff(integ, b).subs({b: a}) - self.assertEqual(symbolic, diffed) + self.assertEqual(7, indexed.subs(a, 7)) + self.assertEqual(7, indexed.subs(a, (6, 7, 8))) class IndexedBasedFinderTests(unittest.TestCase): From de7d5b86011db2be190c380191fa50ae83bed01b Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 25 Nov 2020 15:18:41 +0100 Subject: [PATCH 5/8] Reorder piecewise args for sequence entry translation --- qupulse/pulses/table_pulse_template.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/qupulse/pulses/table_pulse_template.py b/qupulse/pulses/table_pulse_template.py index 9b4fc70d1..59a9dfd53 100644 --- a/qupulse/pulses/table_pulse_template.py +++ b/qupulse/pulses/table_pulse_template.py @@ -105,6 +105,14 @@ def _sequence_as_expression(cls, entry_sequence: Sequence['TableEntry'], # args are tested in order piecewise_args = [] + + # first define out of sequence values. Otherwise integration might produce strange results + if pre_value is not None: + piecewise_args.append((pre_value, t < entry_sequence[0].t.sympified_expression)) + + if post_value is not None: + piecewise_args.append((post_value, t >= entry_sequence[-1].t.sympified_expression)) + for first_entry, second_entry in more_itertools.pairwise(entry_sequence): t0, t1 = first_entry.t.sympified_expression, second_entry.t.sympified_expression substitutions = {'t0': t0, @@ -119,10 +127,8 @@ def _sequence_as_expression(cls, entry_sequence: Sequence['TableEntry'], piecewise_args.append((interpolation_expr, time_gate)) - if pre_value is not None: - piecewise_args.append((pre_value, t < entry_sequence[0].t.sympified_expression)) - if post_value is not None: - piecewise_args.append((post_value, t >= entry_sequence[-1].t.sympified_expression)) + if post_value is None and pre_value is None: + piecewise_args.append((0, True)) return ExpressionScalar(sympy.Piecewise(*piecewise_args)) From 7f3e165015415294e936fce38b6c8328a781ba12 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 25 Nov 2020 15:19:34 +0100 Subject: [PATCH 6/8] Fix tests --- .../pulses/multi_channel_pulse_template.py | 6 +++ tests/pulses/point_pulse_template_tests.py | 19 ++++----- tests/pulses/pulse_template_tests.py | 2 +- tests/pulses/table_pulse_template_tests.py | 42 ++++++++----------- 4 files changed, 34 insertions(+), 35 deletions(-) diff --git a/qupulse/pulses/multi_channel_pulse_template.py b/qupulse/pulses/multi_channel_pulse_template.py index 5b8c2a50e..264b342c7 100644 --- a/qupulse/pulses/multi_channel_pulse_template.py +++ b/qupulse/pulses/multi_channel_pulse_template.py @@ -199,6 +199,12 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: expressions.update(subtemplate.integral) return expressions + def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: + expressions = dict() + for subtemplate in self._subtemplates: + expressions.update(subtemplate._as_expression()) + return expressions + class ParallelConstantChannelPulseTemplate(PulseTemplate): def __init__(self, diff --git a/tests/pulses/point_pulse_template_tests.py b/tests/pulses/point_pulse_template_tests.py index 07e858243..3962af181 100644 --- a/tests/pulses/point_pulse_template_tests.py +++ b/tests/pulses/point_pulse_template_tests.py @@ -154,10 +154,8 @@ def test_build_waveform_multi_channel_same(self): (1., 0., HoldInterpolationStrategy()), (1.1, 21., LinearInterpolationStrategy())]) self.assertEqual(wf.defined_channels, {1, 'A'}) - self.assertEqual(wf._wf_pad[1][0].defined_channels, {1}) - self.assertEqual(wf._wf_pad[1][0], expected_1) - self.assertEqual(wf._wf_pad['A'][0].defined_channels, {'A'}) - self.assertEqual(wf._wf_pad['A'][0], expected_A) + self.assertEqual(wf._sub_waveforms[0], expected_1) + self.assertEqual(wf._sub_waveforms[1], expected_A) def test_build_waveform_multi_channel_vectorized(self): ppt = PointPulseTemplate([('t1', 'A'), @@ -175,10 +173,8 @@ def test_build_waveform_multi_channel_vectorized(self): (1., 0., HoldInterpolationStrategy()), (1.1, 20., LinearInterpolationStrategy())]) self.assertEqual(wf.defined_channels, {1, 'A'}) - self.assertEqual(wf._wf_pad[1][0].defined_channels, {1}) - self.assertEqual(wf._wf_pad[1][0], expected_1) - self.assertEqual(wf._wf_pad['A'][0].defined_channels, {'A'}) - self.assertEqual(wf._wf_pad['A'][0], expected_A) + self.assertEqual(wf._sub_waveforms[0], expected_1) + self.assertEqual(wf._sub_waveforms[1], expected_A) def test_build_waveform_none_channel(self): ppt = PointPulseTemplate([('t1', 'A'), @@ -325,9 +321,12 @@ def test_integral_as_expression_compatible(self): symbolic = sympy.integrate(ch_expr, (t, 0, duration)) symbolic = sympy.simplify(symbolic) + scalar_from_as_expr = ExpressionScalar(symbolic) + scalar_from_integral = ExpressionScalar(ch_int) + for parameters in self.parameter_sets: - num_from_expr = ExpressionScalar(symbolic).evaluate_in_scope(parameters) - num_from_in = ExpressionScalar(ch_int).evaluate_in_scope(parameters) + num_from_expr = scalar_from_as_expr.evaluate_in_scope(parameters) + num_from_in = scalar_from_integral.evaluate_in_scope(parameters) np.testing.assert_almost_equal(num_from_in, num_from_expr) # TODO: the following fails even with a lot of assumptions in sympy 1.6 diff --git a/tests/pulses/pulse_template_tests.py b/tests/pulses/pulse_template_tests.py index f78793685..d765ea8ef 100644 --- a/tests/pulses/pulse_template_tests.py +++ b/tests/pulses/pulse_template_tests.py @@ -356,7 +356,7 @@ class AtomicPulseTemplateTests(unittest.TestCase): def test_internal_create_program(self) -> None: measurement_windows = [('M', 0, 5)] single_wf = DummyWaveform(duration=6, defined_channels={'A'}) - wf = MultiChannelWaveform.from_iterable([single_wf]) + wf = MultiChannelWaveform([single_wf]) template = AtomicPulseTemplateStub(measurements=measurement_windows, parameter_names={'foo'}) scope = DictScope.from_kwargs(foo=7.2, volatile={'gutes_zeuch'}) diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index 580fd9b7a..14b9f0d1f 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -43,16 +43,16 @@ def test_sequence_integral(self): def get_sympy(v): return v.sympified_expression - entries = [TableEntry(0, 0, 'hold'), TableEntry(1, 0, 'hold')] + entries = [TableEntry(0, 0), TableEntry(1, 0, 'hold')] self.assertEqual(ExpressionScalar(0), TableEntry._sequence_integral(entries, get_sympy)) - entries = [TableEntry(0, 1, 'hold'), TableEntry(1, 1, 'hold')] + entries = [TableEntry(0, 1), TableEntry(1, 1, 'hold')] self.assertEqual(ExpressionScalar(1), TableEntry._sequence_integral(entries, get_sympy)) - entries = [TableEntry(0, 0, 'linear'), TableEntry(1, 1, 'hold')] + entries = [TableEntry(0, 0), TableEntry(1, 1, 'linear')] self.assertEqual(ExpressionScalar(.5), TableEntry._sequence_integral(entries, get_sympy)) - entries = [TableEntry('t0', 'a', 'linear'), TableEntry('t1', 'b', 'hold'), TableEntry('t2', 'c', 'hold')] + entries = [TableEntry('t0', 'a', 'linear'), TableEntry('t1', 'b', 'linear'), TableEntry('t2', 'c', 'hold')] self.assertEqual(ExpressionScalar('(t1-t0)*(a+b)/2 + (t2-t1)*b'), TableEntry._sequence_integral(entries, get_sympy)) @@ -71,21 +71,21 @@ def get_sympy(v): entries = [TableEntry(0, 0, None), TableEntry(1, 0, 'hold')] self.assertEqual(ExpressionScalar(0), - TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) + TableEntry._sequence_as_expression(entries, get_sympy, t, pre_value=None, post_value=None).sympified_expression.subs(times)) entries = [TableEntry(0, 1, None), TableEntry(1, 1, 'hold')] self.assertEqual(ExpressionScalar(1), - TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) + TableEntry._sequence_as_expression(entries, get_sympy, t, pre_value=None, post_value=None).sympified_expression.subs(times)) entries = [TableEntry(0, 0, None), TableEntry(1, 1, 'linear')] self.assertEqual(ExpressionScalar(.5), - TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) + TableEntry._sequence_as_expression(entries, get_sympy, t, pre_value=None, post_value=None).sympified_expression.subs(times)) entries = [TableEntry('t0', 'a', 'linear'), - TableEntry('t1', 'b', 'hold'), + TableEntry('t1', 'b', 'linear'), TableEntry('t2', 'c', 'hold')] self.assertEqual(ExpressionScalar('(a+b)*.5'), - TableEntry._sequence_as_expression(entries, get_sympy, t).sympified_expression.subs(times)) + TableEntry._sequence_as_expression(entries, get_sympy, t, pre_value=None, post_value=None).sympified_expression.subs(times)) class TablePulseTemplateTest(unittest.TestCase): @@ -690,23 +690,17 @@ def test_build_waveform_multi_channel(self): channel_mapping=channel_mapping) self.assertIsInstance(waveform, MultiChannelWaveform) - self.assertEqual(len(waveform._wf_pad), 2) - - channels = {'oh', 'ch'} - for wf, _ in waveform._wf_pad.values(): - self.assertIsInstance(wf, TableWaveform) - self.assertIn(wf._channel_id, channels) - channels.remove(wf._channel_id) - if wf.defined_channels == {'ch'}: - self.assertEqual(wf._table, - ((0, 0, HoldInterpolationStrategy()), + + expected_waveforms = [ + TableWaveform('ch', ((0, 0, HoldInterpolationStrategy()), (1.1, 2.3, LinearInterpolationStrategy()), (4, 0, JumpInterpolationStrategy()), - (5.1, 0, HoldInterpolationStrategy()))) - elif wf.defined_channels == {'oh'}: - self.assertEqual(wf._table, - ((0, 1, HoldInterpolationStrategy()), - (5.1, 0, LinearInterpolationStrategy()))) + (5.1, 0, HoldInterpolationStrategy()))), + TableWaveform('oh', ((0, 1, HoldInterpolationStrategy()), + (5.1, 0, LinearInterpolationStrategy()))), + ] + + self.assertEqual(waveform._sub_waveforms, tuple(expected_waveforms)) def test_build_waveform_none(self) -> None: table = TablePulseTemplate({0: [(0, 0), From e92a1a09f98860200175c948baaccdc69a5d9760 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 25 Nov 2020 15:21:09 +0100 Subject: [PATCH 7/8] Improve error message and repr for some waveform code --- qupulse/_program/waveforms.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index fcfed0810..25552b63c 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -81,8 +81,9 @@ def get_sampled(self, if np.any(sample_times[:-1] >= sample_times[1:]): raise ValueError('The sample times are not monotonously increasing') - if sample_times[0] < 0 or sample_times[-1] > self.duration: - raise ValueError('The sample times are not in the range [0, duration]') + if sample_times[0] < 0 or sample_times[-1] > float(self.duration): + raise ValueError(f'The sample times [{sample_times[0]}, ..., {sample_times[-1]}] are not in the range' + f' [0, duration={float(self.duration)}]') if channel not in self.defined_channels: raise KeyError('Channel not defined in this waveform: {}'.format(channel)) @@ -144,6 +145,9 @@ def __init__(self, t: float, v: float, interp: InterpolationStrategy): if not callable(interp): raise TypeError('{} is neither callable nor of type InterpolationStrategy'.format(interp)) + def __repr__(self): + return f'{type(self).__name__}(t={self.t}, v={self.v}, interp="{self.interp}")' + class TableWaveform(Waveform): EntryInInit = Union[TableWaveformEntry, Tuple[float, float, InterpolationStrategy]] @@ -231,6 +235,9 @@ def defined_channels(self) -> Set[ChannelID]: def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform': return self + def __repr__(self): + return f'{type(self).__name__}(channel={self._channel_id}, waveform_table={self._table})' + class FunctionWaveform(Waveform): """Waveform obtained from instantiating a FunctionPulseTemplate.""" From 4d3729d7537a90432592d2c38705f9afc1319cab Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 25 Nov 2020 16:06:33 +0100 Subject: [PATCH 8/8] Add test for AMCPT._as_expression --- tests/pulses/multi_channel_pulse_template_tests.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/pulses/multi_channel_pulse_template_tests.py b/tests/pulses/multi_channel_pulse_template_tests.py index 3e920c7d3..8e8be1318 100644 --- a/tests/pulses/multi_channel_pulse_template_tests.py +++ b/tests/pulses/multi_channel_pulse_template_tests.py @@ -10,6 +10,7 @@ from qupulse.pulses.parameters import ParameterConstraint, ParameterConstraintViolation, ConstantParameter from qupulse.expressions import ExpressionScalar, Expression from qupulse._program.transformation import LinearTransformation, chain_transformations +from qupulse.utils.sympy import sympify from tests.pulses.sequencing_dummies import DummyPulseTemplate, DummyWaveform from tests.serialization_dummies import DummySerializer @@ -193,6 +194,17 @@ def test_integral(self) -> None: 'C': ExpressionScalar('l')}, pulse.integral) + def test_as_expression(self): + sts = [DummyPulseTemplate(duration='t1', defined_channels={'A'}, + integrals={'A': ExpressionScalar('2+k')}), + DummyPulseTemplate(duration='t1', defined_channels={'B', 'C'}, + integrals={'B': ExpressionScalar('t1-t0*3.1'), 'C': ExpressionScalar('l')})] + pulse = AtomicMultiChannelPulseTemplate(*sts) + self.assertEqual({'A': ExpressionScalar(sympify('(2+k) / t1') * pulse._AS_EXPRESSION_TIME), + 'B': ExpressionScalar(sympify('(t1-t0*3.1)/t1') * pulse._AS_EXPRESSION_TIME), + 'C': ExpressionScalar(sympify('l/t1') * pulse._AS_EXPRESSION_TIME)}, + pulse._as_expression()) + class MultiChannelPulseTemplateSequencingTests(unittest.TestCase): def test_build_waveform(self):