Skip to content

Commit

Permalink
Merge pull request #585 from qutech/issues/584_uninitialized_table_sa…
Browse files Browse the repository at this point in the history
…mple

Issues/584 uninitialized table sample
  • Loading branch information
terrorfisch committed May 16, 2021
2 parents 625b07f + 9b3ad35 commit 25e0f71
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 15 deletions.
42 changes: 29 additions & 13 deletions qupulse/_program/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,19 @@
from qupulse.pulses.interpolation import InterpolationStrategy
from qupulse.utils import checked_int_cast, isclose
from qupulse.utils.types import TimeType, time_from_float
from qupulse._program.transformation import Transformation
from qupulse.utils import pairwise


__all__ = ["Waveform", "TableWaveform", "TableWaveformEntry", "FunctionWaveform", "SequenceWaveform",
"MultiChannelWaveform", "RepetitionWaveform", "TransformingWaveform", "ArithmeticWaveform"]

PULSE_TO_WAVEFORM_ERROR = None # error margin in pulse template to waveform conversion

# these are private because there probably will be changes here
_ALLOCATION_FUNCTION = np.full_like # pre_allocated = ALLOCATION_FUNCTION(sample_times, **ALLOCATION_FUNCTION_KWARGS)
_ALLOCATION_FUNCTION_KWARGS = dict(fill_value=np.nan, dtype=float)


class Waveform(Comparable, metaclass=ABCMeta):
"""Represents an instantiated PulseTemplate which can be sampled to retrieve arrays of voltage
Expand Down Expand Up @@ -224,9 +231,16 @@ def unsafe_sample(self,
sample_times: np.ndarray,
output_array: Union[np.ndarray, None] = None) -> np.ndarray:
if output_array is None:
output_array = np.empty_like(sample_times)
output_array = _ALLOCATION_FUNCTION(sample_times, **_ALLOCATION_FUNCTION_KWARGS)

for entry1, entry2 in zip(self._table[:-1], self._table[1:]):
if PULSE_TO_WAVEFORM_ERROR:
# we need to replace the last entry's t with self.duration
*entries, last = self._table
entries.append(TableWaveformEntry(float(self.duration), last.v, last.interp))
else:
entries = self._table

for entry1, entry2 in pairwise(entries):
indices = slice(np.searchsorted(sample_times, entry1.t, 'left'),
np.searchsorted(sample_times, entry2.t, 'right'))
output_array[indices] = \
Expand All @@ -247,9 +261,9 @@ def __repr__(self):


class ConstantWaveform(Waveform):

_is_constant_waveform = True

def __init__(self, duration: float, amplitude: Any, channel: ChannelID):
""" Create a qupulse waveform corresponding to a ConstantPulseTemplate """
self._duration = duration
Expand Down Expand Up @@ -277,9 +291,10 @@ def unsafe_sample(self,
sample_times: np.ndarray,
output_array: Union[np.ndarray, None] = None) -> np.ndarray:
if output_array is None:
output_array = np.empty_like(sample_times, dtype=float)
output_array[:] = self._amplitude
return output_array
return np.full_like(sample_times, fill_value=self._amplitude, dtype=float)
else:
output_array[:] = self._amplitude
return output_array

def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform:
"""Unsafe version of :func:`~qupulse.pulses.instructions.get_measurement_windows`."""
Expand Down Expand Up @@ -325,18 +340,19 @@ def unsafe_sample(self,
channel: ChannelID,
sample_times: np.ndarray,
output_array: Union[np.ndarray, None] = None) -> np.ndarray:
evaluated = self._expression.evaluate_numeric(t=sample_times)
if output_array is None:
output_array = np.empty(len(sample_times))
output_array[:] = self._expression.evaluate_numeric(t=sample_times)
return output_array
return evaluated.astype(float)
else:
output_array[:] = evaluated
return output_array

def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> Waveform:
return self


class SequenceWaveform(Waveform):
"""This class allows putting multiple PulseTemplate together in one waveform on the hardware."""

def __init__(self, sub_waveforms: Iterable[Waveform]):
"""
Expand Down Expand Up @@ -371,7 +387,7 @@ def unsafe_sample(self,
sample_times: np.ndarray,
output_array: Union[np.ndarray, None] = None) -> np.ndarray:
if output_array is None:
output_array = np.empty_like(sample_times)
output_array = _ALLOCATION_FUNCTION(sample_times, **_ALLOCATION_FUNCTION_KWARGS)
time = 0
for subwaveform in self._sequenced_waveforms:
# before you change anything here, make sure to understand the difference between basic and advanced
Expand Down Expand Up @@ -536,7 +552,7 @@ def unsafe_sample(self,
sample_times: np.ndarray,
output_array: Union[np.ndarray, None] = None) -> np.ndarray:
if output_array is None:
output_array = np.empty_like(sample_times)
output_array = _ALLOCATION_FUNCTION(sample_times, **_ALLOCATION_FUNCTION_KWARGS)
body_duration = self._body.duration
time = 0
for _ in range(self._repetition_count):
Expand Down
3 changes: 2 additions & 1 deletion qupulse/pulses/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import operator
import itertools

