Skip to content

Commit

Permalink
TYP: Improve typing interval inclusive (#47646)
Browse files Browse the repository at this point in the history
* TYP: Make typing of inclusive consistent

* Fix comparison

* Fix typing issues

* Try fixing pyright
  • Loading branch information
phofl committed Jul 9, 2022
1 parent e915b0a commit 5506476
Show file tree
Hide file tree
Showing 10 changed files with 52 additions and 42 deletions.
16 changes: 8 additions & 8 deletions pandas/_libs/interval.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import numpy.typing as npt

from pandas._libs import lib
from pandas._typing import (
IntervalClosedType,
IntervalInclusiveType,
Timedelta,
Timestamp,
)
Expand Down Expand Up @@ -56,25 +56,25 @@ class IntervalMixin:

def _warning_interval(
inclusive, closed
) -> tuple[IntervalClosedType, lib.NoDefault]: ...
) -> tuple[IntervalInclusiveType, lib.NoDefault]: ...

class Interval(IntervalMixin, Generic[_OrderableT]):
@property
def left(self: Interval[_OrderableT]) -> _OrderableT: ...
@property
def right(self: Interval[_OrderableT]) -> _OrderableT: ...
@property
def inclusive(self) -> IntervalClosedType: ...
def inclusive(self) -> IntervalInclusiveType: ...
@property
def closed(self) -> IntervalClosedType: ...
def closed(self) -> IntervalInclusiveType: ...
mid: _MidDescriptor
length: _LengthDescriptor
def __init__(
self,
left: _OrderableT,
right: _OrderableT,
inclusive: IntervalClosedType = ...,
closed: IntervalClosedType = ...,
inclusive: IntervalInclusiveType = ...,
closed: IntervalInclusiveType = ...,
) -> None: ...
def __hash__(self) -> int: ...
@overload
Expand Down Expand Up @@ -151,14 +151,14 @@ class Interval(IntervalMixin, Generic[_OrderableT]):

def intervals_to_interval_bounds(
intervals: np.ndarray, validate_closed: bool = ...
) -> tuple[np.ndarray, np.ndarray, str]: ...
) -> tuple[np.ndarray, np.ndarray, IntervalInclusiveType]: ...

class IntervalTree(IntervalMixin):
def __init__(
self,
left: np.ndarray,
right: np.ndarray,
inclusive: IntervalClosedType = ...,
inclusive: IntervalInclusiveType = ...,
leaf_size: int = ...,
) -> None: ...
@property
Expand Down
2 changes: 1 addition & 1 deletion pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def closed(self) -> bool:

# Interval closed type
IntervalLeftRight = Literal["left", "right"]
IntervalClosedType = Union[IntervalLeftRight, Literal["both", "neither"]]
IntervalInclusiveType = Union[IntervalLeftRight, Literal["both", "neither"]]

# datetime and NaTType
DatetimeNaTType = Union[datetime, "NaTType"]
Expand Down
9 changes: 5 additions & 4 deletions pandas/core/arrays/arrow/_arrow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pyarrow

from pandas._typing import IntervalInclusiveType
from pandas.errors import PerformanceWarning
from pandas.util._decorators import deprecate_kwarg
from pandas.util._exceptions import find_stack_level
Expand Down Expand Up @@ -107,11 +108,11 @@ def to_pandas_dtype(self):

class ArrowIntervalType(pyarrow.ExtensionType):
@deprecate_kwarg(old_arg_name="closed", new_arg_name="inclusive")
def __init__(self, subtype, inclusive: str) -> None:
def __init__(self, subtype, inclusive: IntervalInclusiveType) -> None:
# attributes need to be set first before calling
# super init (as that calls serialize)
assert inclusive in VALID_CLOSED
self._closed = inclusive
self._closed: IntervalInclusiveType = inclusive
if not isinstance(subtype, pyarrow.DataType):
subtype = pyarrow.type_for_alias(str(subtype))
self._subtype = subtype
Expand All @@ -124,11 +125,11 @@ def subtype(self):
return self._subtype

@property
def inclusive(self) -> str:
def inclusive(self) -> IntervalInclusiveType:
return self._closed

