Skip to content

Commit

Permalink
Merge pull request #7529 from sinhrks/dtibool
Browse files Browse the repository at this point in the history
BUG: DatetimeIndex comparison handles NaT incorrectly
  • Loading branch information
jreback committed Jun 20, 2014
2 parents 441a1f2 + 589d30a commit f8e94fa
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 20 deletions.
1 change: 1 addition & 0 deletions doc/source/v0.14.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ Bug Fixes
- Bug in when writing Stata files where the encoding was ignored (:issue:`7286`)


- Bug in ``DatetimeIndex`` comparison doesn't handle ``NaT`` properly (:issue:`7529`)


- Bug in passing input with ``tzinfo`` to some offsets ``apply``, ``rollforward`` or ``rollback`` resets ``tzinfo`` or raises ``ValueError`` (:issue:`7465`)
Expand Down
33 changes: 23 additions & 10 deletions pandas/tseries/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,35 @@ def wrapper(left, right):
return wrapper


def _dt_index_cmp(opname):
def _dt_index_cmp(opname, nat_result=False):
"""
Wrap comparison operations to convert datetime-like to datetime64
"""
def wrapper(self, other):
func = getattr(super(DatetimeIndex, self), opname)
if isinstance(other, datetime):
if isinstance(other, datetime) or isinstance(other, compat.string_types):
other = _to_m8(other, tz=self.tz)
elif isinstance(other, list):
other = DatetimeIndex(other)
elif isinstance(other, compat.string_types):
other = _to_m8(other, tz=self.tz)
elif not isinstance(other, (np.ndarray, ABCSeries)):
other = _ensure_datetime64(other)
result = func(other)
result = func(other)
if com.isnull(other):
result.fill(nat_result)
else:
if isinstance(other, list):
other = DatetimeIndex(other)
elif not isinstance(other, (np.ndarray, ABCSeries)):
other = _ensure_datetime64(other)
result = func(other)

if isinstance(other, Index):
o_mask = other.values.view('i8') == tslib.iNaT
else:
o_mask = other.view('i8') == tslib.iNaT

if o_mask.any():
result[o_mask] = nat_result

mask = self.asi8 == tslib.iNaT
if mask.any():
result[mask] = nat_result
return result.view(np.ndarray)

return wrapper
Expand Down Expand Up @@ -142,7 +155,7 @@ class DatetimeIndex(DatetimeIndexOpsMixin, Int64Index):
_arrmap = None

__eq__ = _dt_index_cmp('__eq__')
__ne__ = _dt_index_cmp('__ne__')
__ne__ = _dt_index_cmp('__ne__', nat_result=True)
__lt__ = _dt_index_cmp('__lt__')
__gt__ = _dt_index_cmp('__gt__')
__le__ = _dt_index_cmp('__le__')
Expand Down
15 changes: 5 additions & 10 deletions pandas/tseries/period.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,16 +498,11 @@ def dt64arr_to_periodarr(data, freq, tz):

# --- Period index sketch

def _period_index_cmp(opname):
def _period_index_cmp(opname, nat_result=False):
"""
Wrap comparison operations to convert datetime-like to datetime64
"""
def wrapper(self, other):
if opname == '__ne__':
fill_value = True
else:
fill_value = False

if isinstance(other, Period):
func = getattr(self.values, opname)
if other.freq != self.freq:
Expand All @@ -523,7 +518,7 @@ def wrapper(self, other):
mask = (com.mask_missing(self.values, tslib.iNaT) |
com.mask_missing(other.values, tslib.iNaT))
if mask.any():
result[mask] = fill_value
result[mask] = nat_result

return result
else:
Expand All @@ -532,10 +527,10 @@ def wrapper(self, other):
result = func(other.ordinal)

if other.ordinal == tslib.iNaT:
result.fill(fill_value)
result.fill(nat_result)
mask = self.values == tslib.iNaT
if mask.any():
result[mask] = fill_value
result[mask] = nat_result

return result
return wrapper
Expand Down Expand Up @@ -595,7 +590,7 @@ class PeriodIndex(DatetimeIndexOpsMixin, Int64Index):
_allow_period_index_ops = True

