Skip to content

Commit

Permalink
ENH: make "closed" part of IntervalDtype (#38394)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel committed Jan 10, 2021
1 parent 2d08672 commit 839c1bd
Show file tree
Hide file tree
Showing 22 changed files with 252 additions and 102 deletions.
4 changes: 2 additions & 2 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,8 @@ def float_frame():
# ----------------------------------------------------------------
@pytest.fixture(
params=[
(Interval(left=0, right=5), IntervalDtype("int64")),
(Interval(left=0.1, right=0.5), IntervalDtype("float64")),
(Interval(left=0, right=5), IntervalDtype("int64", "right")),
(Interval(left=0.1, right=0.5), IntervalDtype("float64", "right")),
(Period("2012-01", freq="M"), "period[M]"),
(Period("2012-02-01", freq="D"), "period[D]"),
(
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/_arrow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __hash__(self):
def to_pandas_dtype(self):
import pandas as pd

return pd.IntervalDtype(self.subtype.to_pandas_dtype())
return pd.IntervalDtype(self.subtype.to_pandas_dtype(), self.closed)

# register the type with a dummy instance
_interval_type = ArrowIntervalType(pyarrow.int64(), "left")
Expand Down
35 changes: 23 additions & 12 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@
>>> pd.arrays.IntervalArray([pd.Interval(0, 1), pd.Interval(1, 5)])
<IntervalArray>
[(0, 1], (1, 5]]
Length: 2, closed: right, dtype: interval[int64]
Length: 2, closed: right, dtype: interval[int64, right]
It may also be constructed using one of the constructor
methods: :meth:`IntervalArray.from_arrays`,
Expand Down Expand Up @@ -222,6 +222,9 @@ def _simple_new(
):
result = IntervalMixin.__new__(cls)

if closed is None and isinstance(dtype, IntervalDtype):
closed = dtype.closed

closed = closed or "right"
left = ensure_index(left, copy=copy)
right = ensure_index(right, copy=copy)
Expand All @@ -238,6 +241,12 @@ def _simple_new(
msg = f"dtype must be an IntervalDtype, got {dtype}"
raise TypeError(msg)

if dtype.closed is None:
# possibly loading an old pickle
dtype = IntervalDtype(dtype.subtype, closed)
elif closed != dtype.closed:
raise ValueError("closed keyword does not match dtype.closed")

# coerce dtypes to match if needed
if is_float_dtype(left) and is_integer_dtype(right):
right = right.astype(left.dtype)
Expand Down Expand Up @@ -279,9 +288,11 @@ def _simple_new(
# If these share data, then setitem could corrupt our IA
right = right.copy()

dtype = IntervalDtype(left.dtype, closed=closed)
result._dtype = dtype

result._left = left
result._right = right
result._closed = closed
if verify_integrity:
result._validate()
return result
Expand Down Expand Up @@ -343,7 +354,7 @@ def _from_factorized(cls, values, original):
>>> pd.arrays.IntervalArray.from_breaks([0, 1, 2, 3])
<IntervalArray>
[(0, 1], (1, 2], (2, 3]]
Length: 3, closed: right, dtype: interval[int64]
Length: 3, closed: right, dtype: interval[int64, right]
"""
),
}
Expand Down Expand Up @@ -414,7 +425,7 @@ def from_breaks(
>>> pd.arrays.IntervalArray.from_arrays([0, 1, 2], [1, 2, 3])
<IntervalArray>
[(0, 1], (1, 2], (2, 3]]
Length: 3, closed: right, dtype: interval[int64]
Length: 3, closed: right, dtype: interval[int64, right]
"""
),
}
Expand Down Expand Up @@ -473,7 +484,7 @@ def from_arrays(
>>> pd.arrays.IntervalArray.from_tuples([(0, 1), (1, 2)])
<IntervalArray>
[(0, 1], (1, 2]]
Length: 2, closed: right, dtype: interval[int64]
Length: 2, closed: right, dtype: interval[int64, right]
"""
),
}
Expand Down Expand Up @@ -553,7 +564,7 @@ def _shallow_copy(self, left, right):

@property
def dtype(self):
return IntervalDtype(self.left.dtype)
return self._dtype

@property
def nbytes(self) -> int:
Expand Down Expand Up @@ -1174,7 +1185,7 @@ def mid(self):
>>> intervals
<IntervalArray>
[(0, 1], (1, 3], (2, 4]]
Length: 3, closed: right, dtype: interval[int64]
Length: 3, closed: right, dtype: interval[int64, right]
"""
),
}
Expand Down Expand Up @@ -1203,7 +1214,7 @@ def closed(self):
Whether the intervals are closed on the left-side, right-side, both or
neither.
"""
return self._closed
return self.dtype.closed

_interval_shared_docs["set_closed"] = textwrap.dedent(
"""
Expand Down Expand Up @@ -1238,11 +1249,11 @@ def closed(self):
>>> index
<IntervalArray>
[(0, 1], (1, 2], (2, 3]]
Length: 3, closed: right, dtype: interval[int64]
Length: 3, closed: right, dtype: interval[int64, right]
>>> index.set_closed('both')
<IntervalArray>
[[0, 1], [1, 2], [2, 3]]
Length: 3, closed: both, dtype: interval[int64]
Length: 3, closed: both, dtype: interval[int64, both]
"""
),
}
Expand Down Expand Up @@ -1301,7 +1312,7 @@ def __array__(self, dtype: Optional[NpDtype] = None) -> np.ndarray:
left = self._left
right = self._right
mask = self.isna()
closed = self._closed
closed = self.closed

result = np.empty(len(left), dtype=object)
for i in range(len(left)):
Expand Down Expand Up @@ -1441,7 +1452,7 @@ def repeat(self, repeats, axis=None):
>>> intervals
<IntervalArray>
[(0, 1], (1, 3], (2, 4]]
Length: 3, closed: right, dtype: interval[int64]
Length: 3, closed: right, dtype: interval[int64, right]
"""
),
}
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def infer_dtype_from_scalar(val, pandas_dtype: bool = False) -> Tuple[DtypeObj,
dtype = PeriodDtype(freq=val.freq)
elif lib.is_interval(val):
subtype = infer_dtype_from_scalar(val.left, pandas_dtype=True)[0]
dtype = IntervalDtype(subtype=subtype)
dtype = IntervalDtype(subtype=subtype, closed=val.closed)

return dtype, val

Expand Down
61 changes: 51 additions & 10 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,37 +999,60 @@ class IntervalDtype(PandasExtensionDtype):
Examples
--------
>>> pd.IntervalDtype(subtype='int64')
interval[int64]
>>> pd.IntervalDtype(subtype='int64', closed='both')
interval[int64, both]
"""

name = "interval"
kind: str_type = "O"
str = "|O08"
base = np.dtype("O")
num = 103
_metadata = ("subtype",)
_match = re.compile(r"(I|i)nterval\[(?P<subtype>.+)\]")
_metadata = (
"subtype",
"closed",
)
_match = re.compile(
r"(I|i)nterval\[(?P<subtype>[^,]+)(, (?P<closed>(right|left|both|neither)))?\]"
)
_cache: Dict[str_type, PandasExtensionDtype] = {}

def __new__(cls, subtype=None):
def __new__(cls, subtype=None, closed: Optional[str_type] = None):
from pandas.core.dtypes.common import is_string_dtype, pandas_dtype

if closed is not None and closed not in {"right", "left", "both", "neither"}:
raise ValueError("closed must be one of 'right', 'left', 'both', 'neither'")

if isinstance(subtype, IntervalDtype):
if closed is not None and closed != subtype.closed:
raise ValueError(
"dtype.closed and 'closed' do not match. "
"Try IntervalDtype(dtype.subtype, closed) instead."
)
return subtype
elif subtype is None:
# we are called as an empty constructor
# generally for pickle compat
u = object.__new__(cls)
u._subtype = None
u._closed = closed
return u
elif isinstance(subtype, str) and subtype.lower() == "interval":
subtype = None
else:
if isinstance(subtype, str):
m = cls._match.search(subtype)
if m is not None:
subtype = m.group("subtype")
gd = m.groupdict()
subtype = gd["subtype"]
if gd.get("closed", None) is not None:
if closed is not None:
if closed != gd["closed"]:
raise ValueError(
"'closed' keyword does not match value "
"specified in dtype string"
)
closed = gd["closed"]

try:
subtype = pandas_dtype(subtype)
Expand All @@ -1044,14 +1067,20 @@ def __new__(cls, subtype=None):
)
raise TypeError(msg)

key = str(subtype) + str(closed)
try:
return cls._cache[str(subtype)]
return cls._cache[key]
except KeyError:
u = object.__new__(cls)
u._subtype = subtype
cls._cache[str(subtype)] = u
u._closed = closed
cls._cache[key] = u
return u

@property
def closed(self):
return self._closed

@property
def subtype(self):
"""
Expand Down Expand Up @@ -1101,7 +1130,10 @@ def type(self):
def __str__(self) -> str_type:
if self.subtype is None:
return "interval"
return f"interval[{self.subtype}]"
if self.closed is None:
# Only partially initialized GH#38394
return f"interval[{self.subtype}]"
return f"interval[{self.subtype}, {self.closed}]"

def __hash__(self) -> int:
# make myself hashable
Expand All @@ -1115,6 +1147,8 @@ def __eq__(self, other: Any) -> bool:
elif self.subtype is None or other.subtype is None:
# None should match any subtype
return True
elif self.closed != other.closed:
return False
else:
from pandas.core.dtypes.common import is_dtype_equal

Expand All @@ -1126,6 +1160,9 @@ def __setstate__(self, state):
# pickle -> need to set the settable private ones here (see GH26067)
self._subtype = state["subtype"]

# backward-compat older pickles won't have "closed" key
self._closed = state.pop("closed", None)

@classmethod
def is_dtype(cls, dtype: object) -> bool:
"""
Expand Down Expand Up @@ -1174,9 +1211,13 @@ def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
if not all(isinstance(x, IntervalDtype) for x in dtypes):
return None

closed = cast("IntervalDtype", dtypes[0]).closed
if not all(cast("IntervalDtype", x).closed == closed for x in dtypes):
return np.dtype(object)

from pandas.core.dtypes.cast import find_common_type

common = find_common_type([cast("IntervalDtype", x).subtype for x in dtypes])
if common == object:
return np.dtype(object)
return IntervalDtype(common)
return IntervalDtype(common, closed=closed)

0 comments on commit 839c1bd

Please sign in to comment.