Skip to content

Commit

Permalink
Merge pull request #581 from qutech/fix/repetition_pt_integral
Browse files Browse the repository at this point in the history
Fix RepetitionPT.integral return type
  • Loading branch information
terrorfisch committed May 10, 2021
2 parents 9a1a0a9 + ded16a2 commit b98df08
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion qupulse/pulses/repetition_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def deserialize(cls, serializer: Optional[Serializer]=None, **kwargs) -> 'Repeti
@property
def integral(self) -> Dict[ChannelID, ExpressionScalar]:
body_integral = self.body.integral
return [self.repetition_count * c for c in body_integral]
return {channel: self.repetition_count * value for channel, value in body_integral.items()}


class ParameterNotIntegerException(Exception):
Expand Down
10 changes: 5 additions & 5 deletions tests/pulses/repetition_pulse_template_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from qupulse.utils.types import FrozenDict

from qupulse._program._loop import Loop
from qupulse.expressions import Expression
from qupulse.expressions import Expression, ExpressionScalar
from qupulse.pulses.repetition_pulse_template import RepetitionPulseTemplate,ParameterNotIntegerException
from qupulse.pulses.parameters import ParameterNotProvidedException, ParameterConstraintViolation, ConstantParameter, \
ParameterConstraint
Expand Down Expand Up @@ -77,15 +77,15 @@ def test_duration(self):
self.assertEqual(t.duration, Expression('foo*bar'))

def test_integral(self) -> None:
dummy = DummyPulseTemplate(integrals=['foo+2', 'k*3+x**2'])
dummy = DummyPulseTemplate(integrals={'A': ExpressionScalar('foo+2'), 'B': ExpressionScalar('k*3+x**2')})
template = RepetitionPulseTemplate(dummy, 7)
self.assertEqual([Expression('7*(foo+2)'), Expression('7*(k*3+x**2)')], template.integral)
self.assertEqual({'A': Expression('7*(foo+2)'), 'B': Expression('7*(k*3+x**2)')}, template.integral)

template = RepetitionPulseTemplate(dummy, '2+m')
self.assertEqual([Expression('(2+m)*(foo+2)'), Expression('(2+m)*(k*3+x**2)')], template.integral)
self.assertEqual({'A': Expression('(2+m)*(foo+2)'), 'B': Expression('(2+m)*(k*3+x**2)')}, template.integral)

template = RepetitionPulseTemplate(dummy, Expression('2+m'))
self.assertEqual([Expression('(2+m)*(foo+2)'), Expression('(2+m)*(k*3+x**2)')], template.integral)
self.assertEqual({'A': Expression('(2+m)*(foo+2)'), 'B': Expression('(2+m)*(k*3+x**2)')}, template.integral)

def test_parameter_names_param_only_in_constraint(self) -> None:
pt = RepetitionPulseTemplate(DummyPulseTemplate(parameter_names={'a'}), 'n', parameter_constraints=['a<c'])
Expand Down
5 changes: 4 additions & 1 deletion tests/pulses/sequencing_dummies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""STANDARD LIBRARY IMPORTS"""
from typing import Tuple, List, Dict, Optional, Set, Any, Union
from typing import Tuple, List, Dict, Optional, Set, Any, Union, Mapping
import copy

import numpy
Expand Down Expand Up @@ -207,6 +207,9 @@ def __init__(self,
self._program = program
self._register(registry=registry)

if integrals is not None:
assert isinstance(integrals, Mapping)

@property
def duration(self):
return self._duration
Expand Down

0 comments on commit b98df08

Please sign in to comment.