Skip to content
27 changes: 23 additions & 4 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,12 +702,31 @@ def take(self, indices, allow_fill=False, fill_value=None):

@classmethod
def _concat_same_type(cls, to_concat):
dtypes = {x.dtype for x in to_concat}
assert len(dtypes) == 1
dtype = list(dtypes)[0]

# do not pass tz to set because tzlocal cannot be hashed
dtypes = {str(x.dtype) for x in to_concat}
if len(dtypes) != 1:
raise ValueError("to_concat must have the same dtype (tz)", dtypes)

obj = to_concat[0]
dtype = obj.dtype

values = np.concatenate([x.asi8 for x in to_concat])
return cls(values, dtype=dtype)

if is_period_dtype(to_concat[0].dtype):
new_freq = obj.freq
else:
# GH 3232: If the concat result is evenly spaced, we can retain the
# original frequency
new_freq = None
to_concat = [x for x in to_concat if len(x)]

if obj.freq is not None and all(x.freq == obj.freq for x in to_concat):
pairs = zip(to_concat[:-1], to_concat[1:])
if all(pair[0][-1] + obj.freq == pair[1][0] for pair in pairs):
new_freq = obj.freq

return cls._simple_new(values, dtype=dtype, freq=new_freq)

def copy(self):
values = self.asi8.copy()
Expand Down
14 changes: 1 addition & 13 deletions pandas/core/indexes/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np

from pandas._libs import NaT, iNaT, join as libjoin, lib
from pandas._libs.algos import unique_deltas
from pandas._libs.tslibs import timezones
from pandas.compat.numpy import function as nv
from pandas.errors import AbstractMethodError
Expand Down Expand Up @@ -515,20 +514,9 @@ def _concat_same_dtype(self, to_concat, name):
Concatenate to_concat which has the same class.
"""

# do not pass tz to set because tzlocal cannot be hashed
if len({str(x.dtype) for x in to_concat}) != 1:
raise ValueError("to_concat must have the same tz")

new_data = type(self._data)._concat_same_type(to_concat)

if not is_period_dtype(self.dtype):
# GH 3232: If the concat result is evenly spaced, we can retain the
# original frequency
is_diff_evenly_spaced = len(unique_deltas(new_data.asi8)) == 1
if is_diff_evenly_spaced:
new_data._freq = self.freq

return type(self)._simple_new(new_data, name=name)
return self._simple_new(new_data, name=name)

def shift(self, periods=1, freq=None):
"""
Expand Down
43 changes: 21 additions & 22 deletions pandas/core/indexes/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,12 +697,28 @@ def _assert_can_do_setop(self, other):
if isinstance(other, PeriodIndex) and self.freq != other.freq:
raise raise_on_incompatible(self, other)

def intersection(self, other, sort=False):
def _setop(self, other, sort, opname: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eventually if you can type / doc-string this

"""
Perform a set operation by dispatching to the Int64Index implementation.
"""
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
res_name = get_op_result_name(self, other)
other = ensure_index(other)

i8self = Int64Index._simple_new(self.asi8)
i8other = Int64Index._simple_new(other.asi8)
i8result = getattr(i8self, opname)(i8other, sort=sort)

parr = type(self._data)(np.asarray(i8result, dtype=np.int64), dtype=self.dtype)
result = type(self)._simple_new(parr, name=res_name)
return result

def intersection(self, other, sort=False):
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
other = ensure_index(other)

if self.equals(other):
return self._get_reconciled_name_object(other)

Expand All @@ -712,35 +728,24 @@ def intersection(self, other, sort=False):
other = other.astype("O")
return this.intersection(other, sort=sort)

i8self = Int64Index._simple_new(self.asi8)
i8other = Int64Index._simple_new(other.asi8)
i8result = i8self.intersection(i8other, sort=sort)

result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name)
return result
return self._setop(other, sort, opname="intersection")

def difference(self, other, sort=None):
self._validate_sort_keyword(sort)
self._assert_can_do_setop(other)
res_name = get_op_result_name(self, other)
other = ensure_index(other)

if self.equals(other):
# pass an empty PeriodArray with the appropriate dtype
return self._shallow_copy(self._data[:0])
return type(self)._simple_new(self._data[:0], name=self.name)

if is_object_dtype(other):
return self.astype(object).difference(other).astype(self.dtype)

elif not is_dtype_equal(self.dtype, other.dtype):
return self

i8self = Int64Index._simple_new(self.asi8)
i8other = Int64Index._simple_new(other.asi8)
i8result = i8self.difference(i8other, sort=sort)

result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name)
return result
return self._setop(other, sort, opname="difference")

def _union(self, other, sort):
if not len(other) or self.equals(other) or not len(self):
Expand All @@ -754,13 +759,7 @@ def _union(self, other, sort):
other = other.astype("O")
return this._union(other, sort=sort)

i8self = Int64Index._simple_new(self.asi8)
i8other = Int64Index._simple_new(other.asi8)
i8result = i8self._union(i8other, sort=sort)

res_name = get_op_result_name(self, other)
result = self._shallow_copy(np.asarray(i8result, dtype=np.int64), name=res_name)
return result
return self._setop(other, sort, opname="_union")

# ------------------------------------------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/test_datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def test_concat_same_type_invalid(self, datetime_index):
else:
other = arr.tz_localize(None)

with pytest.raises(AssertionError):
with pytest.raises(ValueError, match="to_concat must have the same"):
arr._concat_same_type([arr, other])

def test_concat_same_type_different_freq(self):
Expand Down