Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 68 additions & 24 deletions qupulse/_program/_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ class Loop(Node):
MAX_REPR_SIZE = 2000
__slots__ = ('_waveform', '_measurements', '_repetition_count', '_cached_body_duration', '_repetition_parameter')

"""Build a loop tree. The leaves of the tree are loops with one element."""
"""Build a loop tree. The leaves of the tree are loops with one element.

Loop objects are equal if all children are/the waveform is equal, the repetition count is equal
"""
def __init__(self,
parent: Union['Loop', None] = None,
children: Iterable['Loop'] = (),
Expand All @@ -39,7 +42,7 @@ def __init__(self,
waveform:
measurements:
repetition_count:
repetition_expression:
repetition_parameter:
"""
super().__init__(parent=parent, children=children)

Expand All @@ -49,16 +52,19 @@ def __init__(self,
self._repetition_parameter = repetition_parameter
self._cached_body_duration = None

if abs(self._repetition_count - repetition_count) > 1e-10:
raise ValueError('Repetition count was not an integer')

if not isinstance(waveform, (type(None), Waveform)):
raise Exception()

@property
def compare_key(self) -> Tuple:
return self._waveform, self.repetition_count, self._measurements if self._measurements else None,\
super().compare_key
assert self._repetition_count == repetition_count, "Repetition count was not an integer: %r" % repetition_count
assert isinstance(waveform, (type(None), Waveform))

def __eq__(self, other: 'Loop') -> bool:
if type(self) == type(other):
return (self._repetition_count == other._repetition_count and
self.waveform == other.waveform and
(self._measurements or None) == (other._measurements or None) and
self._repetition_parameter == other._repetition_parameter and
len(self) == len(other) and
all(self_child == other_child for self_child, other_child in zip(self, other)))
else:
return NotImplemented

def append_child(self, loop: Optional['Loop']=None, **kwargs) -> None:
# do not invalidate but update cached duration
Expand Down Expand Up @@ -118,7 +124,7 @@ def body_duration(self) -> TimeType:

@property
def duration(self) -> TimeType:
return self.repetition_count*self.body_duration
return self.body_duration * self.repetition_count

@property
def repetition_parameter(self) -> Optional[MappedParameter]:
Expand All @@ -138,6 +144,8 @@ def repetition_count(self, val) -> None:
def unroll(self) -> None:
if self.is_leaf():
raise RuntimeError('Leaves cannot be unrolled')
if self.repetition_parameter is not None:
warnings.warn("Unrolling a Loop with volatile repetition count", VolatileModificationWarning)

i = self.parent_index
self.parent[i:i+1] = (child.copy_tree_structure(new_parent=self.parent)
Expand All @@ -150,19 +158,24 @@ def __setitem__(self, idx, value):
self._invalidate_duration()

def unroll_children(self) -> None:
if self._repetition_parameter is not None:
warnings.warn("Unrolling a Loop with volatile repetition count", VolatileModificationWarning)
old_children = self.children
self[:] = (child.copy_tree_structure()
for _ in range(self.repetition_count)
for child in old_children)
self.repetition_count = 1
self._repetition_parameter = None
self.assert_tree_integrity()

def encapsulate(self) -> None:
self[:] = [Loop(children=self,
repetition_count=self.repetition_count,
repetition_parameter=self._repetition_parameter,
waveform=self._waveform,
measurements=self._measurements)]
self.repetition_count = 1
self._repetition_parameter = None
self._waveform = None
self._measurements = None
self.assert_tree_integrity()
Expand Down Expand Up @@ -197,6 +210,7 @@ def copy_tree_structure(self, new_parent: Union['Loop', bool]=False) -> 'Loop':
return type(self)(parent=self.parent if new_parent is False else new_parent,
waveform=self._waveform,
repetition_count=self.repetition_count,
repetition_parameter=self._repetition_parameter,
measurements=self._measurements,
children=(child.copy_tree_structure() for child in self))

Expand Down Expand Up @@ -243,18 +257,30 @@ def get_measurement_windows(self) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
def split_one_child(self, child_index=None) -> None:
"""Take the last child that has a repetition count larger one, decrease it's repetition count and insert a copy
with repetition cout one after it"""
if child_index:
if child_index is not None:
if self[child_index].repetition_count < 2:
raise ValueError('Cannot split child {} as the repetition count is not larger 1')

else:
try:
child_index = next(i for i in reversed(range(len(self)))
if self[i].repetition_count > 1)
except StopIteration:
raise RuntimeError('There is no child with repetition count > 1')
for i, child in enumerate(reversed(self)):
if child.repetition_count > 1:
if child.repetition_parameter is None:
child_index = i
break
elif child_index is None:
child_index = i
else:
if child_index is None:
raise RuntimeError('There is no child with repetition count > 1')

if self[child_index]._repetition_parameter is not None:
warnings.warn("Splitting a child with volatile repetition count", VolatileModificationWarning)
self[child_index]._repetition_parameter = MappedParameter(expression=self[child_index]._repetition_parameter.expression - 1,
dependencies=self[child_index]._repetition_parameter.dependencies)

new_child = self[child_index].copy_tree_structure()
new_child.repetition_count = 1
new_child._repetition_parameter = None

self[child_index].repetition_count -= 1

Expand All @@ -281,12 +307,30 @@ def flatten_and_balance(self, depth: int) -> None:
elif sub_program.depth() == depth - 1:
i += 1

elif len(sub_program) == 1 and len(sub_program[0]) == 1:
elif len(sub_program) == 1 and len(sub_program[0]) == 1 and not sub_program._measurements:
sub_sub_program = cast(Loop, sub_program[0])

sub_program.repetition_count = sub_program.repetition_count * sub_sub_program.repetition_count
measurements = sub_sub_program._measurements
repetition_count = sub_program.repetition_count * sub_sub_program.repetition_count
if sub_program._repetition_parameter is None and sub_sub_program._repetition_parameter is None:
repetition_parameter = None
else:
if sub_program._repetition_parameter is None:
repetition_parameter = MappedParameter(expression=sub_sub_program._repetition_parameter.expression * sub_program.repetition_count,
dependencies=sub_sub_program._repetition_parameter.dependencies)
elif sub_sub_program._repetition_parameter is None:
repetition_parameter = MappedParameter(expression=sub_program._repetition_parameter.expression * sub_sub_program.repetition_count,
dependencies=sub_program._repetition_parameter.dependencies)
else:
# TODO: possible but requires complicated code elsewhere
repetition_parameter = None

sub_program[:] = sub_sub_program[:]
sub_program.waveform = sub_sub_program.waveform
sub_program._waveform = sub_sub_program._waveform
sub_program._repetition_parameter = repetition_parameter
sub_program._repetition_count = repetition_count
sub_program._measurements = measurements
sub_program._invalidate_duration()

elif not sub_program.is_leaf():
sub_program.unroll()
Expand Down Expand Up @@ -332,7 +376,7 @@ def cleanup(self):
elif child._measurements:
warnings.warn("Dropping measurement since there is no waveform in children")

if len(new_children) == 1 and not self._measurements:
if len(new_children) == 1 and not self._measurements and not self._repetition_parameter:
assert not self._waveform
only_child = new_children[0]

Expand All @@ -344,7 +388,7 @@ def cleanup(self):
elif len(self) != len(new_children):
self[:] = new_children

def get_duration_structure(self) -> Tuple[int, Union[int, tuple]]:
def get_duration_structure(self) -> Tuple[int, Union[TimeType, tuple]]:
if self.is_leaf():
return self.repetition_count, self.waveform.duration
else:
Expand Down
4 changes: 2 additions & 2 deletions qupulse/comparable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""This module defines the abstract Comparable class."""
from abc import abstractmethod
from typing import Any
from typing import Hashable, Any

from qupulse.utils.types import DocStringABCMeta

Expand All @@ -20,7 +20,7 @@ class Comparable(metaclass=DocStringABCMeta):

@property
@abstractmethod
def compare_key(self) -> Any:
def compare_key(self) -> Hashable:
"""Return a unique key used in comparison and hashing operations.

The key must describe the essential properties of the object.
Expand Down
8 changes: 6 additions & 2 deletions qupulse/pulses/loop_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ def _internal_create_program(self, *,
channel_mapping: Dict[ChannelID, Optional[ChannelID]],
global_transformation: Optional['Transformation'],
to_single_waveform: Set[Union[str, 'PulseTemplate']],
parent_loop: Loop) -> None:
parent_loop: Loop,
volatile: Set[str]) -> None:
self.validate_parameter_constraints(parameters=parameters)

try:
Expand All @@ -245,6 +246,8 @@ def _internal_create_program(self, *,
for parameter_name in self.duration.variables}
except KeyError as e:
raise ParameterNotProvidedException(str(e)) from e
assert not volatile.intersection(measurement_parameters.keys())
assert not volatile.intersection(duration_parameters.keys())

if self.duration.evaluate_numeric(**duration_parameters) > 0:
measurements = self.get_measurement_windows(measurement_parameters, measurement_mapping)
Expand All @@ -257,7 +260,8 @@ def _internal_create_program(self, *,
channel_mapping=channel_mapping,
global_transformation=global_transformation,
to_single_waveform=to_single_waveform,
parent_loop=parent_loop)
parent_loop=parent_loop,
volatile=volatile)

def build_waveform(self, parameters: Dict[str, Parameter]) -> ForLoopWaveform:
return ForLoopWaveform([self.body.build_waveform(local_parameters)
Expand Down
26 changes: 24 additions & 2 deletions qupulse/pulses/mapping_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,26 @@ def map_parameters(self,
else:
raise TypeError('Values of parameter dict are neither all Parameter nor Real')

def map_volatile(self, volatile: Set[str]) -> Set[str]:
"""Deduce set of inner volatile parameters.

TODO: Does not handle the case of dropped dependencies i.e.:
x is volatile but a == 0 => y is actually not volatile
y = a * x + m

Args:
volatile: a set of outer volatile parameters

Returns:
Set of inner volatile parameters
"""
if volatile:
return {parameter
for parameter, mapping_function in self.__parameter_mapping.items()
if volatile.intersection(mapping_function.variables)}
else:
return volatile

def get_updated_measurement_mapping(self, measurement_mapping: Dict[str, str]) -> Dict[str, str]:
return {k: measurement_mapping[v] for k, v in self.__measurement_mapping.items()}

Expand Down Expand Up @@ -328,14 +348,16 @@ def _internal_create_program(self, *,
channel_mapping: Dict[ChannelID, Optional[ChannelID]],
global_transformation: Optional['Transformation'],
to_single_waveform: Set[Union[str, 'PulseTemplate']],
parent_loop: Loop) -> None:
parent_loop: Loop,
volatile: Set[str]) -> None:
# parameters are validated in map_parameters() call, no need to do it here again explicitly
self.template._create_program(parameters=self.map_parameter_objects(parameters),
measurement_mapping=self.get_updated_measurement_mapping(measurement_mapping),
channel_mapping=self.get_updated_channel_mapping(channel_mapping),
global_transformation=global_transformation,
to_single_waveform=to_single_waveform,
parent_loop=parent_loop)
parent_loop=parent_loop,
volatile=self.map_volatile(volatile))

def build_waveform(self,
parameters: Dict[str, numbers.Real],
Expand Down
38 changes: 22 additions & 16 deletions qupulse/pulses/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
"""

from abc import abstractmethod
from typing import Optional, Union, Dict, Any, Iterable, Set, List
from typing import Optional, Union, Dict, Any, Iterable, Set, List, Mapping
from numbers import Real
import types

import sympy
import numpy
Expand Down Expand Up @@ -47,12 +48,8 @@ def requires_stop(self) -> bool:
True, if evaluating this Parameter instance requires an interruption.
"""

@abstractmethod
def __hash__(self) -> int:
"""Returns a hash value of the parameter. Must be implemented."""

def __eq__(self, other) -> bool:
return type(self) is type(other) and hash(self) == hash(other)
def __eq__(self, other: 'Parameter') -> bool:
return numpy.array_equal(self.get_value(), other.get_value())


class ConstantParameter(Parameter):
Expand Down Expand Up @@ -103,7 +100,7 @@ class MappedParameter(Parameter):

def __init__(self,
expression: Expression,
dependencies: Optional[Dict[str, Parameter]]=None) -> None:
dependencies: Optional[Mapping[str, Parameter]]=None) -> None:
"""Create a MappedParameter instance.

Args:
Expand All @@ -114,8 +111,17 @@ def __init__(self,
"""
super().__init__()
self._expression = expression
self.dependencies = dict() if dependencies is None else dependencies
self._cached_value = (None, None)
self._dependencies = dict() if dependencies is None else dependencies
self._cached_value = None

@property
def dependencies(self):
return types.MappingProxyType(self._dependencies)

@dependencies.setter
def dependencies(self, new_dependencies):
self._dependencies = new_dependencies
self._cached_value = None

def _collect_dependencies(self) -> Dict[str, float]:
# filter only real dependencies from the dependencies dictionary
Expand All @@ -127,13 +133,13 @@ def _collect_dependencies(self) -> Dict[str, float]:

def get_value(self) -> Union[Real, numpy.ndarray]:
"""Does not check explicitly if a parameter requires to stop."""
current_hash = hash(self)
if current_hash != self._cached_value[0]:
self._cached_value = (current_hash, self._expression.evaluate_numeric(**self._collect_dependencies()))
return self._cached_value[1]
if self._cached_value is None:
self._cached_value = self._expression.evaluate_numeric(**self._collect_dependencies())
return self._cached_value

def __hash__(self):
return hash(tuple(self.dependencies.items()))
@property
def expression(self):
return self._expression

@property
def requires_stop(self) -> bool:
Expand Down
7 changes: 5 additions & 2 deletions qupulse/pulses/pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def _create_program(self, *,
channel_mapping=channel_mapping,
to_single_waveform=to_single_waveform,
global_transformation=global_transformation,
parent_loop=parent_loop)
parent_loop=parent_loop,
volatile=volatile)


class AtomicPulseTemplate(PulseTemplate, MeasurementDefiner):
Expand Down Expand Up @@ -271,7 +272,8 @@ def _internal_create_program(self, *,
channel_mapping: Dict[ChannelID, Optional[ChannelID]],
global_transformation: Optional[Transformation],
to_single_waveform: Set[Union[str, 'PulseTemplate']],
parent_loop: Loop) -> None:
parent_loop: Loop,
volatile: Set[str]) -> None:
"""Parameter constraints are validated in build_waveform because build_waveform is guaranteed to be called
during sequencing"""
### current behavior (same as previously): only adds EXEC Loop and measurements if a waveform exists.
Expand All @@ -285,6 +287,7 @@ def _internal_create_program(self, *,
except KeyError as e:
raise ParameterNotProvidedException(str(e)) from e

assert not volatile.intersection(parameters), "not supported"
waveform = self.build_waveform(parameters=parameters,
channel_mapping=channel_mapping)
if waveform:
Expand Down
Loading