Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

Already on GitHub? Sign in to your account

BUG: DTI.intersection doesnt preserve tz #7458

Merged
merged 1 commit into from Jun 14, 2014
Jump to file or symbol
Failed to load files and symbols.
+146 −45
Split
View
@@ -232,4 +232,5 @@ Bug Fixes
-
+- Bug in non-monotonic ``Index.union`` may preserve ``name`` incorrectly (:issue:`7458`)
+- Bug in ``DatetimeIndex.intersection`` doesn't preserve timezone (:issue:`4690`)
View
@@ -777,7 +777,8 @@ def take(self, indexer, axis=0):
"""
indexer = com._ensure_platform_int(indexer)
taken = self.view(np.ndarray).take(indexer)
- return self._constructor(taken, name=self.name)
+ return self._simple_new(taken, name=self.name, freq=None,
+ tz=getattr(self, 'tz', None))
def format(self, name=False, formatter=None, **kwargs):
"""
@@ -1075,7 +1076,10 @@ def intersection(self, other):
# duplicates
indexer = self.get_indexer_non_unique(other.values)[0].unique()
- return self.take(indexer)
+ taken = self.take(indexer)
+ if self.name != other.name:
+ taken.name = None
+ return taken
def diff(self, other):
"""
View
@@ -2,7 +2,6 @@
from datetime import datetime, timedelta
from pandas.compat import range, lrange, lzip, u, zip
-import sys
import operator
import pickle
import re
@@ -447,6 +446,33 @@ def test_intersection(self):
# non-iterable input
assertRaisesRegexp(TypeError, "iterable", first.intersection, 0.5)
+ idx1 = Index([1, 2, 3, 4, 5], name='idx')
+ # if target has the same name, it is preserved
+ idx2 = Index([3, 4, 5, 6, 7], name='idx')
+ expected2 = Index([3, 4, 5], name='idx')
+ result2 = idx1.intersection(idx2)
+ self.assertTrue(result2.equals(expected2))
+ self.assertEqual(result2.name, expected2.name)
+
+ # if target name is different, it will be reset
+ idx3 = Index([3, 4, 5, 6, 7], name='other')
+ expected3 = Index([3, 4, 5], name=None)
+ result3 = idx1.intersection(idx3)
+ self.assertTrue(result3.equals(expected3))
+ self.assertEqual(result3.name, expected3.name)
+
+ # non monotonic
+ idx1 = Index([5, 3, 2, 4, 1], name='idx')
+ idx2 = Index([4, 7, 6, 5, 3], name='idx')
+ result2 = idx1.intersection(idx2)
+ self.assertTrue(tm.equalContents(result2, expected2))
+ self.assertEqual(result2.name, expected2.name)
+
+ idx3 = Index([4, 7, 6, 5, 3], name='other')
+ result3 = idx1.intersection(idx3)
+ self.assertTrue(tm.equalContents(result3, expected3))
+ self.assertEqual(result3.name, expected3.name)
+
def test_union(self):
first = self.strIndex[5:20]
second = self.strIndex[:10]
View
@@ -900,9 +900,7 @@ def take(self, indices, axis=0):
maybe_slice = lib.maybe_indices_to_slice(com._ensure_int64(indices))
if isinstance(maybe_slice, slice):
return self[maybe_slice]
- indices = com._ensure_platform_int(indices)
- taken = self.values.take(indices, axis=axis)
- return self._simple_new(taken, self.name, None, self.tz)
+ return super(DatetimeIndex, self).take(indices, axis)
def unique(self):
"""
@@ -1125,6 +1123,12 @@ def __array_finalize__(self, obj):
self.name = getattr(obj, 'name', None)
self._reset_identity()
+ def _wrap_union_result(self, other, result):
+ name = self.name if self.name == other.name else None
+ if self.tz != other.tz:
+ raise ValueError('Passed item and index have different timezone')
+ return self._simple_new(result, name=name, freq=None, tz=self.tz)
+
def intersection(self, other):
"""
Specialized intersection for DatetimeIndex objects. May be much faster
View
@@ -1133,10 +1133,7 @@ def take(self, indices, axis=None):
"""
indices = com._ensure_platform_int(indices)
taken = self.values.take(indices, axis=axis)
- taken = taken.view(PeriodIndex)
- taken.freq = self.freq
- taken.name = self.name
- return taken
+ return self._simple_new(taken, self.name, freq=self.freq)
def append(self, other):
"""
@@ -2070,14 +2070,19 @@ def test_iteration(self):
self.assertEqual(result[0].freq, index.freq)
def test_take(self):
- index = PeriodIndex(start='1/1/10', end='12/31/12', freq='D')
+ index = PeriodIndex(start='1/1/10', end='12/31/12', freq='D', name='idx')
+ expected = PeriodIndex([datetime(2010, 1, 6), datetime(2010, 1, 7),
+ datetime(2010, 1, 9), datetime(2010, 1, 13)],
+ freq='D', name='idx')
- taken = index.take([5, 6, 8, 12])
+ taken1 = index.take([5, 6, 8, 12])
taken2 = index[[5, 6, 8, 12]]
- tm.assert_isinstance(taken, PeriodIndex)
- self.assertEqual(taken.freq, index.freq)
- tm.assert_isinstance(taken2, PeriodIndex)
- self.assertEqual(taken2.freq, index.freq)
+
+ for taken in [taken1, taken2]:
+ self.assertTrue(taken.equals(expected))
+ tm.assert_isinstance(taken, PeriodIndex)
+ self.assertEqual(taken.freq, index.freq)
+ self.assertEqual(taken.name, expected.name)
def test_joins(self):
index = period_range('1/1/2000', '1/20/2000', freq='D')
@@ -2467,6 +2467,25 @@ def test_delete_slice(self):
self.assertEqual(result.name, expected.name)
self.assertEqual(result.freq, expected.freq)
+ def test_take(self):
+ dates = [datetime(2010, 1, 6), datetime(2010, 1, 7),
+ datetime(2010, 1, 9), datetime(2010, 1, 13)]
+
+ for tz in [None, 'US/Eastern', 'Asia/Tokyo']:
+ idx = DatetimeIndex(start='1/1/10', end='12/31/12',
+ freq='D', tz=tz, name='idx')
+ expected = DatetimeIndex(dates, freq=None, name='idx', tz=tz)
+
+ taken1 = idx.take([5, 6, 8, 12])
+ taken2 = idx[[5, 6, 8, 12]]
+
+ for taken in [taken1, taken2]:
+ self.assertTrue(taken.equals(expected))
+ tm.assert_isinstance(taken, DatetimeIndex)
+ self.assertIsNone(taken.freq)
+ self.assertEqual(taken.tz, expected.tz)
+ self.assertEqual(taken.name, expected.name)
+
def test_map_bug_1677(self):
index = DatetimeIndex(['2012-04-25 09:30:00.393000'])
f = index.asof
@@ -3035,14 +3054,46 @@ def test_union(self):
self.assertEqual(df.index.values.dtype, np.dtype('M8[ns]'))
def test_intersection(self):
- rng = date_range('6/1/2000', '6/15/2000', freq='D')
- rng = rng.delete(5)
-
- rng2 = date_range('5/15/2000', '6/20/2000', freq='D')
- rng2 = DatetimeIndex(rng2.values)
-
- result = rng.intersection(rng2)
- self.assertTrue(result.equals(rng))
+ # GH 4690 (with tz)
+ for tz in [None, 'Asia/Tokyo']:
+ rng = date_range('6/1/2000', '6/30/2000', freq='D', name='idx')
+
+ # if target has the same name, it is preserved
+ rng2 = date_range('5/15/2000', '6/20/2000', freq='D', name='idx')
+ expected2 = date_range('6/1/2000', '6/20/2000', freq='D', name='idx')
+
+ # if target name is different, it will be reset
+ rng3 = date_range('5/15/2000', '6/20/2000', freq='D', name='other')
+ expected3 = date_range('6/1/2000', '6/20/2000', freq='D', name=None)
+
+ result2 = rng.intersection(rng2)
+ result3 = rng.intersection(rng3)
+ for (result, expected) in [(result2, expected2), (result3, expected3)]:
+ self.assertTrue(result.equals(expected))
+ self.assertEqual(result.name, expected.name)
+ self.assertEqual(result.freq, expected.freq)
+ self.assertEqual(result.tz, expected.tz)
+
+ # non-monotonic
+ rng = DatetimeIndex(['2011-01-05', '2011-01-04', '2011-01-02', '2011-01-03'],
+ tz=tz, name='idx')
+
+ rng2 = DatetimeIndex(['2011-01-04', '2011-01-02', '2011-02-02', '2011-02-03'],
+ tz=tz, name='idx')
+ expected2 = DatetimeIndex(['2011-01-04', '2011-01-02'], tz=tz, name='idx')
+
+ rng3 = DatetimeIndex(['2011-01-04', '2011-01-02', '2011-02-02', '2011-02-03'],
+ tz=tz, name='other')
+ expected3 = DatetimeIndex(['2011-01-04', '2011-01-02'], tz=tz, name=None)
+
+ result2 = rng.intersection(rng2)
+ result3 = rng.intersection(rng3)
+ for (result, expected) in [(result2, expected2), (result3, expected3)]:
+ print(result, expected)
+ self.assertTrue(result.equals(expected))
+ self.assertEqual(result.name, expected.name)
+ self.assertIsNone(result.freq)
+ self.assertEqual(result.tz, expected.tz)
# empty same freq GH2129
rng = date_range('6/1/2000', '6/15/2000', freq='T')
@@ -3571,26 +3622,39 @@ def test_shift(self):
self.assertRaises(ValueError, idx.shift, 1)
def test_setops_preserve_freq(self):
- rng = date_range('1/1/2000', '1/1/2002')
-
- result = rng[:50].union(rng[50:100])
- self.assertEqual(result.freq, rng.freq)
-
- result = rng[:50].union(rng[30:100])
- self.assertEqual(result.freq, rng.freq)
-
- result = rng[:50].union(rng[60:100])
- self.assertIsNone(result.freq)
-
- result = rng[:50].intersection(rng[25:75])
- self.assertEqual(result.freqstr, 'D')
-
- nofreq = DatetimeIndex(list(rng[25:75]))
- result = rng[:50].union(nofreq)
- self.assertEqual(result.freq, rng.freq)
-
- result = rng[:50].intersection(nofreq)
- self.assertEqual(result.freq, rng.freq)
+ for tz in [None, 'Asia/Tokyo', 'US/Eastern']:
+ rng = date_range('1/1/2000', '1/1/2002', name='idx', tz=tz)
+
+ result = rng[:50].union(rng[50:100])
+ self.assertEqual(result.name, rng.name)
+ self.assertEqual(result.freq, rng.freq)
+ self.assertEqual(result.tz, rng.tz)
+
+ result = rng[:50].union(rng[30:100])
+ self.assertEqual(result.name, rng.name)
+ self.assertEqual(result.freq, rng.freq)
+ self.assertEqual(result.tz, rng.tz)
+
+ result = rng[:50].union(rng[60:100])
+ self.assertEqual(result.name, rng.name)
+ self.assertIsNone(result.freq)
+ self.assertEqual(result.tz, rng.tz)
+
+ result = rng[:50].intersection(rng[25:75])
+ self.assertEqual(result.name, rng.name)
+ self.assertEqual(result.freqstr, 'D')
+ self.assertEqual(result.tz, rng.tz)
+
+ nofreq = DatetimeIndex(list(rng[25:75]), name='other')
+ result = rng[:50].union(nofreq)
+ self.assertIsNone(result.name)
+ self.assertEqual(result.freq, rng.freq)
+ self.assertEqual(result.tz, rng.tz)
+
+ result = rng[:50].intersection(nofreq)
+ self.assertIsNone(result.name)
+ self.assertEqual(result.freq, rng.freq)
+ self.assertEqual(result.tz, rng.tz)
def test_min_max(self):
rng = date_range('1/1/2000', '12/31/2000')