diff --git a/changes.d/707.fix b/changes.d/707.fix new file mode 100644 index 000000000..33e9e9f8e --- /dev/null +++ b/changes.d/707.fix @@ -0,0 +1 @@ +Fixed that single segment tables where always interpreted to be constant. diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index e9ce8f6f0..a173f3bf3 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -298,7 +298,7 @@ def _validate_input(input_waveform_table: Sequence[EntryInInit]) -> Union[Tuple[ raise ValueError('Negative time values are not allowed.') # constant_v is None <=> the waveform is constant until up to the current entry - constant_v = first_interp.constant_value((previous_t, previous_v), (t, v)) + constant_v = interp.constant_value((previous_t, previous_v), (t, v)) for next_t, next_v, next_interp in input_iter: if next_t < t: diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index e4930372b..d84b2c763 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -555,6 +555,31 @@ def test_validate_input_errors(self): TableWaveformEntry(-0.2, 0.2, HoldInterpolationStrategy()), TableWaveformEntry(0.1, 0.2, HoldInterpolationStrategy())]) + def test_validate_input_const_detection(self): + constant_table = [TableWaveformEntry(0.0, 2.5, HoldInterpolationStrategy()), + (1.4, 2.5, LinearInterpolationStrategy())] + linear_table = [TableWaveformEntry(0.0, 0.0, HoldInterpolationStrategy()), + TableWaveformEntry(1.4, 2.5, LinearInterpolationStrategy())] + + self.assertEqual((1.4, 2.5), TableWaveform._validate_input(constant_table)) + self.assertEqual(linear_table, + TableWaveform._validate_input(linear_table)) + + def test_const_detection_regression(self): + # regression test 707 + from qupulse.pulses import PointPT + second_point_pt = PointPT([(0, 'v_0+v_1'), + ('t_2', 'v_0', 'linear')], + channel_names=('A',), + measurements=[('M', 0, 1)]) + parameters = dict(t=3, + t_2=2, + v_0=1, + v_1=1.4) + channel_mapping = {'A': 'A'} + wf = second_point_pt.build_waveform(parameters=parameters, channel_mapping=channel_mapping) + self.assertIsInstance(wf, TableWaveform) + def test_validate_input_duplicate_removal(self): validated = TableWaveform._validate_input([TableWaveformEntry(0.0, 0.2, HoldInterpolationStrategy()), TableWaveformEntry(0.1, 0.2, LinearInterpolationStrategy()),