Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Support For Interval __contains__ Other Interval (#46613) #47927

Merged
merged 12 commits into from Aug 15, 2022
Merged
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.5.0.rst
Expand Up @@ -292,6 +292,7 @@ Other enhancements
- :class:`Series` reducers (e.g. ``min``, ``max``, ``sum``, ``mean``) will now successfully operate when the dtype is numeric and ``numeric_only=True`` is provided; previously this would raise a ``NotImplementedError`` (:issue:`47500`)
- :meth:`RangeIndex.union` now can return a :class:`RangeIndex` instead of a :class:`Int64Index` if the resulting values are equally spaced (:issue:`47557`, :issue:`43885`)
- :meth:`DataFrame.compare` now accepts an argument ``result_names`` to allow the user to specify the result's names of both left and right DataFrame which are being compared. This is by default ``'self'`` and ``'other'`` (:issue:`44354`)
- :class:`Interval` now supports checking whether one interval is contained by another interval (:issue:`46613`)
- :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support a ``copy`` argument. If ``False``, the underlying data is not copied in the returned object (:issue:`47934`)
- :meth:`DataFrame.set_index` now supports a ``copy`` keyword. If ``False``, the underlying data is not copied when a new :class:`DataFrame` is returned (:issue:`48043`)

Expand Down
11 changes: 9 additions & 2 deletions pandas/_libs/interval.pyi
Expand Up @@ -79,10 +79,17 @@ class Interval(IntervalMixin, Generic[_OrderableT]):
def __hash__(self) -> int: ...
@overload
def __contains__(
self: Interval[_OrderableTimesT], key: _OrderableTimesT
self: Interval[Timedelta], key: Timedelta | Interval[Timedelta]
) -> bool: ...
@overload
def __contains__(self: Interval[_OrderableScalarT], key: float) -> bool: ...
def __contains__(
self: Interval[Timestamp], key: Timestamp | Interval[Timestamp]
) -> bool: ...
@overload
def __contains__(
self: Interval[_OrderableScalarT],
key: _OrderableScalarT | Interval[_OrderableScalarT],
) -> bool: ...
@overload
def __add__(
self: Interval[_OrderableTimesT], y: Timedelta
Expand Down
16 changes: 14 additions & 2 deletions pandas/_libs/interval.pyx
Expand Up @@ -299,10 +299,12 @@ cdef class Interval(IntervalMixin):
>>> iv
Interval(0, 5, inclusive='right')

You can check if an element belongs to it
You can check if an element belongs to it, or if it contains another interval:

>>> 2.5 in iv
True
>>> pd.Interval(left=2, right=5, inclusive='both') in iv
True

You can test the bounds (``inclusive='right'``, so ``0 < x <= 5``):

Expand Down Expand Up @@ -412,7 +414,17 @@ cdef class Interval(IntervalMixin):

def __contains__(self, key) -> bool:
if _interval_like(key):
raise TypeError("__contains__ not defined for two intervals")
key_closed_left = key.inclusive in ('left', 'both')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @venaturum if you'd like to review if this aligns with piso

key_closed_right = key.inclusive in ('right', 'both')
if self.open_left and key_closed_left:
left_contained = self.left < key.left
else:
left_contained = self.left <= key.left
if self.open_right and key_closed_right:
right_contained = key.right < self.right
else:
right_contained = key.right <= self.right
return left_contained and right_contained
return ((self.left < key if self.open_left else self.left <= key) and
(key < self.right if self.open_right else key <= self.right))

Expand Down
4 changes: 0 additions & 4 deletions pandas/tests/scalar/interval/test_interval.py
Expand Up @@ -36,10 +36,6 @@ def test_contains(self, interval):
assert 1 in interval
assert 0 not in interval

msg = "__contains__ not defined for two intervals"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed by this enhancement.

with pytest.raises(TypeError, match=msg):
interval in interval

interval_both = Interval(0, 1, "both")
assert 0 in interval_both
assert 1 in interval_both
Expand Down
51 changes: 51 additions & 0 deletions pandas/tests/scalar/interval/test_ops.py
Expand Up @@ -66,3 +66,54 @@ def test_overlaps_invalid_type(self, other):
msg = f"`other` must be an Interval, got {type(other).__name__}"
with pytest.raises(TypeError, match=msg):
interval.overlaps(other)


class TestContains:
def test_contains_interval(self, inclusive_endpoints_fixture):
interval1 = Interval(0, 1, "both")
interval2 = Interval(0, 1, inclusive_endpoints_fixture)
assert interval1 in interval1
assert interval2 in interval2
assert interval2 in interval1
assert interval1 not in interval2 or inclusive_endpoints_fixture == "both"

def test_contains_infinite_length(self):
interval1 = Interval(0, 1, "both")
interval2 = Interval(float("-inf"), float("inf"), "neither")
assert interval1 in interval2
assert interval2 not in interval1

def test_contains_zero_length(self):
interval1 = Interval(0, 1, "both")
interval2 = Interval(-1, -1, "both")
interval3 = Interval(0.5, 0.5, "both")
assert interval2 not in interval1
assert interval3 in interval1
assert interval2 not in interval3 and interval3 not in interval2
assert interval1 not in interval2 and interval1 not in interval3

@pytest.mark.parametrize(
"type1",
[
(0, 1),
(Timestamp(2000, 1, 1, 0), Timestamp(2000, 1, 1, 1)),
(Timedelta("0h"), Timedelta("1h")),
],
)
@pytest.mark.parametrize(
"type2",
[
(0, 1),
(Timestamp(2000, 1, 1, 0), Timestamp(2000, 1, 1, 1)),
(Timedelta("0h"), Timedelta("1h")),
],
)
def test_contains_mixed_types(self, type1, type2):
interval1 = Interval(*type1)
interval2 = Interval(*type2)
if type1 == type2:
assert interval1 in interval2
else:
msg = "^'<=' not supported between instances of"
with pytest.raises(TypeError, match=msg):
interval1 in interval2