Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[unitaryHACK] Add type hints to _seq_drawer.py (#16) #173

Merged
merged 12 commits into from May 29, 2021
28 changes: 19 additions & 9 deletions pulser/_seq_drawer.py
Expand Up @@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np

import pulser
from pulser.waveforms import ConstantWaveform
from pulser.pulse import Pulse
from scipy.interpolate import CubicSpline
from typing import Any, cast, Dict, Optional, Tuple, Union
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved


def gather_data(seq):
def gather_data(seq: pulser.sequence.Sequence) -> Dict:
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
"""Collects the whole sequence data for plotting.

Args:
Expand All @@ -36,7 +41,7 @@ def gather_data(seq):
time = [-1] # To not break the "time[-1]" later on
amp = []
detuning = []
target = {}
target: Dict[Union[str, Tuple[int, int]], Any] = {}
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
# phase_shift = {}
for slot in sch:
if slot.ti == -1:
Expand All @@ -52,12 +57,12 @@ def gather_data(seq):
if slot.type == 'target':
target[(slot.ti, slot.tf-1)] = slot.targets
continue
pulse = slot.type
pulse = cast(Pulse, slot.type)
if (isinstance(pulse.amplitude, ConstantWaveform) and
isinstance(pulse.detuning, ConstantWaveform)):
time += [slot.ti, slot.tf-1]
amp += [pulse.amplitude._value] * 2
detuning += [pulse.detuning._value] * 2
amp += [int(pulse.amplitude._value)] * 2
detuning += [int(pulse.detuning._value)] * 2
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
else:
time += list(range(slot.ti, slot.tf))
amp += pulse.amplitude.samples.tolist()
Expand All @@ -75,7 +80,9 @@ def gather_data(seq):
return data


def draw_sequence(seq, sampling_rate=None, draw_phase_area=False):
def draw_sequence(seq: pulser.sequence.Sequence,
sampling_rate: Optional[float] = None,
draw_phase_area: bool = False) -> None:
"""Draw the entire sequence.

Args:
Expand All @@ -89,7 +96,7 @@ def draw_sequence(seq, sampling_rate=None, draw_phase_area=False):
as text on the plot, defaults to False.
"""

def phase_str(phi):
def phase_str(phi: float) -> str:
"""Formats a phase value for printing."""
value = (((phi + np.pi) % (2*np.pi)) - np.pi) / np.pi
if value == -1:
Expand Down Expand Up @@ -280,12 +287,15 @@ def phase_str(phi):
# Terminate the last open regions
if target_regions:
target_regions[-1].append(t[-1])
for start, targets, end in target_regions:
for start, targets, end in target_regions: # type: ignore
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
q = targets[0] # All targets have the same ref, so we pick
ref = seq._phase_ref[basis][q]
if end != seq._total_duration - 1 or 'measurement' not in data[ch]:
end = cast(int, end)
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
end += 1 / time_scale
for t_, delta in ref.changes(start, end, time_scale=time_scale):
for t_, delta in ref.changes(cast(int, start),
HGSilveri marked this conversation as resolved.
Show resolved Hide resolved
cast(int, end),
time_scale=time_scale):
conf = dict(linestyle='--', linewidth=1.5, color='black')
a.axvline(t_, **conf)
b.axvline(t_, **conf)
Expand Down
1 change: 1 addition & 0 deletions pulser/sequence.py
Expand Up @@ -163,6 +163,7 @@ def __init__(self, register: Register, device: Device):
# Checks if register is compatible with the device
device.validate_register(register)

self._total_duration: int = 0
self._register: Register = register
self._device: Device = device
self._in_xy: bool = False
Expand Down