Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

Already on GitHub? Sign in to your account

BUG: DatetimeIndex comparison handles NaT incorrectly #7529

Merged
merged 1 commit into from Jun 20, 2014
Jump to file or symbol
Failed to load files and symbols.
+116 −20
Split
View
@@ -237,6 +237,7 @@ Bug Fixes
- Bug in when writing Stata files where the encoding was ignored (:issue:`7286`)
+- Bug in ``DatetimeIndex`` comparison doesn't handle ``NaT`` properly (:issue:`7529`)
- Bug in passing input with ``tzinfo`` to some offsets ``apply``, ``rollforward`` or ``rollback`` resets ``tzinfo`` or raises ``ValueError`` (:issue:`7465`)
View
@@ -74,22 +74,35 @@ def wrapper(left, right):
return wrapper
-def _dt_index_cmp(opname):
+def _dt_index_cmp(opname, nat_result=False):
"""
Wrap comparison operations to convert datetime-like to datetime64
"""
def wrapper(self, other):
func = getattr(super(DatetimeIndex, self), opname)
- if isinstance(other, datetime):
+ if isinstance(other, datetime) or isinstance(other, compat.string_types):
other = _to_m8(other, tz=self.tz)
- elif isinstance(other, list):
- other = DatetimeIndex(other)
- elif isinstance(other, compat.string_types):
- other = _to_m8(other, tz=self.tz)
- elif not isinstance(other, (np.ndarray, ABCSeries)):
- other = _ensure_datetime64(other)
- result = func(other)
+ result = func(other)
+ if com.isnull(other):
+ result.fill(nat_result)
+ else:
+ if isinstance(other, list):
+ other = DatetimeIndex(other)
+ elif not isinstance(other, (np.ndarray, ABCSeries)):
+ other = _ensure_datetime64(other)
+ result = func(other)
+ if isinstance(other, Index):
+ o_mask = other.values.view('i8') == tslib.iNaT
+ else:
+ o_mask = other.view('i8') == tslib.iNaT
+
+ if o_mask.any():
+ result[o_mask] = nat_result
+
+ mask = self.asi8 == tslib.iNaT
+ if mask.any():
+ result[mask] = nat_result
return result.view(np.ndarray)
return wrapper
@@ -142,7 +155,7 @@ class DatetimeIndex(DatetimeIndexOpsMixin, Int64Index):
_arrmap = None
__eq__ = _dt_index_cmp('__eq__')
- __ne__ = _dt_index_cmp('__ne__')
+ __ne__ = _dt_index_cmp('__ne__', nat_result=True)
__lt__ = _dt_index_cmp('__lt__')
__gt__ = _dt_index_cmp('__gt__')
__le__ = _dt_index_cmp('__le__')
View
@@ -498,16 +498,11 @@ def dt64arr_to_periodarr(data, freq, tz):
# --- Period index sketch
-def _period_index_cmp(opname):
+def _period_index_cmp(opname, nat_result=False):
"""
Wrap comparison operations to convert datetime-like to datetime64
"""
def wrapper(self, other):
- if opname == '__ne__':
- fill_value = True
- else:
- fill_value = False
-
if isinstance(other, Period):
func = getattr(self.values, opname)
if other.freq != self.freq:
@@ -523,7 +518,7 @@ def wrapper(self, other):
mask = (com.mask_missing(self.values, tslib.iNaT) |
com.mask_missing(other.values, tslib.iNaT))
if mask.any():
- result[mask] = fill_value
+ result[mask] = nat_result
return result
else:
@@ -532,10 +527,10 @@ def wrapper(self, other):
result = func(other.ordinal)
if other.ordinal == tslib.iNaT:
- result.fill(fill_value)
+ result.fill(nat_result)
mask = self.values == tslib.iNaT
if mask.any():
- result[mask] = fill_value
+ result[mask] = nat_result
return result
return wrapper
@@ -595,7 +590,7 @@ class PeriodIndex(DatetimeIndexOpsMixin, Int64Index):
_allow_period_index_ops = True
__eq__ = _period_index_cmp('__eq__')
- __ne__ = _period_index_cmp('__ne__')
+ __ne__ = _period_index_cmp('__ne__', nat_result=True)
__lt__ = _period_index_cmp('__lt__')
__gt__ = _period_index_cmp('__gt__')
__le__ = _period_index_cmp('__le__')
@@ -2179,6 +2179,93 @@ def test_comparisons_coverage(self):
exp = rng == rng
self.assert_numpy_array_equal(result, exp)
+ def test_comparisons_nat(self):
+ fidx1 = pd.Index([1.0, np.nan, 3.0, np.nan, 5.0, 7.0])
+ fidx2 = pd.Index([2.0, 3.0, np.nan, np.nan, 6.0, 7.0])
+
+ didx1 = pd.DatetimeIndex(['2014-01-01', pd.NaT, '2014-03-01', pd.NaT,
+ '2014-05-01', '2014-07-01'])
+ didx2 = pd.DatetimeIndex(['2014-02-01', '2014-03-01', pd.NaT, pd.NaT,
+ '2014-06-01', '2014-07-01'])
+ darr = np.array([np.datetime64('2014-02-01 00:00Z'),
+ np.datetime64('2014-03-01 00:00Z'),
+ np.datetime64('nat'), np.datetime64('nat'),
+ np.datetime64('2014-06-01 00:00Z'),
+ np.datetime64('2014-07-01 00:00Z')])
+
+ if _np_version_under1p7:
+ # cannot test array because np.datetime('nat') returns today's date
+ cases = [(fidx1, fidx2), (didx1, didx2)]
+ else:
+ cases = [(fidx1, fidx2), (didx1, didx2), (didx1, darr)]
+
+ # Check pd.NaT is handles as the same as np.nan
+ for idx1, idx2 in cases:
+ result = idx1 < idx2
+ expected = np.array([True, False, False, False, True, False])
+ self.assert_numpy_array_equal(result, expected)
+ result = idx2 > idx1
+ expected = np.array([True, False, False, False, True, False])
+ self.assert_numpy_array_equal(result, expected)
+
+ result = idx1 <= idx2
+ expected = np.array([True, False, False, False, True, True])
+ self.assert_numpy_array_equal(result, expected)
+ result = idx2 >= idx1
+ expected = np.array([True, False, False, False, True, True])
+ self.assert_numpy_array_equal(result, expected)
+
+ result = idx1 == idx2
+ expected = np.array([False, False, False, False, False, True])
+ self.assert_numpy_array_equal(result, expected)
+
+ result = idx1 != idx2
+ expected = np.array([True, True, True, True, True, False])
+ self.assert_numpy_array_equal(result, expected)
+
+ for idx1, val in [(fidx1, np.nan), (didx1, pd.NaT)]:
+ result = idx1 < val
+ expected = np.array([False, False, False, False, False, False])
+ self.assert_numpy_array_equal(result, expected)
+ result = idx1 > val
+ self.assert_numpy_array_equal(result, expected)
+
+ result = idx1 <= val
+ self.assert_numpy_array_equal(result, expected)
+ result = idx1 >= val
+ self.assert_numpy_array_equal(result, expected)
+
+ result = idx1 == val
+ self.assert_numpy_array_equal(result, expected)
+
+ result = idx1 != val
+ expected = np.array([True, True, True, True, True, True])
+ self.assert_numpy_array_equal(result, expected)
+
+ # Check pd.NaT is handles as the same as np.nan
+ for idx1, val in [(fidx1, 3), (didx1, datetime(2014, 3, 1))]:
+ result = idx1 < val
+ expected = np.array([True, False, False, False, False, False])
+ self.assert_numpy_array_equal(result, expected)
+ result = idx1 > val
+ expected = np.array([False, False, False, False, True, True])
+ self.assert_numpy_array_equal(result, expected)
+
+ result = idx1 <= val
+ expected = np.array([True, False, True, False, False, False])
+ self.assert_numpy_array_equal(result, expected)
+ result = idx1 >= val
+ expected = np.array([False, False, True, False, True, True])
+ self.assert_numpy_array_equal(result, expected)
+
+ result = idx1 == val
+ expected = np.array([False, False, True, False, False, False])
+ self.assert_numpy_array_equal(result, expected)
+
+ result = idx1 != val
+ expected = np.array([True, True, False, True, True, True])
+ self.assert_numpy_array_equal(result, expected)
+
def test_map(self):
rng = date_range('1/1/2000', periods=10)