11from operator import le , lt
22import textwrap
3+ from typing import TYPE_CHECKING , Optional , Tuple , Union , cast
34
45import numpy as np
56
1112 IntervalMixin ,
1213 intervals_to_interval_bounds ,
1314)
15+ from pandas ._typing import ArrayLike , Dtype
1416from pandas .compat .numpy import function as nv
1517from pandas .util ._decorators import Appender
1618
1719from pandas .core .dtypes .cast import maybe_convert_platform
1820from pandas .core .dtypes .common import (
1921 is_categorical_dtype ,
2022 is_datetime64_any_dtype ,
23+ is_dtype_equal ,
2124 is_float_dtype ,
25+ is_integer ,
2226 is_integer_dtype ,
2327 is_interval_dtype ,
2428 is_list_like ,
4549from pandas .core .indexers import check_array_indexer
4650from pandas .core .indexes .base import ensure_index
4751
52+ if TYPE_CHECKING :
53+ from pandas import Index
54+ from pandas .core .arrays import DatetimeArray , TimedeltaArray
55+
4856_interval_shared_docs = {}
4957
5058_shared_docs_kwargs = dict (
@@ -169,6 +177,17 @@ def __new__(
169177 left = data ._left
170178 right = data ._right
171179 closed = closed or data .closed
180+
181+ if dtype is None or data .dtype == dtype :
182+ # This path will preserve id(result._combined)
183+ # TODO: could also validate dtype before going to simple_new
184+ combined = data ._combined
185+ if copy :
186+ combined = combined .copy ()
187+ result = cls ._simple_new (combined , closed = closed )
188+ if verify_integrity :
189+ result ._validate ()
190+ return result
172191 else :
173192
174193 # don't allow scalars
@@ -186,83 +205,22 @@ def __new__(
186205 )
187206 closed = closed or infer_closed
188207
189- return cls ._simple_new (
190- left ,
191- right ,
192- closed ,
193- copy = copy ,
194- dtype = dtype ,
195- verify_integrity = verify_integrity ,
196- )
208+ closed = closed or "right"
209+ left , right = _maybe_cast_inputs (left , right , copy , dtype )
210+ combined = _get_combined_data (left , right )
211+ result = cls ._simple_new (combined , closed = closed )
212+ if verify_integrity :
213+ result ._validate ()
214+ return result
197215
198216 @classmethod
199- def _simple_new (
200- cls , left , right , closed = None , copy = False , dtype = None , verify_integrity = True
201- ):
217+ def _simple_new (cls , data , closed = "right" ):
202218 result = IntervalMixin .__new__ (cls )
203219
204- closed = closed or "right"
205- left = ensure_index (left , copy = copy )
206- right = ensure_index (right , copy = copy )
207-
208- if dtype is not None :
209- # GH 19262: dtype must be an IntervalDtype to override inferred
210- dtype = pandas_dtype (dtype )
211- if not is_interval_dtype (dtype ):
212- msg = f"dtype must be an IntervalDtype, got { dtype } "
213- raise TypeError (msg )
214- elif dtype .subtype is not None :
215- left = left .astype (dtype .subtype )
216- right = right .astype (dtype .subtype )
217-
218- # coerce dtypes to match if needed
219- if is_float_dtype (left ) and is_integer_dtype (right ):
220- right = right .astype (left .dtype )
221- elif is_float_dtype (right ) and is_integer_dtype (left ):
222- left = left .astype (right .dtype )
223-
224- if type (left ) != type (right ):
225- msg = (
226- f"must not have differing left [{ type (left ).__name__ } ] and "
227- f"right [{ type (right ).__name__ } ] types"
228- )
229- raise ValueError (msg )
230- elif is_categorical_dtype (left .dtype ) or is_string_dtype (left .dtype ):
231- # GH 19016
232- msg = (
233- "category, object, and string subtypes are not supported "
234- "for IntervalArray"
235- )
236- raise TypeError (msg )
237- elif isinstance (left , ABCPeriodIndex ):
238- msg = "Period dtypes are not supported, use a PeriodIndex instead"
239- raise ValueError (msg )
240- elif isinstance (left , ABCDatetimeIndex ) and str (left .tz ) != str (right .tz ):
241- msg = (
242- "left and right must have the same time zone, got "
243- f"'{ left .tz } ' and '{ right .tz } '"
244- )
245- raise ValueError (msg )
246-
247- # For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
248- from pandas .core .ops .array_ops import maybe_upcast_datetimelike_array
249-
250- left = maybe_upcast_datetimelike_array (left )
251- left = extract_array (left , extract_numpy = True )
252- right = maybe_upcast_datetimelike_array (right )
253- right = extract_array (right , extract_numpy = True )
254-
255- lbase = getattr (left , "_ndarray" , left ).base
256- rbase = getattr (right , "_ndarray" , right ).base
257- if lbase is not None and lbase is rbase :
258- # If these share data, then setitem could corrupt our IA
259- right = right .copy ()
260-
261- result ._left = left
262- result ._right = right
220+ result ._combined = data
221+ result ._left = data [:, 0 ]
222+ result ._right = data [:, 1 ]
263223 result ._closed = closed
264- if verify_integrity :
265- result ._validate ()
266224 return result
267225
268226 @classmethod
@@ -397,10 +355,16 @@ def from_breaks(cls, breaks, closed="right", copy=False, dtype=None):
397355 def from_arrays (cls , left , right , closed = "right" , copy = False , dtype = None ):
398356 left = maybe_convert_platform_interval (left )
399357 right = maybe_convert_platform_interval (right )
358+ if len (left ) != len (right ):
359+ raise ValueError ("left and right must have the same length" )
400360
401- return cls ._simple_new (
402- left , right , closed , copy = copy , dtype = dtype , verify_integrity = True
403- )
361+ closed = closed or "right"
362+ left , right = _maybe_cast_inputs (left , right , copy , dtype )
363+ combined = _get_combined_data (left , right )
364+
365+ result = cls ._simple_new (combined , closed )
366+ result ._validate ()
367+ return result
404368
405369 _interval_shared_docs ["from_tuples" ] = textwrap .dedent (
406370 """
@@ -506,19 +470,6 @@ def _validate(self):
506470 msg = "left side of interval must be <= right side"
507471 raise ValueError (msg )
508472
509- def _shallow_copy (self , left , right ):
510- """
511- Return a new IntervalArray with the replacement attributes
512-
513- Parameters
514- ----------
515- left : Index
516- Values to be used for the left-side of the intervals.
517- right : Index
518- Values to be used for the right-side of the intervals.
519- """
520- return self ._simple_new (left , right , closed = self .closed , verify_integrity = False )
521-
522473 # ---------------------------------------------------------------------
523474 # Descriptive
524475
@@ -546,18 +497,20 @@ def __len__(self) -> int:
546497
547498 def __getitem__ (self , key ):
548499 key = check_array_indexer (self , key )
549- left = self ._left [key ]
550- right = self ._right [key ]
551500
552- if not isinstance (left , (np .ndarray , ExtensionArray )):
553- # scalar
554- if is_scalar (left ) and isna (left ):
501+ result = self ._combined [key ]
502+
503+ if is_integer (key ):
504+ left , right = result [0 ], result [1 ]
505+ if isna (left ):
555506 return self ._fill_value
556507 return Interval (left , right , self .closed )
557- if np .ndim (left ) > 1 :
508+
509+ # TODO: need to watch out for incorrectly-reducing getitem
510+ if np .ndim (result ) > 2 :
558511 # GH#30588 multi-dimensional indexer disallowed
559512 raise ValueError ("multi-dimensional indexing not allowed" )
560- return self . _shallow_copy ( left , right )
513+ return type ( self ). _simple_new ( result , closed = self . closed )
561514
562515 def __setitem__ (self , key , value ):
563516 value_left , value_right = self ._validate_setitem_value (value )
@@ -651,7 +604,8 @@ def fillna(self, value=None, method=None, limit=None):
651604
652605 left = self .left .fillna (value = value_left )
653606 right = self .right .fillna (value = value_right )
654- return self ._shallow_copy (left , right )
607+ combined = _get_combined_data (left , right )
608+ return type (self )._simple_new (combined , closed = self .closed )
655609
656610 def astype (self , dtype , copy = True ):
657611 """
@@ -693,7 +647,9 @@ def astype(self, dtype, copy=True):
693647 f"Cannot convert { self .dtype } to { dtype } ; subtypes are incompatible"
694648 )
695649 raise TypeError (msg ) from err
696- return self ._shallow_copy (new_left , new_right )
650+ # TODO: do astype directly on self._combined
651+ combined = _get_combined_data (new_left , new_right )
652+ return type (self )._simple_new (combined , closed = self .closed )
697653 elif is_categorical_dtype (dtype ):
698654 return Categorical (np .asarray (self ))
699655 elif isinstance (dtype , StringDtype ):
@@ -734,9 +690,11 @@ def _concat_same_type(cls, to_concat):
734690 raise ValueError ("Intervals must all be closed on the same side." )
735691 closed = closed .pop ()
736692
693+ # TODO: will this mess up on dt64tz?
737694 left = np .concatenate ([interval .left for interval in to_concat ])
738695 right = np .concatenate ([interval .right for interval in to_concat ])
739- return cls ._simple_new (left , right , closed = closed , copy = False )
696+ combined = _get_combined_data (left , right ) # TODO: 1-stage concat
697+ return cls ._simple_new (combined , closed = closed )
740698
741699 def copy (self ):
742700 """
@@ -746,11 +704,8 @@ def copy(self):
746704 -------
747705 IntervalArray
748706 """
749- left = self ._left .copy ()
750- right = self ._right .copy ()
751- closed = self .closed
752- # TODO: Could skip verify_integrity here.
753- return type (self ).from_arrays (left , right , closed = closed )
707+ combined = self ._combined .copy ()
708+ return type (self )._simple_new (combined , closed = self .closed )
754709
755710 def isna (self ) -> np .ndarray :
756711 return isna (self ._left )
@@ -843,7 +798,8 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs):
843798 self ._right , indices , allow_fill = allow_fill , fill_value = fill_right
844799 )
845800
846- return self ._shallow_copy (left_take , right_take )
801+ combined = _get_combined_data (left_take , right_take )
802+ return type (self )._simple_new (combined , closed = self .closed )
847803
848804 def _validate_listlike (self , value ):
849805 # list-like of intervals
@@ -1170,10 +1126,7 @@ def set_closed(self, closed):
11701126 if closed not in VALID_CLOSED :
11711127 msg = f"invalid option for 'closed': { closed } "
11721128 raise ValueError (msg )
1173-
1174- return type (self )._simple_new (
1175- left = self ._left , right = self ._right , closed = closed , verify_integrity = False
1176- )
1129+ return type (self )._simple_new (self ._combined , closed = closed )
11771130
11781131 _interval_shared_docs [
11791132 "is_non_overlapping_monotonic"
@@ -1314,9 +1267,8 @@ def to_tuples(self, na_tuple=True):
13141267 @Appender (_extension_array_shared_docs ["repeat" ] % _shared_docs_kwargs )
13151268 def repeat (self , repeats , axis = None ):
13161269 nv .validate_repeat (tuple (), dict (axis = axis ))
1317- left_repeat = self .left .repeat (repeats )
1318- right_repeat = self .right .repeat (repeats )
1319- return self ._shallow_copy (left = left_repeat , right = right_repeat )
1270+ combined = self ._combined .repeat (repeats , 0 )
1271+ return type (self )._simple_new (combined , closed = self .closed )
13201272
13211273 _interval_shared_docs ["contains" ] = textwrap .dedent (
13221274 """
@@ -1399,3 +1351,92 @@ def maybe_convert_platform_interval(values):
13991351 values = np .asarray (values )
14001352
14011353 return maybe_convert_platform (values )
1354+
1355+
1356+ def _maybe_cast_inputs (
1357+ left_orig : Union ["Index" , ArrayLike ],
1358+ right_orig : Union ["Index" , ArrayLike ],
1359+ copy : bool ,
1360+ dtype : Optional [Dtype ],
1361+ ) -> Tuple ["Index" , "Index" ]:
1362+ left = ensure_index (left_orig , copy = copy )
1363+ right = ensure_index (right_orig , copy = copy )
1364+
1365+ if dtype is not None :
1366+ # GH#19262: dtype must be an IntervalDtype to override inferred
1367+ dtype = pandas_dtype (dtype )
1368+ if not is_interval_dtype (dtype ):
1369+ msg = f"dtype must be an IntervalDtype, got { dtype } "
1370+ raise TypeError (msg )
1371+ dtype = cast (IntervalDtype , dtype )
1372+ if dtype .subtype is not None :
1373+ left = left .astype (dtype .subtype )
1374+ right = right .astype (dtype .subtype )
1375+
1376+ # coerce dtypes to match if needed
1377+ if is_float_dtype (left ) and is_integer_dtype (right ):
1378+ right = right .astype (left .dtype )
1379+ elif is_float_dtype (right ) and is_integer_dtype (left ):
1380+ left = left .astype (right .dtype )
1381+
1382+ if type (left ) != type (right ):
1383+ msg = (
1384+ f"must not have differing left [{ type (left ).__name__ } ] and "
1385+ f"right [{ type (right ).__name__ } ] types"
1386+ )
1387+ raise ValueError (msg )
1388+ elif is_categorical_dtype (left .dtype ) or is_string_dtype (left .dtype ):
1389+ # GH#19016
1390+ msg = (
1391+ "category, object, and string subtypes are not supported "
1392+ "for IntervalArray"
1393+ )
1394+ raise TypeError (msg )
1395+ elif isinstance (left , ABCPeriodIndex ):
1396+ msg = "Period dtypes are not supported, use a PeriodIndex instead"
1397+ raise ValueError (msg )
1398+ elif isinstance (left , ABCDatetimeIndex ) and not is_dtype_equal (
1399+ left .dtype , right .dtype
1400+ ):
1401+ left_arr = cast ("DatetimeArray" , left ._data )
1402+ right_arr = cast ("DatetimeArray" , right ._data )
1403+ msg = (
1404+ "left and right must have the same time zone, got "
1405+ f"'{ left_arr .tz } ' and '{ right_arr .tz } '"
1406+ )
1407+ raise ValueError (msg )
1408+
1409+ return left , right
1410+
1411+
1412+ def _get_combined_data (
1413+ left : Union ["Index" , ArrayLike ], right : Union ["Index" , ArrayLike ]
1414+ ) -> Union [np .ndarray , "DatetimeArray" , "TimedeltaArray" ]:
1415+ # For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
1416+ from pandas .core .ops .array_ops import maybe_upcast_datetimelike_array
1417+
1418+ left = maybe_upcast_datetimelike_array (left )
1419+ left = extract_array (left , extract_numpy = True )
1420+ right = maybe_upcast_datetimelike_array (right )
1421+ right = extract_array (right , extract_numpy = True )
1422+
1423+ lbase = getattr (left , "_ndarray" , left ).base
1424+ rbase = getattr (right , "_ndarray" , right ).base
1425+ if lbase is not None and lbase is rbase :
1426+ # If these share data, then setitem could corrupt our IA
1427+ right = right .copy ()
1428+
1429+ if isinstance (left , np .ndarray ):
1430+ assert isinstance (right , np .ndarray ) # for mypy
1431+ combined = np .concatenate (
1432+ [left .reshape (- 1 , 1 ), right .reshape (- 1 , 1 )],
1433+ axis = 1 ,
1434+ )
1435+ else :
1436+ left = cast (Union ["DatetimeArray" , "TimedeltaArray" ], left )
1437+ right = cast (Union ["DatetimeArray" , "TimedeltaArray" ], right )
1438+ combined = type (left )._concat_same_type (
1439+ [left .reshape (- 1 , 1 ), right .reshape (- 1 , 1 )],
1440+ axis = 1 ,
1441+ )
1442+ return combined
0 commit comments