From f12a056a78cd957bd8160716e759bdc32958fcd9 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 1 Jul 2020 16:20:31 +0200 Subject: [PATCH 1/3] Add failing test --- tests/pulses/table_pulse_template_tests.py | 31 ++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/pulses/table_pulse_template_tests.py b/tests/pulses/table_pulse_template_tests.py index 8fc7fd49c..efae8335c 100644 --- a/tests/pulses/table_pulse_template_tests.py +++ b/tests/pulses/table_pulse_template_tests.py @@ -561,6 +561,37 @@ def test_build_waveform_single_channel(self): self.assertEqual(waveform._channel_id, 'ch') + def test_build_waveform_time_type(self): + from qupulse.utils.types import TimeType + + table = TablePulseTemplate({0: [(0, 0), + ('foo', 'v', 'linear'), + ('bar', 0, 'jump')]}, + parameter_constraints=['foo>1'], + measurements=[('M', 'b', 'l'), + ('N', 1, 2)]) + + parameters = {'v': 2.3, + 'foo': TimeType.from_float(1.), 'bar': TimeType.from_float(4), + 'b': TimeType.from_float(2), 'l': TimeType.from_float(1)} + channel_mapping = {0: 'ch'} + + with self.assertRaises(ParameterConstraintViolation): + table.build_waveform(parameters=parameters, + channel_mapping=channel_mapping) + + parameters['foo'] = TimeType.from_float(1.1) + waveform = table.build_waveform(parameters=parameters, + channel_mapping=channel_mapping) + + self.assertIsInstance(waveform, TableWaveform) + self.assertEqual(waveform._table, + ((0, 0, HoldInterpolationStrategy()), + (TimeType.from_float(1.1), 2.3, LinearInterpolationStrategy()), + (4, 0, JumpInterpolationStrategy()))) + self.assertEqual(waveform._channel_id, + 'ch') + def test_build_waveform_multi_channel(self): table = TablePulseTemplate({0: [(0, 0), ('foo', 'v', 'linear'), From d4f215478c4f038e980e85606e9f810095c52ed6 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 1 Jul 2020 16:29:40 +0200 Subject: [PATCH 2/3] Better error message --- qupulse/utils/types.py | 10 ++++++++-- tests/utils/time_type_tests.py | 8 ++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/qupulse/utils/types.py b/qupulse/utils/types.py index 5e64e8742..3b19c3d3b 100644 --- a/qupulse/utils/types.py +++ b/qupulse/utils/types.py @@ -200,8 +200,14 @@ def from_float(cls, value: float, absolute_error: typing.Optional[float] = None) if type(value) in (cls, cls._InternalType, fractions.Fraction): return cls(value) else: - # .upper() is a bit faster than replace('e', 'E') which gmpy2.mpq needs - return cls(cls._to_internal(str(value).upper())) + try: + # .upper() is a bit faster than replace('e', 'E') which gmpy2.mpq needs + return cls(cls._to_internal(str(value).upper())) + except ValueError: + if isinstance(value, numbers.Number) and not numpy.isfinite(value): + raise ValueError('Cannot represent "{}" as TimeType'.format(value), value) + else: + raise elif absolute_error == 0: return cls(cls._to_internal(value)) diff --git a/tests/utils/time_type_tests.py b/tests/utils/time_type_tests.py index 0339f44c9..7800b2f53 100644 --- a/tests/utils/time_type_tests.py +++ b/tests/utils/time_type_tests.py @@ -58,6 +58,14 @@ def fallback_qutypes(self): self._fallback_qutypes = qutypes return self._fallback_qutypes + def test_non_finite_float(self): + with self.assertRaisesRegex(ValueError, 'Cannot represent'): + qutypes.TimeType.from_float(float('inf')) + with self.assertRaisesRegex(ValueError, 'Cannot represent'): + qutypes.TimeType.from_float(float('-inf')) + with self.assertRaisesRegex(ValueError, 'Cannot represent'): + qutypes.TimeType.from_float(float('nan')) + def test_fraction_fallback(self): self.assertIs(fractions.Fraction, self.fallback_qutypes.TimeType._InternalType) From 9a6c866a1f09b5facf2552e04b3b7a4bf7779297 Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Wed, 1 Jul 2020 16:43:55 +0200 Subject: [PATCH 3/3] Do not cast to TimeType for comparison to allow comparison with nonfinite --- qupulse/utils/types.py | 27 +++++++++++++------------ tests/utils/time_type_tests.py | 36 ++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 12 deletions(-) diff --git a/qupulse/utils/types.py b/qupulse/utils/types.py index 3b19c3d3b..2d2267e97 100644 --- a/qupulse/utils/types.py +++ b/qupulse/utils/types.py @@ -151,21 +151,17 @@ def __floordiv__(self, other: 'TimeType'): def __rfloordiv__(self, other: 'TimeType'): return self._value.__rfloordiv__(other._value) - @_with_other_as_time_type - def __le__(self, other: 'TimeType'): - return self._value.__le__(other._value) + def __le__(self, other): + return self._value <= self.as_comparable(other) - @_with_other_as_time_type - def __ge__(self, other: 'TimeType'): - return self._value.__ge__(other._value) + def __ge__(self, other): + return self._value >= self.as_comparable(other) - @_with_other_as_time_type - def __lt__(self, other: 'TimeType'): - return self._value.__lt__(other._value) + def __lt__(self, other): + return self._value < self.as_comparable(other) - @_with_other_as_time_type - def __gt__(self, other: 'TimeType'): - return self._value.__gt__(other._value) + def __gt__(self, other): + return self._value > self.as_comparable(other) def __eq__(self, other): if type(other) == type(self): @@ -173,6 +169,13 @@ def __eq__(self, other): else: return self._value == other + @classmethod + def as_comparable(cls, other: typing.Union['TimeType', typing.Any]): + if type(other) == cls: + return other._value + else: + return other + @classmethod def from_float(cls, value: float, absolute_error: typing.Optional[float] = None) -> 'TimeType': """Convert a floating point number to a TimeType using one of three modes depending on `absolute_error`. diff --git a/tests/utils/time_type_tests.py b/tests/utils/time_type_tests.py index 7800b2f53..87d201c28 100644 --- a/tests/utils/time_type_tests.py +++ b/tests/utils/time_type_tests.py @@ -142,6 +142,42 @@ def test_from_float_exceptions(self): with self.assertRaisesRegex(ValueError, '<= 1'): qutypes.time_from_float(.8, 2) + def assert_comparisons_work(self, time_type): + tt = time_type.from_float(1.1) + + self.assertLess(tt, 4) + self.assertLess(tt, 4.) + self.assertLess(tt, time_type.from_float(4.)) + self.assertLess(tt, float('inf')) + + self.assertLessEqual(tt, 4) + self.assertLessEqual(tt, 4.) + self.assertLessEqual(tt, time_type.from_float(4.)) + self.assertLessEqual(tt, float('inf')) + + self.assertGreater(tt, 1) + self.assertGreater(tt, 1.) + self.assertGreater(tt, time_type.from_float(1.)) + self.assertGreater(tt, float('-inf')) + + self.assertGreaterEqual(tt, 1) + self.assertGreaterEqual(tt, 1.) + self.assertGreaterEqual(tt, time_type.from_float(1.)) + self.assertGreaterEqual(tt, float('-inf')) + + self.assertFalse(tt == float('nan')) + self.assertFalse(tt <= float('nan')) + self.assertFalse(tt >= float('nan')) + self.assertFalse(tt < float('nan')) + self.assertFalse(tt > float('nan')) + + def test_comparisons_work(self): + self.assert_comparisons_work(qutypes.TimeType) + + @unittest.skipIf(gmpy2 is None, "fallback already tested") + def test_comparisons_work_fallback(self): + self.assert_comparisons_work(self.fallback_qutypes.TimeType) + def get_some_floats(seed=42, n=1000): rand = random.Random(seed)