Skip to content

Commit

Permalink
API: ensure IntervalIndex.left/right are 64bit if numeric, part II (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
topper-123 committed Jan 10, 2023
1 parent 1d63474 commit 939d0ba
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 28 deletions.
4 changes: 3 additions & 1 deletion pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,5 +1787,7 @@ def _maybe_convert_platform_interval(values) -> ArrayLike:
values = extract_array(values, extract_numpy=True)

if not hasattr(values, "dtype"):
return np.asarray(values)
values = np.asarray(values)
if is_integer_dtype(values) and values.dtype != np.int64:
values = values.astype(np.int64)
return values
25 changes: 25 additions & 0 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ArrayLike,
Dtype,
DtypeObj,
NumpyIndexT,
Scalar,
npt,
)
Expand Down Expand Up @@ -65,6 +66,7 @@
is_numeric_dtype,
is_object_dtype,
is_scalar,
is_signed_integer_dtype,
is_string_dtype,
is_timedelta64_dtype,
is_unsigned_integer_dtype,
Expand Down Expand Up @@ -412,6 +414,29 @@ def trans(x):
return result


def maybe_upcast_numeric_to_64bit(arr: NumpyIndexT) -> NumpyIndexT:
"""
If array is a int/uint/float bit size lower than 64 bit, upcast it to 64 bit.
Parameters
----------
arr : ndarray or ExtensionArray
Returns
-------
ndarray or ExtensionArray
"""
dtype = arr.dtype
if is_signed_integer_dtype(dtype) and dtype != np.int64:
return arr.astype(np.int64)
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
return arr.astype(np.uint64)
elif is_float_dtype(dtype) and dtype != np.float64:
return arr.astype(np.float64)
else:
return arr


def maybe_cast_pointwise_result(
result: ArrayLike,
dtype: DtypeObj,
Expand Down
25 changes: 6 additions & 19 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
infer_dtype_from_scalar,
maybe_box_datetimelike,
maybe_downcast_numeric,
maybe_upcast_numeric_to_64bit,
)
from pandas.core.dtypes.common import (
ensure_platform_int,
Expand All @@ -59,8 +60,6 @@
is_number,
is_object_dtype,
is_scalar,
is_signed_integer_dtype,
is_unsigned_integer_dtype,
)
from pandas.core.dtypes.dtypes import IntervalDtype
from pandas.core.dtypes.missing import is_valid_na_for_dtype
Expand Down Expand Up @@ -342,8 +341,11 @@ def from_tuples(
# "Union[IndexEngine, ExtensionEngine]" in supertype "Index"
@cache_readonly
def _engine(self) -> IntervalTree: # type: ignore[override]
# IntervalTree does not supports numpy array unless they are 64 bit
left = self._maybe_convert_i8(self.left)
left = maybe_upcast_numeric_to_64bit(left)
right = self._maybe_convert_i8(self.right)
right = maybe_upcast_numeric_to_64bit(right)
return IntervalTree(left, right, closed=self.closed)

def __contains__(self, key: Any) -> bool:
Expand Down Expand Up @@ -520,13 +522,12 @@ def _maybe_convert_i8(self, key):
The original key if no conversion occurred, int if converted scalar,
Int64Index if converted list-like.
"""
original = key
if is_list_like(key):
key = ensure_index(key)
key = self._maybe_convert_numeric_to_64bit(key)
key = maybe_upcast_numeric_to_64bit(key)

if not self._needs_i8_conversion(key):
return original
return key

scalar = is_scalar(key)
if is_interval_dtype(key) or isinstance(key, Interval):
Expand Down Expand Up @@ -569,20 +570,6 @@ def _maybe_convert_i8(self, key):

return key_i8

def _maybe_convert_numeric_to_64bit(self, idx: Index) -> Index:
# IntervalTree only supports 64 bit numpy array
dtype = idx.dtype
if np.issubclass_(dtype.type, np.number):
return idx
elif is_signed_integer_dtype(dtype) and dtype != np.int64:
return idx.astype(np.int64)
elif is_unsigned_integer_dtype(dtype) and dtype != np.uint64:
return idx.astype(np.uint64)
elif is_float_dtype(dtype) and dtype != np.float64:
return idx.astype(np.float64)
else:
return idx

def _searchsorted_monotonic(self, label, side: Literal["left", "right"] = "left"):
if not self.is_non_overlapping_monotonic:
raise KeyError(
Expand Down
25 changes: 17 additions & 8 deletions pandas/tests/indexes/interval/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,27 +415,36 @@ def test_maybe_convert_i8_nat(self, breaks):
tm.assert_index_equal(result, expected)

@pytest.mark.parametrize(
"breaks",
[np.arange(5, dtype="int64"), np.arange(5, dtype="float64")],
ids=lambda x: str(x.dtype),
"make_key",
[lambda breaks: breaks, list],
ids=["lambda", "list"],
)
def test_maybe_convert_i8_numeric(self, make_key, any_real_numpy_dtype):
# GH 20636
breaks = np.arange(5, dtype=any_real_numpy_dtype)
index = IntervalIndex.from_breaks(breaks)
key = make_key(breaks)

result = index._maybe_convert_i8(key)
expected = Index(key)
tm.assert_index_equal(result, expected)

@pytest.mark.parametrize(
"make_key",
[
IntervalIndex.from_breaks,
lambda breaks: Interval(breaks[0], breaks[1]),
lambda breaks: breaks,
lambda breaks: breaks[0],
list,
],
ids=["IntervalIndex", "Interval", "Index", "scalar", "list"],
ids=["IntervalIndex", "Interval", "scalar"],
)
def test_maybe_convert_i8_numeric(self, breaks, make_key):
def test_maybe_convert_i8_numeric_identical(self, make_key, any_real_numpy_dtype):
# GH 20636
breaks = np.arange(5, dtype=any_real_numpy_dtype)
index = IntervalIndex.from_breaks(breaks)
key = make_key(breaks)

# no conversion occurs for numeric
# test if _maybe_convert_i8 won't change key if an Interval or IntervalIndex
result = index._maybe_convert_i8(key)
assert result is key

Expand Down

0 comments on commit 939d0ba

Please sign in to comment.