Skip to content

Commit

Permalink
FIX: plotting now no longer crashes on MEASInstructions; Improved tests
Browse files Browse the repository at this point in the history
  • Loading branch information
terrorfisch committed Mar 9, 2018
1 parent 0fbc741 commit acc701a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 22 deletions.
49 changes: 29 additions & 20 deletions qctoolkit/pulses/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,52 +7,61 @@
- plot: Plot a pulse using matplotlib.
"""

from typing import Dict, Tuple, Any
from typing import Dict, Tuple, Any, Generator, Optional

import numpy as np

from qctoolkit.utils.types import ChannelID
from qctoolkit.pulses.pulse_template import PulseTemplate
from qctoolkit.pulses.parameters import Parameter
from qctoolkit.pulses.sequencing import Sequencer
from qctoolkit.pulses.instructions import EXECInstruction, STOPInstruction, InstructionSequence, \
REPJInstruction
from qctoolkit.pulses.instructions import EXECInstruction, STOPInstruction, AbstractInstructionBlock, \
REPJInstruction, MEASInstruction, GOTOInstruction, Waveform, InstructionPointer


__all__ = ["render", "plot", "PlottingNotPossibleException"]


def render(sequence: InstructionSequence, sample_rate: int=10) -> Tuple[np.ndarray, Dict[ChannelID, np.ndarray]]:
def iter_waveforms(instruction_block: AbstractInstructionBlock,
expected_return: Optional[InstructionPointer]=None) -> Generator[Waveform, None, None]:
for i, instruction in enumerate(instruction_block):
if isinstance(instruction, EXECInstruction):
yield instruction.waveform
elif isinstance(instruction, REPJInstruction):
expected_repj_return = InstructionPointer(instruction_block, i+1)
repj_instructions = instruction.target.block.instructions[instruction.target.offset:]
for _ in range(instruction.count):
yield from iter_waveforms(repj_instructions, expected_repj_return)
elif isinstance(instruction, MEASInstruction):
continue
elif isinstance(instruction, GOTOInstruction):
if instruction.target != expected_return:
raise NotImplementedError("Instruction block contains an unexpected GOTO instruction.")
return
elif isinstance(instruction, STOPInstruction):
raise StopIteration()
else:
raise NotImplementedError('Rendering cannot handle instructions of type {}.'.format(type(instruction)))


def render(sequence: AbstractInstructionBlock, sample_rate: int=10) -> Tuple[np.ndarray, Dict[ChannelID, np.ndarray]]:
"""'Render' an instruction sequence (sample all contained waveforms into an array).
Returns:
a tuple (times, values) of numpy.ndarrays of similar size. times contains the time value
of all sample times and values the corresponding sampled value.
"""
if not all(isinstance(x, (EXECInstruction, STOPInstruction, REPJInstruction)) for x in sequence):
raise NotImplementedError('Can only plot waveforms without branching so far.')

def get_waveform_generator(instruction_block):
for instruction in instruction_block:
if isinstance(instruction, EXECInstruction):
yield instruction.waveform
elif isinstance(instruction, REPJInstruction):
for _ in range(instruction.count):
yield from get_waveform_generator(instruction.target.block[instruction.target.offset:])
else:
return

waveforms = [wf for wf in get_waveform_generator(sequence)]
waveforms = list(iter_waveforms(sequence, ))
if not waveforms:
return [], []
return np.empty(0), dict()

total_time = sum(waveform.duration for waveform in waveforms)

channels = waveforms[0].defined_channels

# add one sample to see the end of the waveform
sample_count = total_time * sample_rate + 1
times = np.linspace(0, total_time, num=sample_count)
times = np.linspace(0, total_time, num=sample_count, dtype=float)
# move the last sample inside the waveform
times[-1] = np.nextafter(times[-1], times[-2])

Expand Down
24 changes: 22 additions & 2 deletions tests/pulses/plotting_tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import numpy

from qctoolkit.pulses.plotting import PlottingNotPossibleException, render
from qctoolkit.pulses.plotting import PlottingNotPossibleException, render, iter_waveforms
from qctoolkit.pulses.instructions import InstructionBlock
from qctoolkit.pulses.table_pulse_template import TablePulseTemplate
from qctoolkit.pulses.sequence_pulse_template import SequencePulseTemplate
Expand All @@ -20,14 +20,34 @@ def test_render_unsupported_instructions(self) -> None:
render(block)

def test_render_no_waveforms(self) -> None:
self.assertEqual(([], []), render(InstructionBlock()))
time, channel_data = render(InstructionBlock())
self.assertEqual(channel_data, dict())
numpy.testing.assert_equal(time, numpy.empty(0))

def test_iter_waveforms(self):
wf1 = DummyWaveform(duration=7)
wf2 = DummyWaveform(duration=5)
wf3 = DummyWaveform(duration=3)

repeated_block = InstructionBlock()
repeated_block.add_instruction_meas([('m', 1, 2)])
repeated_block.add_instruction_exec(wf2)

main_block = InstructionBlock()
main_block.add_instruction_exec(wf1)
main_block.add_instruction_repj(2, repeated_block)
main_block.add_instruction_exec(wf3)

for idx, (expected, received) in enumerate(zip([wf1, wf2, wf2, wf3], iter_waveforms(main_block))):
self.assertIs(expected, received, msg="Waveform {} is wrong".format(idx))

def test_render(self) -> None:
wf1 = DummyWaveform(duration=19)
wf2 = DummyWaveform(duration=21)

block = InstructionBlock()
block.add_instruction_exec(wf1)
block.add_instruction_meas([('asd', 0, 1)])
block.add_instruction_exec(wf2)

wf1_expected = ('A', [0, 2, 4, 6, 8, 10, 12, 14, 16, 18])
Expand Down

0 comments on commit acc701a

Please sign in to comment.