Skip to content

Commit

Permalink
Added unit tests for forgotten code paths. Redid total ordering since…
Browse files Browse the repository at this point in the history
… in Python 2.7 functools.total_ordering does not properly support NotImplemented, which resulted in an infinite loop in some edge cases.
  • Loading branch information
runfalk committed Dec 22, 2015
1 parent ee7b72a commit 0bfa3b5
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 21 deletions.
3 changes: 3 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ Not released yet
(`Michael Krate <https://github.com/der-michik>`_)
- Added `Sphinx <http://sphinx-doc.org/>`_ style doc strings to all methods
- Added proper Sphinx documentation
- Added unit tests for uncovered parts, mostly error checking
- Fixed a potential bug where comparing ranges of different types would result
in an infinite loop
- Changed meta class implementation for range sets to allow more mixins for
custom range sets

Expand Down
49 changes: 49 additions & 0 deletions spans/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,52 @@ def __getstate__(self):
def __setstate__(self, data):
for attr, value in data.items():
setattr(self, attr, value)

def sane_total_ordering(cls):
def __ge__(self, other):
lt = self.__lt__(other)
if lt is NotImplemented:
return NotImplemented

return not lt

def __le__(self, other):
lt = self.__lt__(other)
if lt is NotImplemented:
return NotImplemented

eq = self.__eq__(other)
if eq is NotImplemented:
return NotImplemented

return lt or eq

def __gt__(self, other):
le = __le__(self, other)
if le is NotImplemented:
return NotImplemented

return not le

def __ne__(self, other):
eq = self.__eq__(other)
if eq is NotImplemented:
return NotImplemented

return not eq

ops = [(f.__name__, f) for f in [__ge__, __le__, __gt__]]
predefined = set(dir(cls))

if "__lt__" not in predefined:
raise ValueError("Must define __lt__")

for func in [__ge__, __le__, __gt__, __ne__]:
name = func.__name__

# Test if class actually has overridden the default rich comparison
# implementation
if name not in predefined or getattr(cls, name) is getattr(object, name):
setattr(cls, name, func)

return cls
4 changes: 2 additions & 2 deletions spans/settypes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import total_ordering
from itertools import chain

from ._compat import add_metaclass
from ._utils import sane_total_ordering
from .types import range_
from .types import *
from .types import discreterange, offsetablerange
Expand Down Expand Up @@ -91,7 +91,7 @@ def offset(self, offset):
metarangeset.add(offsetablerange, offsetablerangeset)


