Skip to content

Commit

Permalink
Do not cast to TimeType for comparison to allow comparison with nonfi…
Browse files Browse the repository at this point in the history
…nite
  • Loading branch information
terrorfisch committed Jul 1, 2020
1 parent d4f2154 commit 9a6c866
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 12 deletions.
27 changes: 15 additions & 12 deletions qupulse/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,28 +151,31 @@ def __floordiv__(self, other: 'TimeType'):
def __rfloordiv__(self, other: 'TimeType'):
return self._value.__rfloordiv__(other._value)

@_with_other_as_time_type
def __le__(self, other: 'TimeType'):
return self._value.__le__(other._value)
def __le__(self, other):
return self._value <= self.as_comparable(other)

@_with_other_as_time_type
def __ge__(self, other: 'TimeType'):
return self._value.__ge__(other._value)
def __ge__(self, other):
return self._value >= self.as_comparable(other)

@_with_other_as_time_type
def __lt__(self, other: 'TimeType'):
return self._value.__lt__(other._value)
def __lt__(self, other):
return self._value < self.as_comparable(other)

@_with_other_as_time_type
def __gt__(self, other: 'TimeType'):
return self._value.__gt__(other._value)
def __gt__(self, other):
return self._value > self.as_comparable(other)

def __eq__(self, other):
if type(other) == type(self):
return self._value.__eq__(other._value)
else:
return self._value == other

@classmethod
def as_comparable(cls, other: typing.Union['TimeType', typing.Any]):
if type(other) == cls:
return other._value
else:
return other

@classmethod
def from_float(cls, value: float, absolute_error: typing.Optional[float] = None) -> 'TimeType':
"""Convert a floating point number to a TimeType using one of three modes depending on `absolute_error`.
Expand Down
36 changes: 36 additions & 0 deletions tests/utils/time_type_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,42 @@ def test_from_float_exceptions(self):
with self.assertRaisesRegex(ValueError, '<= 1'):
qutypes.time_from_float(.8, 2)

def assert_comparisons_work(self, time_type):
tt = time_type.from_float(1.1)

self.assertLess(tt, 4)
self.assertLess(tt, 4.)
self.assertLess(tt, time_type.from_float(4.))
self.assertLess(tt, float('inf'))

self.assertLessEqual(tt, 4)
self.assertLessEqual(tt, 4.)
self.assertLessEqual(tt, time_type.from_float(4.))
self.assertLessEqual(tt, float('inf'))

self.assertGreater(tt, 1)
self.assertGreater(tt, 1.)
self.assertGreater(tt, time_type.from_float(1.))
self.assertGreater(tt, float('-inf'))

self.assertGreaterEqual(tt, 1)
self.assertGreaterEqual(tt, 1.)
self.assertGreaterEqual(tt, time_type.from_float(1.))
self.assertGreaterEqual(tt, float('-inf'))

self.assertFalse(tt == float('nan'))
self.assertFalse(tt <= float('nan'))
self.assertFalse(tt >= float('nan'))
self.assertFalse(tt < float('nan'))
self.assertFalse(tt > float('nan'))

def test_comparisons_work(self):
self.assert_comparisons_work(qutypes.TimeType)

@unittest.skipIf(gmpy2 is None, "fallback already tested")
def test_comparisons_work_fallback(self):
self.assert_comparisons_work(self.fallback_qutypes.TimeType)


def get_some_floats(seed=42, n=1000):
rand = random.Random(seed)
Expand Down

0 comments on commit 9a6c866

Please sign in to comment.