Skip to content

Commit

Permalink
BUG: Index.union cannot handle array-likes
Browse files Browse the repository at this point in the history
  • Loading branch information
sinhrks committed May 31, 2015
1 parent 5852e72 commit dddc8ea
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 61 deletions.
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v0.17.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Bug Fixes

- Bug in ``mean()`` where integer dtypes can overflow (:issue:`10172`)
- Bug where Panel.from_dict does not set dtype when specified (:issue:`10058`)
- Bug in ``Index.union`` raises ``AttributeError`` when passing array-likes. (:issue:`10149`)
- Bug in ``Timestamp``'s' ``microsecond``, ``quarter``, ``dayofyear``, ``week`` and ``daysinmonth`` properties return ``np.int`` type, not built-in ``int``. (:issue:`10050`)
- Bug in ``NaT`` raises ``AttributeError`` when accessing to ``daysinmonth``, ``dayofweek`` properties. (:issue:`10096`)

Expand All @@ -91,3 +92,4 @@ Bug Fixes


- Bug where infer_freq infers timerule (WOM-5XXX) unsupported by to_offset (:issue:`9425`)

86 changes: 44 additions & 42 deletions pandas/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,18 @@ def to_datetime(self, dayfirst=False):
return DatetimeIndex(self.values)

def _assert_can_do_setop(self, other):
if not com.is_list_like(other):
raise TypeError('Input must be Index or array-like')
return True

def _convert_can_do_setop(self, other):
if not isinstance(other, Index):
other = Index(other, name=self.name)
result_name = self.name
else:
result_name = self.name if self.name == other.name else None
return other, result_name

@property
def nlevels(self):
return 1
Expand Down Expand Up @@ -1364,16 +1374,14 @@ def union(self, other):
-------
union : Index
"""
if not hasattr(other, '__iter__'):
raise TypeError('Input must be iterable.')
self._assert_can_do_setop(other)
other = _ensure_index(other)

if len(other) == 0 or self.equals(other):
return self

if len(self) == 0:
return _ensure_index(other)

self._assert_can_do_setop(other)
return other

if not is_dtype_equal(self.dtype,other.dtype):
this = self.astype('O')
Expand Down Expand Up @@ -1439,11 +1447,7 @@ def intersection(self, other):
-------
intersection : Index
"""
if not hasattr(other, '__iter__'):
raise TypeError('Input must be iterable!')

self._assert_can_do_setop(other)

other = _ensure_index(other)

if self.equals(other):
Expand Down Expand Up @@ -1492,18 +1496,12 @@ def difference(self, other):
>>> index.difference(index2)
"""

if not hasattr(other, '__iter__'):
raise TypeError('Input must be iterable!')
self._assert_can_do_setop(other)

if self.equals(other):
return Index([], name=self.name)

if not isinstance(other, Index):
other = np.asarray(other)
result_name = self.name
else:
result_name = self.name if self.name == other.name else None
other, result_name = self._convert_can_do_setop(other)

theDiff = sorted(set(self) - set(other))
return Index(theDiff, name=result_name)
Expand All @@ -1517,7 +1515,7 @@ def sym_diff(self, other, result_name=None):
Parameters
----------
other : array-like
other : Index or array-like
result_name : str
Returns
Expand Down Expand Up @@ -1545,13 +1543,10 @@ def sym_diff(self, other, result_name=None):
>>> idx1 ^ idx2
Int64Index([1, 5], dtype='int64')
"""
if not hasattr(other, '__iter__'):
raise TypeError('Input must be iterable!')

if not isinstance(other, Index):
other = Index(other)
result_name = result_name or self.name

self._assert_can_do_setop(other)
other, result_name_update = self._convert_can_do_setop(other)
if result_name is None:
result_name = result_name_update
the_diff = sorted(set((self.difference(other)).union(other.difference(self))))
return Index(the_diff, name=result_name)

Expand Down Expand Up @@ -5460,12 +5455,11 @@ def union(self, other):
>>> index.union(index2)
"""
self._assert_can_do_setop(other)
other, result_names = self._convert_can_do_setop(other)

if len(other) == 0 or self.equals(other):
return self

result_names = self.names if self.names == other.names else None

uniq_tuples = lib.fast_unique_multiple([self.values, other.values])
return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0,
names=result_names)
Expand All @@ -5483,12 +5477,11 @@ def intersection(self, other):
Index
"""
self._assert_can_do_setop(other)
other, result_names = self._convert_can_do_setop(other)

if self.equals(other):
return self

result_names = self.names if self.names == other.names else None

self_tuples = self.values
other_tuples = other.values
uniq_tuples = sorted(set(self_tuples) & set(other_tuples))
Expand All @@ -5509,18 +5502,10 @@ def difference(self, other):
diff : MultiIndex
"""
self._assert_can_do_setop(other)
other, result_names = self._convert_can_do_setop(other)

if not isinstance(other, MultiIndex):
if len(other) == 0:
if len(other) == 0:
return self
try:
other = MultiIndex.from_tuples(other)
except:
raise TypeError('other must be a MultiIndex or a list of'
' tuples')
result_names = self.names
else:
result_names = self.names if self.names == other.names else None

if self.equals(other):
return MultiIndex(levels=[[]] * self.nlevels,
Expand All @@ -5537,15 +5522,32 @@ def difference(self, other):
return MultiIndex.from_tuples(difference, sortorder=0,
names=result_names)

def _assert_can_do_setop(self, other):
pass

def astype(self, dtype):
if not is_object_dtype(np.dtype(dtype)):
raise TypeError('Setting %s dtype to anything other than object '
'is not supported' % self.__class__)
return self._shallow_copy()

def _convert_can_do_setop(self, other):
result_names = self.names

if not isinstance(other, MultiIndex):
if len(other) == 0:
other = MultiIndex(levels=[[]] * self.nlevels,
labels=[[]] * self.nlevels,
verify_integrity=False)
elif isinstance(other, Index):
result_names = self.names if self.names == other.names else None
else:
msg = 'other must be a MultiIndex or a list of tuples'
try:
other = MultiIndex.from_tuples(other)
except:
raise TypeError(msg)
else:
result_names = self.names if self.names == other.names else None
return other, result_names

def insert(self, loc, item):
"""
Make new MultiIndex inserting new item at location
Expand Down
Loading

0 comments on commit dddc8ea

Please sign in to comment.