From d910a6ea18e4f2806923cc54075d85db9e642ac7 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Thu, 5 Jul 2018 19:29:11 +0200 Subject: [PATCH 1/3] Concatenation function for TablePulseTemplate --- qctoolkit/pulses/table_pulse_template.py | 32 ++++++++- tests/pulses/table_pulse_template_tests.py | 75 +++++++++++++++++++++- 2 files changed, 105 insertions(+), 2 deletions(-) diff --git a/qctoolkit/pulses/table_pulse_template.py b/qctoolkit/pulses/table_pulse_template.py index b18eaa8fa..6eb2dfcf6 100644 --- a/qctoolkit/pulses/table_pulse_template.py +++ b/qctoolkit/pulses/table_pulse_template.py @@ -11,6 +11,7 @@ import numbers import itertools import warnings +import copy import numpy as np import sympy @@ -26,7 +27,7 @@ from qctoolkit.expressions import ExpressionScalar, Expression from qctoolkit.pulses.multi_channel_pulse_template import MultiChannelWaveform -__all__ = ["TablePulseTemplate", "TableWaveform", "TableWaveformEntry"] +__all__ = ["TablePulseTemplate", "TableWaveform", "TableWaveformEntry", "concatenate"] class TableWaveformEntry(NamedTuple('TableWaveformEntry', [('t', float), @@ -467,6 +468,35 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: return expressions +def concatenate(first_table_pulse_template: TablePulseTemplate, *table_pulse_templates: TablePulseTemplate, **kwargs): + """Concatenate two or more table pulse templates""" + entries = {channel: [] for channel in first_table_pulse_template.defined_channels} + duration = ExpressionScalar(0) + + for i, template in enumerate((first_table_pulse_template,) + table_pulse_templates): + new_duration = duration + template.duration + + if template.defined_channels != first_table_pulse_template.defined_channels: + raise ValueError() + + for channel, channel_entries in template.entries.items(): + first_t, first_v, _ = channel_entries[0] + if i > 0 and first_t != 0: + if (first_v == 0) is False: + entries[channel].append((duration, first_v, 'hold')) + + for t, v, interp in channel_entries: + entries[channel].append((duration.sympified_expression + t, v, interp)) + + last_t, last_v, _ = channel_entries[-1] + if i < len(table_pulse_templates) and last_t != new_duration: + entries[channel].append((new_duration, last_v, TablePulseTemplate.interpolation_strategies['hold'])) + + duration = new_duration + + return TablePulseTemplate(entries, **kwargs) + + class ZeroDurationTablePulseTemplate(UserWarning): pass diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index 71de919ca..58867057e 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -5,7 +5,7 @@ from qctoolkit.expressions import Expression from qctoolkit.serialization import Serializer -from qctoolkit.pulses.table_pulse_template import TablePulseTemplate, TableWaveform, TableEntry, TableWaveformEntry, ZeroDurationTablePulseTemplate, AmbiguousTablePulseEntry +from qctoolkit.pulses.table_pulse_template import TablePulseTemplate, TableWaveform, TableEntry, TableWaveformEntry, ZeroDurationTablePulseTemplate, AmbiguousTablePulseEntry, concatenate from qctoolkit.pulses.parameters import ParameterNotProvidedException, ParameterConstraintViolation from qctoolkit.pulses.interpolation import HoldInterpolationStrategy, LinearInterpolationStrategy, JumpInterpolationStrategy from qctoolkit.pulses.multi_channel_pulse_template import MultiChannelWaveform @@ -704,5 +704,78 @@ def test_simple_properties(self): self.assertIs(waveform.unsafe_get_subset_for_channels({'A'}), waveform) +class TablePulseConcatenationTests(unittest.TestCase): + def test_simple_concatenation(self): + tpt_1 = TablePulseTemplate({'A': [(0, 1), ('a', 5, 'linear')], + 'B': [(0, 2), ('b', 7)]}) + + tpt_2 = TablePulseTemplate({'A': [('c', 9), ('a', 10, 'jump')], + 'B': [(0, 6), ('b', 8)]}) + + expected = TablePulseTemplate({'A': [(0, 1), + ('a', 5, 'linear'), + ('Max(a, b)', 5), + ('Max(a, b)', 9), + ('Max(a, b) + c', 9), + ('Max(a, b) + a', 10, 'jump')], + 'B': [(0, 2), + ('b', 7), + ('Max(a, b)', 7, 'hold'), + ('Max(a, b)', 6), + ('Max(a, b) + b', 8)]}) + + concatenated = concatenate(tpt_1, tpt_2) + + self.assertEqual(expected.entries, concatenated.entries) + + def test_triple_concatenation(self): + tpt_1 = TablePulseTemplate({'A': [(0, 1), ('a', 5, 'linear')], + 'B': [(0, 2), ('b', 7)]}) + + tpt_2 = TablePulseTemplate({'A': [('c', 9), ('a', 10, 'jump')], + 'B': [(0, 6), ('b', 8)]}) + + tpt_3 = TablePulseTemplate({'A': [('fg', 19), ('ab', 110, 'jump')], + 'B': [('df', 16), ('ab', 18)]}) + + expected = TablePulseTemplate({'A': [(0, 1), + ('a', 5, 'linear'), + ('Max(a, b)', 5), + ('Max(a, b)', 9), + ('Max(a, b) + c', 9), + ('Max(a, b) + a', 10, 'jump'), + ('2*Max(a, b)', 10), + ('2*Max(a, b)', 19), + ('2*Max(a, b) + fg', 19), + ('2*Max(a, b) + ab', 110, 'jump')], + 'B': [(0, 2), + ('b', 7), + ('Max(a, b)', 7, 'hold'), + ('Max(a, b)', 6), + ('Max(a, b) + b', 8), + ('2*Max(a, b)', 8), + ('2*Max(a, b)', 16), + ('2*Max(a, b) + df', 16), + ('2*Max(a, b) + ab', 18)]}) + + concatenated = concatenate(tpt_1, tpt_2, tpt_3, identifier='asdf') + + self.assertEqual(expected.entries, concatenated.entries) + self.assertEqual(concatenated.identifier, 'asdf') + + def test_duplication(self): + tpt = TablePulseTemplate({'A': [(0, 1), ('a', 5)], + 'B': [(0, 2), ('b', 3)]}) + + concatenated = concatenate(tpt, tpt) + + self.assertIsNot(concatenated.entries, tpt.entries) + + expected = TablePulseTemplate({'A': [(0, 1), ('a', 5), ('Max(a, b)', 5), ('Max(a, b)', 1), ('Max(a, b) + a', 5)], + 'B': [(0, 2), ('b', 3), ('Max(a, b)', 3), ('Max(a, b)', 2), ('Max(a, b) + b', 3)]}) + + self.assertEqual(expected.entries, concatenated.entries) + + if __name__ == "__main__": unittest.main(verbosity=2) From 89a3bc619ef85417f85d7d1da23f2465a60ed4e0 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 6 Jul 2018 09:44:08 +0200 Subject: [PATCH 2/3] Add type annotation and simplify function signature --- qctoolkit/pulses/table_pulse_template.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/qctoolkit/pulses/table_pulse_template.py b/qctoolkit/pulses/table_pulse_template.py index 6eb2dfcf6..89816e205 100644 --- a/qctoolkit/pulses/table_pulse_template.py +++ b/qctoolkit/pulses/table_pulse_template.py @@ -468,15 +468,17 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: return expressions -def concatenate(first_table_pulse_template: TablePulseTemplate, *table_pulse_templates: TablePulseTemplate, **kwargs): +def concatenate(*table_pulse_templates: TablePulseTemplate, **kwargs) -> TablePulseTemplate: """Concatenate two or more table pulse templates""" - entries = {channel: [] for channel in first_table_pulse_template.defined_channels} + first_template, *other_templates = table_pulse_templates + + entries = {channel: [] for channel in first_template.defined_channels} duration = ExpressionScalar(0) - for i, template in enumerate((first_table_pulse_template,) + table_pulse_templates): + for i, template in enumerate(table_pulse_templates): new_duration = duration + template.duration - if template.defined_channels != first_table_pulse_template.defined_channels: + if template.defined_channels != first_template.defined_channels: raise ValueError() for channel, channel_entries in template.entries.items(): @@ -489,7 +491,7 @@ def concatenate(first_table_pulse_template: TablePulseTemplate, *table_pulse_tem entries[channel].append((duration.sympified_expression + t, v, interp)) last_t, last_v, _ = channel_entries[-1] - if i < len(table_pulse_templates) and last_t != new_duration: + if i < len(other_templates) and last_t != new_duration: entries[channel].append((new_duration, last_v, TablePulseTemplate.interpolation_strategies['hold'])) duration = new_duration From 87da8728e7f4157d8b815824ce5b2fe9438988dd Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Fri, 6 Jul 2018 10:01:55 +0200 Subject: [PATCH 3/3] Raise better exceptions and test them --- qctoolkit/pulses/table_pulse_template.py | 6 +++++- tests/pulses/table_pulse_template_tests.py | 21 ++++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/qctoolkit/pulses/table_pulse_template.py b/qctoolkit/pulses/table_pulse_template.py index 89816e205..8c58c917c 100644 --- a/qctoolkit/pulses/table_pulse_template.py +++ b/qctoolkit/pulses/table_pulse_template.py @@ -476,10 +476,14 @@ def concatenate(*table_pulse_templates: TablePulseTemplate, **kwargs) -> TablePu duration = ExpressionScalar(0) for i, template in enumerate(table_pulse_templates): + if not isinstance(template, TablePulseTemplate): + raise TypeError('Template number %d is not a TablePulseTemplate' % i) + new_duration = duration + template.duration if template.defined_channels != first_template.defined_channels: - raise ValueError() + raise ValueError('Template number %d has differing defined channels' % i, + first_template.defined_channels, template.defined_channels) for channel, channel_entries in template.entries.items(): first_t, first_v, _ = channel_entries[0] diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index 58867057e..8f1459c8a 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -10,7 +10,8 @@ from qctoolkit.pulses.interpolation import HoldInterpolationStrategy, LinearInterpolationStrategy, JumpInterpolationStrategy from qctoolkit.pulses.multi_channel_pulse_template import MultiChannelWaveform -from tests.pulses.sequencing_dummies import DummyInterpolationStrategy, DummyParameter, DummyCondition +from tests.pulses.sequencing_dummies import DummyInterpolationStrategy, DummyParameter, DummyCondition,\ + DummyPulseTemplate from tests.serialization_dummies import DummySerializer, DummyStorageBackend from tests.pulses.measurement_tests import ParameterConstrainerTest, MeasurementDefinerTest @@ -776,6 +777,24 @@ def test_duplication(self): self.assertEqual(expected.entries, concatenated.entries) + def test_wrong_channels(self): + tpt_1 = TablePulseTemplate({'A': [(0, 1), ('a', 5, 'linear')], + 'B': [(0, 2), ('b', 7)]}) + + tpt_2 = TablePulseTemplate({'A': [('c', 9), ('a', 10, 'jump')], + 'C': [(0, 6), ('b', 8)]}) + + with self.assertRaisesRegex(ValueError, 'differing defined channels'): + concatenate(tpt_1, tpt_2) + + def test_wrong_type(self): + dummy = DummyPulseTemplate() + tpt = TablePulseTemplate({'A': [(0, 1), ('a', 5, 'linear')], + 'B': [(0, 2), ('b', 7)]}) + + with self.assertRaisesRegex(TypeError, 'not a TablePulseTemplate'): + concatenate(dummy, tpt) + if __name__ == "__main__": unittest.main(verbosity=2)