@property
def closed(self):
def closed(self) -> IntervalInclusiveType:
warnings.warn(
"Attribute `closed` is deprecated in favor of `inclusive`.",
FutureWarning,
Expand Down
22 changes: 12 additions & 10 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from pandas._typing import (
ArrayLike,
Dtype,
IntervalClosedType,
IntervalInclusiveType,
NpDtype,
PositionalIndexer,
ScalarIndexer,
Expand Down Expand Up @@ -230,7 +230,7 @@ def ndim(self) -> Literal[1]:
def __new__(
cls: type[IntervalArrayT],
data,
inclusive: str | None = None,
inclusive: IntervalInclusiveType | None = None,
dtype: Dtype | None = None,
copy: bool = False,
verify_integrity: bool = True,
Expand Down Expand Up @@ -277,7 +277,7 @@ def _simple_new(
cls: type[IntervalArrayT],
left,
right,
inclusive=None,
inclusive: IntervalInclusiveType | None = None,
copy: bool = False,
dtype: Dtype | None = None,
verify_integrity: bool = True,
Expand Down Expand Up @@ -431,7 +431,7 @@ def _from_factorized(
def from_breaks(
cls: type[IntervalArrayT],
breaks,
inclusive: IntervalClosedType | None = None,
inclusive: IntervalInclusiveType | None = None,
copy: bool = False,
dtype: Dtype | None = None,
) -> IntervalArrayT:
Expand Down Expand Up @@ -513,7 +513,7 @@ def from_arrays(
cls: type[IntervalArrayT],
left,
right,
inclusive: IntervalClosedType | None = None,
inclusive: IntervalInclusiveType | None = None,
copy: bool = False,
dtype: Dtype | None = None,
) -> IntervalArrayT:
Expand Down Expand Up @@ -586,7 +586,7 @@ def from_arrays(
def from_tuples(
cls: type[IntervalArrayT],
data,
inclusive=None,
inclusive: IntervalInclusiveType | None = None,
copy: bool = False,
dtype: Dtype | None = None,
) -> IntervalArrayT:
Expand Down Expand Up @@ -1364,15 +1364,15 @@ def overlaps(self, other):
# ---------------------------------------------------------------------

@property
def inclusive(self) -> IntervalClosedType:
def inclusive(self) -> IntervalInclusiveType:
"""
Whether the intervals are closed on the left-side, right-side, both or
neither.
"""
return self.dtype.inclusive

@property
def closed(self) -> IntervalClosedType:
def closed(self) -> IntervalInclusiveType:
"""
Whether the intervals are closed on the left-side, right-side, both or
neither.
Expand Down Expand Up @@ -1426,7 +1426,9 @@ def closed(self) -> IntervalClosedType:
),
}
)
def set_closed(self: IntervalArrayT, closed: IntervalClosedType) -> IntervalArrayT:
def set_closed(
self: IntervalArrayT, closed: IntervalInclusiveType
) -> IntervalArrayT:
warnings.warn(
"set_closed is deprecated and will be removed in a future version. "
"Use set_inclusive instead.",
Expand Down Expand Up @@ -1478,7 +1480,7 @@ def set_closed(self: IntervalArrayT, closed: IntervalClosedType) -> IntervalArra
}
)
def set_inclusive(
self: IntervalArrayT, inclusive: IntervalClosedType
self: IntervalArrayT, inclusive: IntervalInclusiveType
) -> IntervalArrayT:
if inclusive not in VALID_CLOSED:
msg = f"invalid option for 'inclusive': {inclusive}"
Expand Down
9 changes: 7 additions & 2 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pandas._typing import (
Dtype,
DtypeObj,
IntervalInclusiveType,
Ordered,
npt,
type_t,
Expand Down Expand Up @@ -1091,7 +1092,7 @@ class IntervalDtype(PandasExtensionDtype):
def __new__(
cls,
subtype=None,
inclusive: str_type | None = None,
inclusive: IntervalInclusiveType | None = None,
closed: None | lib.NoDefault = lib.no_default,
):
from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -1140,7 +1141,11 @@ def __new__(
"'inclusive' keyword does not match value "
"specified in dtype string"
)
inclusive = gd["inclusive"]
# Incompatible types in assignment (expression has type
# "Union[str, Any]", variable has type
# "Optional[Union[Literal['left', 'right'],
# Literal['both', 'neither']]]")
inclusive = gd["inclusive"] # type: ignore[assignment]

try:
subtype = pandas_dtype(subtype)
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
IgnoreRaise,
IndexKeyFunc,
IndexLabel,
IntervalClosedType,
IntervalInclusiveType,
JSONSerializable,
Level,
Manager,
Expand Down Expand Up @@ -8066,7 +8066,7 @@ def between_time(
end_time,
include_start: bool_t | lib.NoDefault = lib.no_default,
include_end: bool_t | lib.NoDefault = lib.no_default,
inclusive: IntervalClosedType | None = None,
inclusive: IntervalInclusiveType | None = None,
axis=None,
) -> NDFrameT:
"""
Expand Down Expand Up @@ -8172,7 +8172,7 @@ def between_time(
left = True if include_start is lib.no_default else include_start
right = True if include_end is lib.no_default else include_end

inc_dict: dict[tuple[bool_t, bool_t], IntervalClosedType] = {
inc_dict: dict[tuple[bool_t, bool_t], IntervalInclusiveType] = {
(True, True): "both",
(True, False): "left",
(False, True): "right",
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pandas._typing import (
Dtype,
DtypeObj,
IntervalClosedType,
IntervalInclusiveType,
IntervalLeftRight,
npt,
)
Expand Down Expand Up @@ -920,7 +920,7 @@ def date_range(
normalize: bool = False,
name: Hashable = None,
closed: Literal["left", "right"] | None | lib.NoDefault = lib.no_default,
inclusive: IntervalClosedType | None = None,
inclusive: IntervalInclusiveType | None = None,
**kwargs,
) -> DatetimeIndex:
"""
Expand Down Expand Up @@ -1126,7 +1126,7 @@ def bdate_range(
weekmask=None,
holidays=None,
closed: IntervalLeftRight | lib.NoDefault | None = lib.no_default,
inclusive: IntervalClosedType | None = None,
inclusive: IntervalInclusiveType | None = None,
**kwargs,
) -> DatetimeIndex:
"""
Expand Down
14 changes: 7 additions & 7 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pandas._typing import (
Dtype,
DtypeObj,
IntervalClosedType,
IntervalInclusiveType,
npt,
)
from pandas.errors import InvalidIndexError
Expand Down Expand Up @@ -198,7 +198,7 @@ class IntervalIndex(ExtensionIndex):
_typ = "intervalindex"

# annotate properties pinned via inherit_names
inclusive: IntervalClosedType
inclusive: IntervalInclusiveType
is_non_overlapping_monotonic: bool
closed_left: bool
closed_right: bool
Expand All @@ -217,7 +217,7 @@ class IntervalIndex(ExtensionIndex):
def __new__(
cls,
data,
inclusive=None,
inclusive: IntervalInclusiveType | None = None,
dtype: Dtype | None = None,
copy: bool = False,
name: Hashable = None,
Expand Down Expand Up @@ -266,7 +266,7 @@ def closed(self):
def from_breaks(
cls,
breaks,
inclusive=None,
inclusive: IntervalInclusiveType | None = None,
name: Hashable = None,
copy: bool = False,
dtype: Dtype | None = None,
Expand Down Expand Up @@ -302,7 +302,7 @@ def from_arrays(
cls,
left,
right,
inclusive=None,
inclusive: IntervalInclusiveType | None = None,
name: Hashable = None,
copy: bool = False,
dtype: Dtype | None = None,
Expand Down Expand Up @@ -337,7 +337,7 @@ def from_arrays(
def from_tuples(
cls,
data,
inclusive=None,
inclusive: IntervalInclusiveType | None = None,
name: Hashable = None,
copy: bool = False,
dtype: Dtype | None = None,
Expand Down Expand Up @@ -989,7 +989,7 @@ def interval_range(
periods=None,
freq=None,
name: Hashable = None,
inclusive: IntervalClosedType | None = None,
inclusive: IntervalInclusiveType | None = None,
) -> IntervalIndex:
"""
Return a fixed frequency IntervalIndex.
Expand Down
7 changes: 4 additions & 3 deletions pandas/io/formats/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Axis,
FilePath,
IndexLabel,
IntervalInclusiveType,
Level,
QuantileInterpolation,
Scalar,
Expand Down Expand Up @@ -3479,7 +3480,7 @@ def highlight_between(
axis: Axis | None = 0,
left: Scalar | Sequence | None = None,
right: Scalar | Sequence | None = None,
inclusive: str = "both",
inclusive: IntervalInclusiveType = "both",
props: str | None = None,
) -> Styler:
"""
Expand Down Expand Up @@ -3584,7 +3585,7 @@ def highlight_quantile(
q_left: float = 0.0,
q_right: float = 1.0,
interpolation: QuantileInterpolation = "linear",
inclusive: str = "both",
inclusive: IntervalInclusiveType = "both",
props: str | None = None,
) -> Styler:
"""
Expand Down Expand Up @@ -3969,7 +3970,7 @@ def _highlight_between(
props: str,
left: Scalar | Sequence | np.ndarray | NDFrame | None = None,
right: Scalar | Sequence | np.ndarray | NDFrame | None = None,
inclusive: bool | str = True,
inclusive: bool | IntervalInclusiveType = True,
) -> np.ndarray:
"""
Return an array of css props based on condition of data values within given range.
Expand Down
3 changes: 2 additions & 1 deletion pandas/util/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np

from pandas._typing import IntervalInclusiveType
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -487,7 +488,7 @@ def validate_endpoints(closed: str | None) -> tuple[bool, bool]:
return left_closed, right_closed


def validate_inclusive(inclusive: str | None) -> tuple[bool, bool]:
def validate_inclusive(inclusive: IntervalInclusiveType | None) -> tuple[bool, bool]:
"""
Check that the `inclusive` argument is among {"both", "neither", "left", "right"}.
Expand Down

0 comments on commit 5506476

Please sign in to comment.