Skip to content

Commit

Permalink
Merge pull request #292 from qutech/feat/table_pt_concatenation
Browse files Browse the repository at this point in the history
Concatenation function for TablePulseTemplate
  • Loading branch information
terrorfisch committed Jul 6, 2018
2 parents 5501597 + 87da872 commit 3a8df53
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 3 deletions.
38 changes: 37 additions & 1 deletion qctoolkit/pulses/table_pulse_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numbers
import itertools
import warnings
import copy

import numpy as np
import sympy
Expand All @@ -26,7 +27,7 @@
from qctoolkit.expressions import ExpressionScalar, Expression
from qctoolkit.pulses.multi_channel_pulse_template import MultiChannelWaveform

__all__ = ["TablePulseTemplate", "TableWaveform", "TableWaveformEntry"]
__all__ = ["TablePulseTemplate", "TableWaveform", "TableWaveformEntry", "concatenate"]


class TableWaveformEntry(NamedTuple('TableWaveformEntry', [('t', float),
Expand Down Expand Up @@ -467,6 +468,41 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]:
return expressions


def concatenate(*table_pulse_templates: TablePulseTemplate, **kwargs) -> TablePulseTemplate:
"""Concatenate two or more table pulse templates"""
first_template, *other_templates = table_pulse_templates

entries = {channel: [] for channel in first_template.defined_channels}
duration = ExpressionScalar(0)

for i, template in enumerate(table_pulse_templates):
if not isinstance(template, TablePulseTemplate):
raise TypeError('Template number %d is not a TablePulseTemplate' % i)

new_duration = duration + template.duration

if template.defined_channels != first_template.defined_channels:
raise ValueError('Template number %d has differing defined channels' % i,
first_template.defined_channels, template.defined_channels)

for channel, channel_entries in template.entries.items():
first_t, first_v, _ = channel_entries[0]
if i > 0 and first_t != 0:
if (first_v == 0) is False:
entries[channel].append((duration, first_v, 'hold'))

for t, v, interp in channel_entries:
entries[channel].append((duration.sympified_expression + t, v, interp))

last_t, last_v, _ = channel_entries[-1]
if i < len(other_templates) and last_t != new_duration:
entries[channel].append((new_duration, last_v, TablePulseTemplate.interpolation_strategies['hold']))

duration = new_duration

return TablePulseTemplate(entries, **kwargs)


class ZeroDurationTablePulseTemplate(UserWarning):
pass

Expand Down
96 changes: 94 additions & 2 deletions tests/pulses/table_pulse_template_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

from qctoolkit.expressions import Expression
from qctoolkit.serialization import Serializer
from qctoolkit.pulses.table_pulse_template import TablePulseTemplate, TableWaveform, TableEntry, TableWaveformEntry, ZeroDurationTablePulseTemplate, AmbiguousTablePulseEntry
from qctoolkit.pulses.table_pulse_template import TablePulseTemplate, TableWaveform, TableEntry, TableWaveformEntry, ZeroDurationTablePulseTemplate, AmbiguousTablePulseEntry, concatenate
from qctoolkit.pulses.parameters import ParameterNotProvidedException, ParameterConstraintViolation
from qctoolkit.pulses.interpolation import HoldInterpolationStrategy, LinearInterpolationStrategy, JumpInterpolationStrategy
from qctoolkit.pulses.multi_channel_pulse_template import MultiChannelWaveform

from tests.pulses.sequencing_dummies import DummyInterpolationStrategy, DummyParameter, DummyCondition
from tests.pulses.sequencing_dummies import DummyInterpolationStrategy, DummyParameter, DummyCondition,\
DummyPulseTemplate
from tests.serialization_dummies import DummySerializer, DummyStorageBackend
from tests.pulses.measurement_tests import ParameterConstrainerTest, MeasurementDefinerTest

Expand Down Expand Up @@ -704,5 +705,96 @@ def test_simple_properties(self):
self.assertIs(waveform.unsafe_get_subset_for_channels({'A'}), waveform)


class TablePulseConcatenationTests(unittest.TestCase):
def test_simple_concatenation(self):
tpt_1 = TablePulseTemplate({'A': [(0, 1), ('a', 5, 'linear')],
'B': [(0, 2), ('b', 7)]})

tpt_2 = TablePulseTemplate({'A': [('c', 9), ('a', 10, 'jump')],
'B': [(0, 6), ('b', 8)]})

expected = TablePulseTemplate({'A': [(0, 1),
('a', 5, 'linear'),
('Max(a, b)', 5),
('Max(a, b)', 9),
('Max(a, b) + c', 9),
('Max(a, b) + a', 10, 'jump')],
'B': [(0, 2),
('b', 7),
('Max(a, b)', 7, 'hold'),
('Max(a, b)', 6),
('Max(a, b) + b', 8)]})

concatenated = concatenate(tpt_1, tpt_2)

self.assertEqual(expected.entries, concatenated.entries)

def test_triple_concatenation(self):
tpt_1 = TablePulseTemplate({'A': [(0, 1), ('a', 5, 'linear')],
'B': [(0, 2), ('b', 7)]})

tpt_2 = TablePulseTemplate({'A': [('c', 9), ('a', 10, 'jump')],
'B': [(0, 6), ('b', 8)]})

tpt_3 = TablePulseTemplate({'A': [('fg', 19), ('ab', 110, 'jump')],
'B': [('df', 16), ('ab', 18)]})

expected = TablePulseTemplate({'A': [(0, 1),
('a', 5, 'linear'),
('Max(a, b)', 5),
('Max(a, b)', 9),
('Max(a, b) + c', 9),
('Max(a, b) + a', 10, 'jump'),
('2*Max(a, b)', 10),
('2*Max(a, b)', 19),
('2*Max(a, b) + fg', 19),
('2*Max(a, b) + ab', 110, 'jump')],
'B': [(0, 2),
('b', 7),
('Max(a, b)', 7, 'hold'),
('Max(a, b)', 6),
('Max(a, b) + b', 8),
('2*Max(a, b)', 8),
('2*Max(a, b)', 16),
('2*Max(a, b) + df', 16),
('2*Max(a, b) + ab', 18)]})

concatenated = concatenate(tpt_1, tpt_2, tpt_3, identifier='asdf')

self.assertEqual(expected.entries, concatenated.entries)
self.assertEqual(concatenated.identifier, 'asdf')

def test_duplication(self):
tpt = TablePulseTemplate({'A': [(0, 1), ('a', 5)],
'B': [(0, 2), ('b', 3)]})

concatenated = concatenate(tpt, tpt)

self.assertIsNot(concatenated.entries, tpt.entries)

expected = TablePulseTemplate({'A': [(0, 1), ('a', 5), ('Max(a, b)', 5), ('Max(a, b)', 1), ('Max(a, b) + a', 5)],
'B': [(0, 2), ('b', 3), ('Max(a, b)', 3), ('Max(a, b)', 2), ('Max(a, b) + b', 3)]})

self.assertEqual(expected.entries, concatenated.entries)

def test_wrong_channels(self):
tpt_1 = TablePulseTemplate({'A': [(0, 1), ('a', 5, 'linear')],
'B': [(0, 2), ('b', 7)]})

tpt_2 = TablePulseTemplate({'A': [('c', 9), ('a', 10, 'jump')],
'C': [(0, 6), ('b', 8)]})

with self.assertRaisesRegex(ValueError, 'differing defined channels'):
concatenate(tpt_1, tpt_2)

def test_wrong_type(self):
dummy = DummyPulseTemplate()
tpt = TablePulseTemplate({'A': [(0, 1), ('a', 5, 'linear')],
'B': [(0, 2), ('b', 7)]})

with self.assertRaisesRegex(TypeError, 'not a TablePulseTemplate'):
concatenate(dummy, tpt)


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit 3a8df53

Please sign in to comment.