Skip to content

Commit

Permalink
Fixed issue with Range.contains for scalars on unbounded ranges
Browse files Browse the repository at this point in the history
  • Loading branch information
runfalk committed Jun 7, 2017
1 parent ed9a79c commit d56bc1d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 17 deletions.
26 changes: 18 additions & 8 deletions spans/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import operator
import sys

from collections import namedtuple
Expand Down Expand Up @@ -322,14 +323,23 @@ def contains(self, other):
else:
return False
elif self.is_valid_scalar(other):
if self.lower_inc and self.upper_inc:
return self.lower <= other <= self.upper
elif self.lower_inc:
return self.lower <= other < self.upper
elif self.upper_inc:
return self.lower < other <= self.upper
else:
return self.lower < other < self.upper
# If the lower bounary is not unbound we can safely perform the
# comparison. Otherwise we'll try to compare a scalar to None, which
# is bad
is_within_lower = True
if not self.lower_inf:
lower_cmp = operator.le if self.lower_inc else operator.lt
is_within_lower = lower_cmp(self.lower, other)

# If the upper bounary is not unbound we can safely perform the
# comparison. Otherwise we'll try to compare a scalar to None, which
# is bad
is_within_upper = True
if not self.upper_inf:
upper_cmp = operator.ge if self.upper_inc else operator.gt
is_within_upper = upper_cmp(self.upper, other)

return is_within_lower and is_within_upper
else:
raise TypeError(
"Unsupported type to test for inclusion '{0.__class__.__name__}'".format(
Expand Down
22 changes: 13 additions & 9 deletions tests/test_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,21 +273,25 @@ def test_endswith_type_check():


@pytest.mark.parametrize("a, b", [
(intrange(1, 5), intrange(1, 5)),
(intrange(1, 10), intrange(1, 5)),
(intrange(1, 10), intrange(5, 10)),
(intrange(1, 5), 1),
(intrange(1, 5), 3),
(floatrange(1.0, 5.0), floatrange(1.0, 5.0)),
(floatrange(1.0, 10.0), floatrange(1.0, 5.0)),
(floatrange(1.0, 10.0), floatrange(5.0, 10.0)),
(floatrange(1.0, 5.0), 1.0),
(floatrange(1.0, 5.0), 3.0),
(floatrange(1.0), 3.0),
(floatrange(upper=5.0), 3.0),
])
def test_contains(a, b):
assert a.contains(b)


@pytest.mark.parametrize("a, b", [
(intrange(1, 5, lower_inc=False), intrange(1, 5)),
(intrange(1, 5), intrange(1, 5, upper_inc=True)),
(intrange(1, 5, lower_inc=False), 1),
(intrange(1, 5), 5),
(floatrange(1.0, 5.0, lower_inc=False), floatrange(1.0, 5.0)),
(floatrange(1.0, 5.0), floatrange(1.0, 5.0, upper_inc=True)),
(floatrange(1.0, 5.0, lower_inc=False), 1.0),
(floatrange(1.0, 5.0), 5.0),
(floatrange(1.0, lower_inc=False), 1.0),
(floatrange(upper=5.0), 5.0),
])
def test_not_contains(a, b):
assert not a.contains(b)
Expand Down

0 comments on commit d56bc1d

Please sign in to comment.