Skip to content

Commit

Permalink
dispatch Series[datetime64] comparison ops to DatetimeIndex (#19800)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and jreback committed Mar 1, 2018
1 parent 9242248 commit 87fefe2
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 31 deletions.
8 changes: 3 additions & 5 deletions pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,9 @@ def wrapper(self, other):
result = func(np.asarray(other))
result = com._values_from_object(result)

if isinstance(other, Index):
o_mask = other.values.view('i8') == libts.iNaT
else:
o_mask = other.view('i8') == libts.iNaT

# Make sure to pass an array to result[...]; indexing with
# Series breaks with older version of numpy
o_mask = np.array(isna(other))
if o_mask.any():
result[o_mask] = nat_result

Expand Down
34 changes: 18 additions & 16 deletions pandas/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import numpy as np
import pandas as pd

from pandas._libs import (lib, index as libindex,
algos as libalgos, ops as libops)
from pandas._libs import algos as libalgos, ops as libops

from pandas import compat
from pandas.util._decorators import Appender
Expand Down Expand Up @@ -1127,24 +1126,20 @@ def na_op(x, y):
# integer comparisons

# we have a datetime/timedelta and may need to convert
assert not needs_i8_conversion(x)
mask = None
if (needs_i8_conversion(x) or
(not is_scalar(y) and needs_i8_conversion(y))):

if is_scalar(y):
mask = isna(x)
y = libindex.convert_scalar(x, com._values_from_object(y))
else:
mask = isna(x) | isna(y)
y = y.view('i8')
if not is_scalar(y) and needs_i8_conversion(y):
mask = isna(x) | isna(y)
y = y.view('i8')
x = x.view('i8')

try:
method = getattr(x, name, None)
if method is not None:
with np.errstate(all='ignore'):
result = getattr(x, name)(y)
result = method(y)
if result is NotImplemented:
raise TypeError("invalid type comparison")
except AttributeError:
else:
result = op(x, y)

if mask is not None and mask.any():
Expand Down Expand Up @@ -1174,6 +1169,14 @@ def wrapper(self, other, axis=None):
return self._constructor(res_values, index=self.index,
name=res_name)

if is_datetime64_dtype(self) or is_datetime64tz_dtype(self):
# Dispatch to DatetimeIndex to ensure identical
# Series/Index behavior
res_values = dispatch_to_index_op(op, self, other,
pd.DatetimeIndex)
return self._constructor(res_values, index=self.index,
name=res_name)

elif is_timedelta64_dtype(self):
res_values = dispatch_to_index_op(op, self, other,
pd.TimedeltaIndex)
Expand All @@ -1191,8 +1194,7 @@ def wrapper(self, other, axis=None):
elif isinstance(other, (np.ndarray, pd.Index)):
# do not check length of zerodim array
# as it will broadcast
if (not is_scalar(lib.item_from_zerodim(other)) and
len(self) != len(other)):
if other.ndim != 0 and len(self) != len(other):
raise ValueError('Lengths must match to compare')

res_values = na_op(self.values, np.asarray(other))
Expand Down
4 changes: 2 additions & 2 deletions pandas/tests/indexes/datetimes/test_partial_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from datetime import datetime, date
from datetime import datetime
import numpy as np
import pandas as pd
import operator as op
Expand Down Expand Up @@ -349,7 +349,7 @@ def test_loc_datetime_length_one(self):

@pytest.mark.parametrize('datetimelike', [
Timestamp('20130101'), datetime(2013, 1, 1),
date(2013, 1, 1), np.datetime64('2013-01-01T00:00', 'ns')])
np.datetime64('2013-01-01T00:00', 'ns')])
@pytest.mark.parametrize('op,expected', [
(op.lt, [True, False, False, False]),
(op.le, [True, True, False, False]),
Expand Down
17 changes: 17 additions & 0 deletions pandas/tests/series/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ def test_ser_cmp_result_names(self, names, op):


class TestTimestampSeriesComparison(object):
def test_dt64ser_cmp_date_invalid(self):
# GH#19800 datetime.date comparison raises to
# match DatetimeIndex/Timestamp. This also matches the behavior
# of stdlib datetime.datetime
ser = pd.Series(pd.date_range('20010101', periods=10), name='dates')
date = ser.iloc[0].to_pydatetime().date()
assert not (ser == date).any()
assert (ser != date).all()
with pytest.raises(TypeError):
ser > date
with pytest.raises(TypeError):
ser < date
with pytest.raises(TypeError):
ser >= date
with pytest.raises(TypeError):
ser <= date

def test_dt64ser_cmp_period_scalar(self):
ser = Series(pd.period_range('2000-01-01', periods=10, freq='D'))
val = Period('2000-01-04', freq='D')
Expand Down
23 changes: 15 additions & 8 deletions pandas/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pandas as pd
import pandas.compat as compat
from pandas.core.dtypes.common import (
is_object_dtype, is_datetimetz,
is_object_dtype, is_datetimetz, is_datetime64_dtype,
needs_i8_conversion)
import pandas.util.testing as tm
from pandas import (Series, Index, DatetimeIndex, TimedeltaIndex,
Expand Down Expand Up @@ -296,14 +296,21 @@ def test_none_comparison(self):
# result = None != o # noqa
# assert result.iat[0]
# assert result.iat[1]
if (is_datetime64_dtype(o) or is_datetimetz(o)):
# Following DatetimeIndex (and Timestamp) convention,
# inequality comparisons with Series[datetime64] raise
with pytest.raises(TypeError):
None > o
with pytest.raises(TypeError):
o > None
else:
result = None > o
assert not result.iat[0]
assert not result.iat[1]

result = None > o
assert not result.iat[0]
assert not result.iat[1]

result = o < None
assert not result.iat[0]
assert not result.iat[1]
result = o < None
assert not result.iat[0]
assert not result.iat[1]

def test_ndarray_compat_properties(self):

Expand Down

0 comments on commit 87fefe2

Please sign in to comment.