from qupulse._program import waveforms
from qupulse.utils.types import ChannelID, MeasurementWindow, has_type_interface
from qupulse.pulses.pulse_template import PulseTemplate
from qupulse.pulses.parameters import Parameter
Expand Down Expand Up @@ -86,7 +87,7 @@ def render(program: Union[Loop],
times = np.linspace(float(start_time), float(end_time), num=int(sample_count), dtype=float)
times[-1] = np.nextafter(times[-1], times[-2])

voltages = {ch: np.empty_like(times)
voltages = {ch: waveforms._ALLOCATION_FUNCTION(times, **waveforms._ALLOCATION_FUNCTION_KWARGS)
for ch in channels}
for ch, ch_voltage in voltages.items():
waveform.get_sampled(channel=ch, sample_times=times, output_array=ch_voltage)
Expand Down
30 changes: 29 additions & 1 deletion tests/pulses/bug_tests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import unittest
from unittest import mock

import numpy as np

from qupulse.pulses.table_pulse_template import TablePulseTemplate
from qupulse.pulses.function_pulse_template import FunctionPulseTemplate
Expand All @@ -9,6 +12,8 @@

from qupulse.pulses.plotting import plot

from qupulse._program._loop import to_waveform
from qupulse.utils import isclose

class BugTests(unittest.TestCase):

Expand Down Expand Up @@ -65,4 +70,27 @@ def test_plot_with_parameter_value_being_expression_string(self) -> None:

parameter_values = dict(omega=1.0, a=1.0, t_duration="2*pi")

_ = plot(both, parameters=parameter_values, sample_rate=100)
_ = plot(both, parameters=parameter_values, sample_rate=100)

def test_issue_584_uninitialized_table_sample(self):
"""issue 584"""
d = 598.3333333333334 - 480
tpt = TablePulseTemplate(entries={'P': [(0, 1.0, 'hold'), (d, 1.0, 'hold')]})
with mock.patch('qupulse._program.waveforms.PULSE_TO_WAVEFORM_ERROR', 1e-6):
wf = to_waveform(tpt.create_program())
self.assertTrue(isclose(d, wf.duration, abs_tol=1e-6))

start_time = 0.
end_time = wf.duration
sample_rate = 3.

sample_count = (end_time - start_time) * sample_rate + 1

times = np.linspace(float(start_time), float(wf.duration), num=int(sample_count), dtype=float)
times[-1] = np.nextafter(times[-1], times[-2])

out = np.full_like(times, fill_value=np.nan)
sampled = wf.get_sampled(channel='P', sample_times=times, output_array=out)

expected = np.full_like(times, fill_value=1.)
np.testing.assert_array_equal(expected, sampled)

0 comments on commit 25e0f71

Please sign in to comment.