Skip to content

Commit

Permalink
Merge 9a6c866 into a223bd6
Browse files Browse the repository at this point in the history
  • Loading branch information
terrorfisch committed Jul 1, 2020
2 parents a223bd6 + 9a6c866 commit a885dc8
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 14 deletions.
37 changes: 23 additions & 14 deletions qupulse/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,28 +151,31 @@ def __floordiv__(self, other: 'TimeType'):
def __rfloordiv__(self, other: 'TimeType'):
return self._value.__rfloordiv__(other._value)

@_with_other_as_time_type
def __le__(self, other: 'TimeType'):
return self._value.__le__(other._value)
def __le__(self, other):
return self._value <= self.as_comparable(other)

@_with_other_as_time_type
def __ge__(self, other: 'TimeType'):
return self._value.__ge__(other._value)
def __ge__(self, other):
return self._value >= self.as_comparable(other)

@_with_other_as_time_type
def __lt__(self, other: 'TimeType'):
return self._value.__lt__(other._value)
def __lt__(self, other):
return self._value < self.as_comparable(other)

@_with_other_as_time_type
def __gt__(self, other: 'TimeType'):
return self._value.__gt__(other._value)
def __gt__(self, other):
return self._value > self.as_comparable(other)

def __eq__(self, other):
if type(other) == type(self):
return self._value.__eq__(other._value)
else:
return self._value == other

@classmethod
def as_comparable(cls, other: typing.Union['TimeType', typing.Any]):
if type(other) == cls:
return other._value
else:
return other

@classmethod
def from_float(cls, value: float, absolute_error: typing.Optional[float] = None) -> 'TimeType':
"""Convert a floating point number to a TimeType using one of three modes depending on `absolute_error`.
Expand Down Expand Up @@ -200,8 +203,14 @@ def from_float(cls, value: float, absolute_error: typing.Optional[float] = None)
if type(value) in (cls, cls._InternalType, fractions.Fraction):
return cls(value)
else:
# .upper() is a bit faster than replace('e', 'E') which gmpy2.mpq needs
return cls(cls._to_internal(str(value).upper()))
try:
# .upper() is a bit faster than replace('e', 'E') which gmpy2.mpq needs
return cls(cls._to_internal(str(value).upper()))
except ValueError:
if isinstance(value, numbers.Number) and not numpy.isfinite(value):
raise ValueError('Cannot represent "{}" as TimeType'.format(value), value)
else:
raise

elif absolute_error == 0:
return cls(cls._to_internal(value))
Expand Down
31 changes: 31 additions & 0 deletions tests/pulses/table_pulse_template_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,37 @@ def test_build_waveform_single_channel(self):
self.assertEqual(waveform._channel_id,
'ch')

def test_build_waveform_time_type(self):
from qupulse.utils.types import TimeType

table = TablePulseTemplate({0: [(0, 0),
('foo', 'v', 'linear'),
('bar', 0, 'jump')]},
parameter_constraints=['foo>1'],
measurements=[('M', 'b', 'l'),
('N', 1, 2)])

parameters = {'v': 2.3,
'foo': TimeType.from_float(1.), 'bar': TimeType.from_float(4),
'b': TimeType.from_float(2), 'l': TimeType.from_float(1)}
channel_mapping = {0: 'ch'}

with self.assertRaises(ParameterConstraintViolation):
table.build_waveform(parameters=parameters,
channel_mapping=channel_mapping)

parameters['foo'] = TimeType.from_float(1.1)
waveform = table.build_waveform(parameters=parameters,
channel_mapping=channel_mapping)

self.assertIsInstance(waveform, TableWaveform)
self.assertEqual(waveform._table,
((0, 0, HoldInterpolationStrategy()),
(TimeType.from_float(1.1), 2.3, LinearInterpolationStrategy()),
(4, 0, JumpInterpolationStrategy())))
self.assertEqual(waveform._channel_id,
'ch')

def test_build_waveform_multi_channel(self):
table = TablePulseTemplate({0: [(0, 0),
('foo', 'v', 'linear'),
Expand Down
44 changes: 44 additions & 0 deletions tests/utils/time_type_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def fallback_qutypes(self):
self._fallback_qutypes = qutypes
return self._fallback_qutypes

def test_non_finite_float(self):
with self.assertRaisesRegex(ValueError, 'Cannot represent'):
qutypes.TimeType.from_float(float('inf'))
with self.assertRaisesRegex(ValueError, 'Cannot represent'):
qutypes.TimeType.from_float(float('-inf'))
with self.assertRaisesRegex(ValueError, 'Cannot represent'):
qutypes.TimeType.from_float(float('nan'))

def test_fraction_fallback(self):
self.assertIs(fractions.Fraction, self.fallback_qutypes.TimeType._InternalType)

Expand Down Expand Up @@ -134,6 +142,42 @@ def test_from_float_exceptions(self):
with self.assertRaisesRegex(ValueError, '<= 1'):
qutypes.time_from_float(.8, 2)

def assert_comparisons_work(self, time_type):
tt = time_type.from_float(1.1)

self.assertLess(tt, 4)
self.assertLess(tt, 4.)
self.assertLess(tt, time_type.from_float(4.))
self.assertLess(tt, float('inf'))

self.assertLessEqual(tt, 4)
self.assertLessEqual(tt, 4.)
self.assertLessEqual(tt, time_type.from_float(4.))
self.assertLessEqual(tt, float('inf'))

self.assertGreater(tt, 1)
self.assertGreater(tt, 1.)
self.assertGreater(tt, time_type.from_float(1.))
self.assertGreater(tt, float('-inf'))

self.assertGreaterEqual(tt, 1)
self.assertGreaterEqual(tt, 1.)
self.assertGreaterEqual(tt, time_type.from_float(1.))
self.assertGreaterEqual(tt, float('-inf'))

self.assertFalse(tt == float('nan'))
self.assertFalse(tt <= float('nan'))
self.assertFalse(tt >= float('nan'))
self.assertFalse(tt < float('nan'))
self.assertFalse(tt > float('nan'))

def test_comparisons_work(self):
self.assert_comparisons_work(qutypes.TimeType)

@unittest.skipIf(gmpy2 is None, "fallback already tested")
def test_comparisons_work_fallback(self):
self.assert_comparisons_work(self.fallback_qutypes.TimeType)


def get_some_floats(seed=42, n=1000):
rand = random.Random(seed)
Expand Down

0 comments on commit a885dc8

Please sign in to comment.