Skip to content

Commit

Permalink
Fix for issue pandas-dev#57268 - ENH: Preserve input start/end type i…
Browse files Browse the repository at this point in the history
…n interval… (pandas-dev#57399)

* Fix for issue pandas-dev#57268 - ENH: Preserve input start/end type in interval_range

* issue pandas-dev#57268 - github actions resolution

* Use generated datatype from breaks

* Ruff - Pre-commit issue fix

* Fix for issue pandas-dev#57268 - floating point support

* int - float dtype compatability

* whatsnew documentation update

* OS based varaible access

* Fixing failed unit test cases

* pytest - interval passsed

* Python backwards compatability

* Pytest

* Fixing PyLint and mypy issues

* dtype specification

* Conditional statement simplification

* remove redundant code blocks

* Changing whatsnew to interval section

* Passing expected in parameterize

* Update doc/source/whatsnew/v3.0.0.rst

Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com>

* Update pandas/core/indexes/interval.py

Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com>

---------

Co-authored-by: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com>
  • Loading branch information
2 people authored and pmhatre1 committed May 7, 2024
1 parent 2fa5f22 commit 0dcfc51
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ Strings

Interval
^^^^^^^^
-
- Bug in :func:`interval_range` where start and end numeric types were always cast to 64 bit (:issue:`57268`)
-

Indexing
Expand Down
23 changes: 21 additions & 2 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,9 +1101,23 @@ def interval_range(
breaks: np.ndarray | TimedeltaIndex | DatetimeIndex

if is_number(endpoint):
dtype: np.dtype = np.dtype("int64")
if com.all_not_none(start, end, freq):
if (
isinstance(start, (float, np.float16))
or isinstance(end, (float, np.float16))
or isinstance(freq, (float, np.float16))
):
dtype = np.dtype("float64")
elif (
isinstance(start, (np.integer, np.floating))
and isinstance(end, (np.integer, np.floating))
and start.dtype == end.dtype
):
dtype = start.dtype
# 0.1 ensures we capture end
breaks = np.arange(start, end + (freq * 0.1), freq)
breaks = maybe_downcast_numeric(breaks, dtype)
else:
# compute the period/start/end if unspecified (at most one)
if periods is None:
Expand All @@ -1122,7 +1136,7 @@ def interval_range(
# expected "ndarray[Any, Any]" [
breaks = maybe_downcast_numeric(
breaks, # type: ignore[arg-type]
np.dtype("int64"),
dtype,
)
else:
# delegate to the appropriate range function
Expand All @@ -1131,4 +1145,9 @@ def interval_range(
else:
breaks = timedelta_range(start=start, end=end, periods=periods, freq=freq)

return IntervalIndex.from_breaks(breaks, name=name, closed=closed)
return IntervalIndex.from_breaks(
breaks,
name=name,
closed=closed,
dtype=IntervalDtype(subtype=breaks.dtype, closed=closed),
)
14 changes: 14 additions & 0 deletions pandas/tests/indexes/interval/test_interval_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,20 @@ def test_float_subtype(self, start, end, freq):
expected = "int64" if is_integer(start + end) else "float64"
assert result == expected

@pytest.mark.parametrize(
"start, end, expected",
[
(np.int8(1), np.int8(10), np.dtype("int8")),
(np.int8(1), np.float16(10), np.dtype("float64")),
(np.float32(1), np.float32(10), np.dtype("float32")),
(1, 10, np.dtype("int64")),
(1, 10.0, np.dtype("float64")),
],
)
def test_interval_dtype(self, start, end, expected):
result = interval_range(start=start, end=end).dtype.subtype
assert result == expected

def test_interval_range_fractional_period(self):
# float value for periods
expected = interval_range(start=0, periods=10)
Expand Down

0 comments on commit 0dcfc51

Please sign in to comment.