Skip to content

Commit

Permalink
Merge b938898 into d2581a9
Browse files Browse the repository at this point in the history
  • Loading branch information
terrorfisch committed Oct 5, 2022
2 parents d2581a9 + b938898 commit 8617fcf
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 147 deletions.
147 changes: 4 additions & 143 deletions qupulse/pulses/loop_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from qupulse.pulses.pulse_template import PulseTemplate, ChannelID, AtomicPulseTemplate
from qupulse._program.waveforms import SequenceWaveform as ForLoopWaveform
from qupulse.pulses.measurement import MeasurementDefiner, MeasurementDeclaration
from qupulse.pulses.range import ParametrizedRange, RangeScope

__all__ = ['ForLoopPulseTemplate', 'LoopPulseTemplate', 'LoopIndexNotUsedException']

Expand All @@ -45,54 +46,6 @@ def measurement_names(self) -> Set[str]:
return self.__body.measurement_names


class ParametrizedRange:
"""Like the builtin python range but with parameters."""
def __init__(self, *args, **kwargs):
"""Positional and keyword arguments cannot be mixed.
Args:
*args: Interpreted as ``(start, )`` or ``(start, stop[, step])``
**kwargs: Expected to contain ``start``, ``stop`` and ``step``
Raises:
TypeError: If positional and keyword arguments are mixed
KeyError: If keyword arguments but one of ``start``, ``stop`` or ``step`` is missing
"""
if args and kwargs:
raise TypeError('ParametrizedRange only takes either positional or keyword arguments')
elif kwargs:
start = kwargs['start']
stop = kwargs['stop']
step = kwargs['step']
elif len(args) in (1, 2, 3):
if len(args) == 3:
start, stop, step = args
elif len(args) == 2:
(start, stop), step = args, 1
elif len(args) == 1:
start, (stop,), step = 0, args, 1
else:
raise TypeError('ParametrizedRange expected 1 to 3 arguments, got {}'.format(len(args)))

self.start = ExpressionScalar.make(start)
self.stop = ExpressionScalar.make(stop)
self.step = ExpressionScalar.make(step)

def to_tuple(self) -> Tuple[Any, Any, Any]:
"""Return a simple representation of the range which is useful for comparison and serialization"""
return (self.start.get_serialization_data(),
self.stop.get_serialization_data(),
self.step.get_serialization_data())

def to_range(self, parameters: Mapping[str, Number]) -> range:
return range(checked_int_cast(self.start.evaluate_in_scope(parameters)),
checked_int_cast(self.stop.evaluate_in_scope(parameters)),
checked_int_cast(self.step.evaluate_in_scope(parameters)))

@property
def parameter_names(self) -> Set[str]:
return set(self.start.variables) | set(self.stop.variables) | set(self.step.variables)