__eq__ = _period_index_cmp('__eq__')
__ne__ = _period_index_cmp('__ne__')
__ne__ = _period_index_cmp('__ne__', nat_result=True)
__lt__ = _period_index_cmp('__lt__')
__gt__ = _period_index_cmp('__gt__')
__le__ = _period_index_cmp('__le__')
Expand Down
87 changes: 87 additions & 0 deletions pandas/tseries/tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2179,6 +2179,93 @@ def test_comparisons_coverage(self):
exp = rng == rng
self.assert_numpy_array_equal(result, exp)

def test_comparisons_nat(self):
fidx1 = pd.Index([1.0, np.nan, 3.0, np.nan, 5.0, 7.0])
fidx2 = pd.Index([2.0, 3.0, np.nan, np.nan, 6.0, 7.0])

didx1 = pd.DatetimeIndex(['2014-01-01', pd.NaT, '2014-03-01', pd.NaT,
'2014-05-01', '2014-07-01'])
didx2 = pd.DatetimeIndex(['2014-02-01', '2014-03-01', pd.NaT, pd.NaT,
'2014-06-01', '2014-07-01'])
darr = np.array([np.datetime64('2014-02-01 00:00Z'),
np.datetime64('2014-03-01 00:00Z'),
np.datetime64('nat'), np.datetime64('nat'),
np.datetime64('2014-06-01 00:00Z'),
np.datetime64('2014-07-01 00:00Z')])

if _np_version_under1p7:
# cannot test array because np.datetime('nat') returns today's date
cases = [(fidx1, fidx2), (didx1, didx2)]
else:
cases = [(fidx1, fidx2), (didx1, didx2), (didx1, darr)]

# Check pd.NaT is handles as the same as np.nan
for idx1, idx2 in cases:
result = idx1 < idx2
expected = np.array([True, False, False, False, True, False])
self.assert_numpy_array_equal(result, expected)
result = idx2 > idx1
expected = np.array([True, False, False, False, True, False])
self.assert_numpy_array_equal(result, expected)

result = idx1 <= idx2
expected = np.array([True, False, False, False, True, True])
self.assert_numpy_array_equal(result, expected)
result = idx2 >= idx1
expected = np.array([True, False, False, False, True, True])
self.assert_numpy_array_equal(result, expected)

result = idx1 == idx2
expected = np.array([False, False, False, False, False, True])
self.assert_numpy_array_equal(result, expected)

result = idx1 != idx2
expected = np.array([True, True, True, True, True, False])
self.assert_numpy_array_equal(result, expected)

for idx1, val in [(fidx1, np.nan), (didx1, pd.NaT)]:
result = idx1 < val
expected = np.array([False, False, False, False, False, False])
self.assert_numpy_array_equal(result, expected)
result = idx1 > val
self.assert_numpy_array_equal(result, expected)

result = idx1 <= val
self.assert_numpy_array_equal(result, expected)
result = idx1 >= val
self.assert_numpy_array_equal(result, expected)

result = idx1 == val
self.assert_numpy_array_equal(result, expected)

result = idx1 != val
expected = np.array([True, True, True, True, True, True])
self.assert_numpy_array_equal(result, expected)

# Check pd.NaT is handles as the same as np.nan
for idx1, val in [(fidx1, 3), (didx1, datetime(2014, 3, 1))]:
result = idx1 < val
expected = np.array([True, False, False, False, False, False])
self.assert_numpy_array_equal(result, expected)
result = idx1 > val
expected = np.array([False, False, False, False, True, True])
self.assert_numpy_array_equal(result, expected)

result = idx1 <= val
expected = np.array([True, False, True, False, False, False])
self.assert_numpy_array_equal(result, expected)
result = idx1 >= val
expected = np.array([False, False, True, False, True, True])
self.assert_numpy_array_equal(result, expected)

result = idx1 == val
expected = np.array([False, False, True, False, False, False])
self.assert_numpy_array_equal(result, expected)

result = idx1 != val
expected = np.array([True, True, False, True, True, True])
self.assert_numpy_array_equal(result, expected)

def test_map(self):
rng = date_range('1/1/2000', periods=10)

Expand Down

0 comments on commit f8e94fa

Please sign in to comment.