@total_ordering
@sane_total_ordering
@add_metaclass(metarangeset)
class rangeset(object):
"""
Expand Down
100 changes: 97 additions & 3 deletions spans/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@

class TestIntRange(TestCase):
def test_empty(self):
self.assertFalse(intrange.empty())
range = intrange.empty()

self.assertFalse(range)

self.assertIsNone(range.lower)
self.assertIsNone(range.upper)

def test_non_empty(self):
self.assertTrue(intrange())
Expand Down Expand Up @@ -65,6 +70,9 @@ def test_offset(self):
self.assertEqual(low_range.offset(5), high_range)
self.assertEqual(low_range, high_range.offset(-5))

with self.assertRaises(TypeError):
low_range.offset(5.0)

def test_offset_unbounded(self):
range = intrange()

Expand All @@ -75,6 +83,8 @@ def test_equality(self):
self.assertEqual(intrange.empty(), intrange.empty())
self.assertNotEqual(intrange(1, 5), intrange(1, 5, upper_inc=True))

self.assertFalse(intrange() == None)

def test_less_than(self):
self.assertTrue(intrange(1, 5) < intrange(2, 5))
self.assertTrue(intrange(1, 4) < intrange(1, 5))
Expand All @@ -86,6 +96,12 @@ def test_less_than(self):
self.assertTrue(intrange(1, 4) <= intrange(1, 5))
self.assertFalse(intrange(2, 5) <= intrange(1, 5))

# Hack used to work around version differences between Python 2 and 3
# Python 2 has its own idea of how objects compare to each other.
# Python 3 raises type error when an operation is not implemented
self.assertIs(intrange().__lt__(floatrange()), NotImplemented)
self.assertIs(intrange().__le__(floatrange()), NotImplemented)

def test_greater_than(self):
self.assertTrue(intrange(2, 5) > intrange(1, 5))
self.assertTrue(intrange(1, 5) > intrange(1, 4))
Expand All @@ -97,18 +113,28 @@ def test_greater_than(self):
self.assertTrue(intrange(1, 5) >= intrange(1, 4))
self.assertFalse(intrange(1, 5) >= intrange(2, 5))

# Hack used to work around version differences between Python 2 and 3.
# Python 2 has its own idea of how objects compare to each other.
# Python 3 raises type error when an operation is not implemented
self.assertIs(intrange().__gt__(floatrange()), NotImplemented)
self.assertIs(intrange().__ge__(floatrange()), NotImplemented)

def test_left_of(self):
self.assertTrue(intrange(1, 5).left_of(intrange(5, 10)))
self.assertTrue(intrange(1, 5).left_of(intrange(10, 15)))
self.assertFalse(intrange(1, 5, upper_inc=True).left_of(intrange(5, 10)))
self.assertFalse(intrange(5, 10).left_of(intrange(1, 5)))

self.assertFalse(intrange.empty().left_of(intrange.empty()))

def test_right_of(self):
self.assertTrue(intrange(5, 10).right_of(intrange(1, 5)))
self.assertTrue(intrange(5, 10).right_of(intrange(1, 5)))
self.assertFalse(intrange(5, 10).right_of(intrange(1, 5, upper_inc=True)))
self.assertFalse(intrange(1, 5).right_of(intrange(5, 10)))

self.assertFalse(intrange.empty().right_of(intrange.empty()))

def test_startsafter(self):
self.assertTrue(intrange(1, 5).startsafter(intrange(1, 5)))
self.assertTrue(intrange(1, 5).startsafter(intrange(1, 10)))
Expand All @@ -121,6 +147,11 @@ def test_startsafter(self):
self.assertFalse(intrange(1, 10).startsafter(intrange(5)))
self.assertTrue(intrange(1, 10).startsafter(intrange(upper=5)))

self.assertTrue(intrange(1, 5).startsafter(0))

with self.assertRaises(TypeError):
intrange(1, 5).startsafter(1.0)

def test_endsbefore(self):
self.assertTrue(intrange(1, 5).endsbefore(intrange(1, 5)))
self.assertTrue(intrange(5, 10).endsbefore(intrange(1, 10)))
Expand All @@ -133,6 +164,11 @@ def test_endsbefore(self):
self.assertTrue(intrange(1, 10).endsbefore(intrange(5)))
self.assertFalse(intrange(1, 10).endsbefore(intrange(upper=5)))

self.assertTrue(intrange(1, 5).endsbefore(5))

with self.assertRaises(TypeError):
intrange(1, 5).endsbefore(5.0)

def test_startswith(self):
self.assertTrue(intrange(1, 5).startswith(intrange(1, 5)))
self.assertTrue(intrange(1, 5).startswith(intrange(1, 10)))
Expand All @@ -143,6 +179,9 @@ def test_startswith(self):
self.assertTrue(intrange(1, 5).startswith(1))
self.assertFalse(intrange(1, 5, lower_inc=False).startswith(1))

with self.assertRaises(TypeError):
intrange(1, 5).startswith(1.0)

def test_endswith(self):
self.assertTrue(intrange(5, 10).endswith(intrange(5, 10)))
self.assertTrue(intrange(1, 10).endswith(intrange(5, 10)))
Expand All @@ -153,6 +192,9 @@ def test_endswith(self):
self.assertFalse(intrange(1, 5).endswith(5))
self.assertTrue(intrange(1, 5, upper_inc=True).endswith(5))

with self.assertRaises(TypeError):
intrange(1, 5).endswith(5.0)

def test_contains(self):
# Test ranges
self.assertTrue(intrange(1, 5).contains(intrange(1, 5)))
Expand All @@ -168,7 +210,7 @@ def test_contains(self):
self.assertFalse(intrange(1, 5).contains(5))

with self.assertRaises(TypeError):
intrange.contains(True)
intrange(1, 5).contains(None)

def test_within(self):
# Test ranges
Expand All @@ -179,7 +221,7 @@ def test_within(self):
self.assertTrue(intrange(5, 10).within(intrange(1, 10)))

with self.assertRaises(TypeError):
intrange.within(True)
intrange(1, 5).within(1)

def test_overlap(self):
self.assertFalse(intrange(1, 5).overlap(intrange(5, 10)))
Expand All @@ -198,12 +240,22 @@ def test_adjacent(self):
self.assertFalse(intrange(1, 5).adjacent(intrange(3, 8)))
self.assertFalse(intrange(3, 8).adjacent(intrange(1, 5)))

# Test that empty range is not adjacent to a range
self.assertFalse(intrange.empty().adjacent(intrange(0, 5)))

with self.assertRaises(TypeError):
intrange(1, 5).adjacent(floatrange(5.0, 10.0))

def test_union(self):
self.assertEqual(intrange(1, 5).union(intrange(5, 10)), intrange(1, 10))
self.assertEqual(intrange(1, 5).union(intrange(3, 10)), intrange(1, 10))
self.assertEqual(intrange(5, 10).union(intrange(1, 5)), intrange(1, 10))
self.assertEqual(intrange(3, 10).union(intrange(1, 5)), intrange(1, 10))

# Test interaction with empty ranges
self.assertEqual(intrange.empty().union(intrange(1, 5)), intrange(1, 5))
self.assertEqual(intrange(1, 5).union(intrange.empty()), intrange(1, 5))

with self.assertRaises(ValueError):
intrange(1, 5).union(intrange(5, 10, lower_inc=False))

Expand All @@ -226,6 +278,7 @@ def test_difference(self):
intrange(5, 8, lower_inc=False))
self.assertEqual(intrange(1, 5).difference(intrange(1, 3)), intrange(3, 5))
self.assertEqual(intrange(1, 5).difference(intrange(3, 5)), intrange(1, 3))
self.assertEqual(intrange(1, 5).difference(intrange(1, 10)), intrange.empty())

with self.assertRaises(ValueError):
intrange(1, 15).difference(intrange(5, 10))
Expand Down Expand Up @@ -253,10 +306,38 @@ def test_iteration(self):

class TestFloatRange(TestCase):
def test_invalid_bounds(self):
with self.assertRaises(TypeError):
intrange("foo")

with self.assertRaises(TypeError):
intrange(upper="foo")

with self.assertRaises(ValueError):
floatrange(10.0, 5.0)

def test_contains(self):
self.assertTrue(floatrange(1.0, 5.0).contains(1.0))
self.assertTrue(floatrange(1.0, 5.0).contains(3.0))
self.assertFalse(floatrange(1.0, 5.0, lower_inc=False).contains(1.0))
self.assertFalse(floatrange(1.0, 5.0).contains(5.0))

def test_startswith(self):
# Special case that discrete ranges can't cover
self.assertFalse(floatrange(1.0, lower_inc=False).startswith(1.0))

def test_endswith(self):
# Special case that discrete ranges can't cover
self.assertFalse(floatrange(upper=5.0).endswith(5.0))
self.assertTrue(floatrange(upper=5.0, upper_inc=True).endswith(5.0))

class TestDateRange(TestCase):
def test_datetime(self):
with self.assertRaises(TypeError):
daterange(datetime(2000, 1, 1))

with self.assertRaises(TypeError):
daterange(upper=datetime(2000, 1, 1))

def test_offset(self):
range_low = daterange(date(2000, 1, 1), date(2000, 1, 6))
range_high = daterange(date(2000, 1, 5), date(2000, 1, 10))
Expand All @@ -282,7 +363,20 @@ def test_datetime_input(self):
with self.assertRaises(TypeError):
daterange(datetime(2000, 1, 1))

def test_len(self):
with self.assertRaises(ValueError):
len(daterange())


class TestStrRange(TestCase):
def test_last(self):
self.assertEqual(strrange(u"a", u"c").last, u"b")
self.assertEqual(strrange(u"aa", u"cc").last, u"cb")

def text_prev(self):
self.assertEqual(strrange.prev(u""), u"")
self.assertEqual(strrange.prev(u"b"), u"a")

def text_next(self):
self.assertEqual(strrange.next(u""), u"")
self.assertEqual(strrange.next(u"a"), u"b")
29 changes: 13 additions & 16 deletions spans/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

from collections import namedtuple
from datetime import date, datetime, timedelta
from functools import total_ordering

from ._compat import *
from ._utils import PicklableSlotMixin
from ._utils import PicklableSlotMixin, sane_total_ordering

__all__ = [
"intrange",
Expand All @@ -20,7 +19,7 @@
"_internal_range", ["lower", "upper", "lower_inc", "upper_inc", "empty"])
_empty_internal_range = _internal_range(None, None, False, False, True)

@total_ordering
@sane_total_ordering
class range_(PicklableSlotMixin):
"""
Abstract base class of all ranges.
Expand Down Expand Up @@ -238,18 +237,6 @@ def __lt__(self, other):
def __nonzero__(self):
return not self._range.empty

def __contains__(self, item):
try:
return self.contains(item)
except TypeError:
return NotImplemented

def __lshift__(self, other):
return self.left_of(other)

def __rshift__(self, other):
return self.right_of(other)

def contains(self, other):
"""
Return True if this contains other. Other may be either range of same
Expand Down Expand Up @@ -357,10 +344,15 @@ def adjacent(self, other):
:param other: Range to test against.
:return: ``True`` if this range is adjacent with `other`, otherwise
``False``.
:raises TypeError: If given argument is of invalid type
"""

if not isinstance(other, self.__class__):
raise TypeError(
"Unsupported type to test for inclusion '{0.__class__.__name__}'".format(
other))
# Must return False if either is an empty set
if not self or not other:
elif not self or not other:
return False
return (
(self.lower == other.upper and self.lower_inc != other.upper_inc) or
Expand Down Expand Up @@ -677,6 +669,11 @@ def right_of(self, other):

return other.left_of(self)

# TODO: Properly implement NotImplemented
__contains__ = contains
__lshift__ = left_of
__rshift__ = right_of

# Python 3 support
__bool__ = __nonzero__

Expand Down

0 comments on commit 0bfa3b5

Please sign in to comment.