class ForLoopPulseTemplate(LoopPulseTemplate, MeasurementDefiner, ParameterConstrainer):
"""This pulse template allows looping through an parametrized integer range and provides the loop index as a
parameter to the body. If you do not need the index in the pulse template, consider using
Expand Down Expand Up @@ -122,18 +75,7 @@ def __init__(self,
MeasurementDefiner.__init__(self, measurements=measurements)
ParameterConstrainer.__init__(self, parameter_constraints=parameter_constraints)

if isinstance(loop_range, ParametrizedRange):
self._loop_range = loop_range
elif isinstance(loop_range, (int, str)):
self._loop_range = ParametrizedRange(loop_range)
elif isinstance(loop_range, (tuple, list)):
self._loop_range = ParametrizedRange(*loop_range)
elif isinstance(loop_range, range):
self._loop_range = ParametrizedRange(start=loop_range.start,
stop=loop_range.stop,
step=loop_range.step)
else:
raise ValueError('loop_range is not valid')
self._loop_range = ParametrizedRange.from_range_like(loop_range)

if not loop_index.isidentifier():
raise InvalidParameterNameException(loop_index)
Expand Down Expand Up @@ -198,15 +140,8 @@ def _body_scope_generator(self, scope: Scope, forward=True) -> Iterator[Scope]:
loop_range = loop_range if forward else reversed(loop_range)
loop_index_name = self._loop_index

get_for_loop_scope = _get_for_loop_scope

for loop_index_value in loop_range:
try:
yield get_for_loop_scope(scope, loop_index_name, loop_index_value)
except TypeError:
# we cannot hash the scope so we will not try anymore
get_for_loop_scope = _ForLoopScope
yield get_for_loop_scope(scope, loop_index_name, loop_index_value)
yield _ForLoopScope(scope, loop_index_name, loop_index_value)

def _internal_create_program(self, *,
scope: Scope,
Expand Down Expand Up @@ -301,78 +236,4 @@ def __str__(self) -> str:
self.body_parameter_names)


class _ForLoopScope(Scope):
__slots__ = ('_index_name', '_index_value', '_inner')

def __init__(self, inner: Scope, index_name: str, index_value: int):
super().__init__()
self._inner = inner
self._index_name = index_name
self._index_value = index_value

def get_volatile_parameters(self) -> FrozenMapping[str, Expression]:
inner_volatile = self._inner.get_volatile_parameters()

if self._index_name in inner_volatile:
# TODO: use delete method of frozendict
index_name = self._index_name
return FrozenDict((name, value) for name, value in inner_volatile.items() if name != index_name)
else:
return inner_volatile

def __hash__(self):
return hash((self._inner, self._index_name, self._index_value))

def __eq__(self, other: '_ForLoopScope'):
try:
return (self._index_name == other._index_name
and self._index_value == other._index_value
and self._inner == other._inner)
except AttributeError:
return False

def __contains__(self, item):
return item == self._index_name or item in self._inner

def get_parameter(self, parameter_name: str) -> Number:
if parameter_name == self._index_name:
return self._index_value
else:
return self._inner.get_parameter(parameter_name)

__getitem__ = get_parameter

def change_constants(self, new_constants: Mapping[str, Number]) -> 'Scope':
return _get_for_loop_scope(self._inner.change_constants(new_constants), self._index_name, self._index_value)

def __len__(self) -> int:
return len(self._inner) + int(self._index_name not in self._inner)

def __iter__(self) -> Iterator:
if self._index_name in self._inner:
return iter(self._inner)
else:
return itertools.chain(self._inner, (self._index_name,))

def as_dict(self) -> FrozenMapping[str, Number]:
if self._as_dict is None:
self._as_dict = FrozenDict({**self._inner.as_dict(), self._index_name: self._index_value})
return self._as_dict

def keys(self):
return self.as_dict().keys()

def items(self):
return self.as_dict().items()

def values(self):
return self.as_dict().values()

def __repr__(self):
return f'{type(self)}(inner={self._inner!r}, index_name={self._index_name!r}, ' \
f'index_value={self._index_value!r})'


@functools.lru_cache(maxsize=10**6)
def _get_for_loop_scope(inner: Scope, index_name: str, index_value: int) -> Scope:
return _ForLoopScope(inner, index_name, index_value)
_ForLoopScope = RangeScope
156 changes: 156 additions & 0 deletions qupulse/pulses/range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import Tuple, Any, AbstractSet, Mapping, Union, Iterator
from numbers import Number
from dataclasses import dataclass
from functools import lru_cache
import itertools

from qupulse.utils import checked_int_cast, cached_property
from qupulse.expressions import ExpressionScalar, ExpressionVariableMissingException, ExpressionLike, Expression
from qupulse.parameter_scope import Scope
from qupulse.utils.types import FrozenDict, FrozenMapping

RangeLike = Union[range,
ExpressionLike,
Tuple[ExpressionLike, ExpressionLike],
Tuple[ExpressionLike, ExpressionLike, ExpressionLike]]


@dataclass(frozen=True)
class ParametrizedRange:
start: ExpressionScalar
stop: ExpressionScalar
step: ExpressionScalar

def __init__(self, *args, **kwargs):
"""Like the builtin python range but with parameters. Positional and keyword arguments cannot be mixed.
Args:
*args: Interpreted as ``(start, )`` or ``(start, stop[, step])``
**kwargs: Expected to contain ``start``, ``stop`` and ``step``
Raises:
TypeError: If positional and keyword arguments are mixed
KeyError: If keyword arguments but one of ``start``, ``stop`` or ``step`` is missing
"""
if args and kwargs:
raise TypeError('ParametrizedRange only takes either positional or keyword arguments')
elif kwargs:
start = kwargs['start']
stop = kwargs['stop']
step = kwargs['step']
elif len(args) in (1, 2, 3):
if len(args) == 3:
start, stop, step = args
elif len(args) == 2:
(start, stop), step = args, 1
else:
start, (stop,), step = 0, args, 1
else:
raise TypeError('ParametrizedRange expected 1 to 3 arguments, got {}'.format(len(args)), args)

object.__setattr__(self, 'start', ExpressionScalar(start))
object.__setattr__(self, 'stop', ExpressionScalar(stop))
object.__setattr__(self, 'step', ExpressionScalar(step))

@lru_cache(maxsize=1024)
def to_tuple(self) -> Tuple[Any, Any, Any]:
"""Return a simple representation of the range which is useful for comparison and serialization"""
return (self.start.get_serialization_data(),
self.stop.get_serialization_data(),
self.step.get_serialization_data())

def to_range(self, parameters: Mapping[str, Number]) -> range:
return range(checked_int_cast(self.start.evaluate_in_scope(parameters)),
checked_int_cast(self.stop.evaluate_in_scope(parameters)),
checked_int_cast(self.step.evaluate_in_scope(parameters)))

@cached_property
def parameter_names(self) -> AbstractSet[str]:
return set(self.start.variables) | set(self.stop.variables) | set(self.step.variables)

@classmethod
def from_range_like(cls, range_like: RangeLike):
if isinstance(range_like, cls):
return range_like
elif isinstance(range_like, (tuple, list)):
return cls(*range_like)
elif isinstance(range_like, range):
return cls(range_like.start, range_like.stop, range_like.step)
elif isinstance(range_like, slice):
raise TypeError("Cannot construct a range from a slice")
else:
return cls(range_like)

def get_serialization_data(self):
return self.to_tuple()


class RangeScope(Scope):
__slots__ = ('_index_name', '_index_value', '_inner')

def __init__(self, inner: Scope, index_name: str, index_value: int):
super().__init__()
self._inner = inner
self._index_name = index_name
self._index_value = index_value

def get_volatile_parameters(self) -> FrozenMapping[str, Expression]:
inner_volatile = self._inner.get_volatile_parameters()

if self._index_name in inner_volatile:
# TODO: use delete method of frozendict
index_name = self._index_name
return FrozenDict((name, value) for name, value in inner_volatile.items() if name != index_name)
else:
return inner_volatile

def __hash__(self):
return hash((self._inner, self._index_name, self._index_value))

def __eq__(self, other: 'RangeScope'):
try:
return (self._index_name == other._index_name
and self._index_value == other._index_value
and self._inner == other._inner)
except AttributeError:
return NotImplemented

def __contains__(self, item):
return item == self._index_name or item in self._inner

def get_parameter(self, parameter_name: str) -> Number:
if parameter_name == self._index_name:
return self._index_value
else:
return self._inner.get_parameter(parameter_name)

__getitem__ = get_parameter

def change_constants(self, new_constants: Mapping[str, Number]) -> 'Scope':
return RangeScope(self._inner.change_constants(new_constants), self._index_name, self._index_value)

def __len__(self) -> int:
return len(self._inner) + int(self._index_name not in self._inner)

def __iter__(self) -> Iterator:
if self._index_name in self._inner:
return iter(self._inner)
else:
return itertools.chain(self._inner, (self._index_name,))

def as_dict(self) -> FrozenMapping[str, Number]:
if self._as_dict is None:
self._as_dict = FrozenDict({**self._inner.as_dict(), self._index_name: self._index_value})
return self._as_dict

def keys(self):
return self.as_dict().keys()

def items(self):
return self.as_dict().items()

def values(self):
return self.as_dict().values()

def __repr__(self):
return f'{type(self)}(inner={self._inner!r}, index_name={self._index_name!r}, ' \
f'index_value={self._index_value!r})'
8 changes: 4 additions & 4 deletions tests/pulses/loop_pulse_template_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from qupulse.expressions import Expression, ExpressionScalar
from qupulse.pulses.loop_pulse_template import ForLoopPulseTemplate, ParametrizedRange,\
LoopIndexNotUsedException, LoopPulseTemplate, _get_for_loop_scope, _ForLoopScope
LoopIndexNotUsedException, LoopPulseTemplate, _ForLoopScope, _ForLoopScope
from qupulse.pulses.parameters import ConstantParameter, InvalidParameterNameException, ParameterConstraintViolation,\
ParameterNotProvidedException, ParameterConstraint

Expand Down Expand Up @@ -100,7 +100,7 @@ def test_init(self):
with self.assertRaises(InvalidParameterNameException):
ForLoopPulseTemplate(body=dt, loop_index='1i', loop_range=6)

with self.assertRaises(ValueError):
with self.assertRaises(TypeError):
ForLoopPulseTemplate(body=dt, loop_index='i', loop_range=slice(None))

with self.assertRaises(LoopIndexNotUsedException):
Expand Down Expand Up @@ -363,7 +363,7 @@ def test_create_program(self) -> None:
to_single_waveform=to_single_waveform,
parent_loop=program)
expected_create_program_calls = [mock.call(**expected_create_program_kwargs,
scope=_get_for_loop_scope(scope, 'i', i))
scope=_ForLoopScope(scope, 'i', i))
for i in (1, 3)]

with mock.patch.object(flt, 'validate_scope') as validate_scope:
Expand Down Expand Up @@ -435,7 +435,7 @@ def assert_equal_instance_except_id(self, lhs: ForLoopPulseTemplate, rhs: ForLoo
self.assertIsInstance(rhs, ForLoopPulseTemplate)
self.assertEqual(lhs.body, rhs.body)
self.assertEqual(lhs.loop_index, rhs.loop_index)
self.assertEqual(lhs.loop_range.to_tuple(), rhs.loop_range.to_tuple())
self.assertEqual(lhs.loop_range, rhs.loop_range)
self.assertEqual(lhs.parameter_constraints, rhs.parameter_constraints)
self.assertEqual(lhs.measurement_declarations, rhs.measurement_declarations)

Expand Down

0 comments on commit 8617fcf

Please sign in to comment.