Skip to content

Commit

Permalink
API/ERR: allow iterators in df.set_index & improve errors (#24984)
Browse files Browse the repository at this point in the history
  • Loading branch information
h-vetinari authored and jreback committed Feb 24, 2019
1 parent 183dc02 commit 5ae9b48
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 11 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.25.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Other Enhancements
- Indexing of ``DataFrame`` and ``Series`` now accepts zerodim ``np.ndarray`` (:issue:`24919`)
- :meth:`Timestamp.replace` now supports the ``fold`` argument to disambiguate DST transition times (:issue:`25017`)
- :meth:`DataFrame.at_time` and :meth:`Series.at_time` now support :meth:`datetime.time` objects with timezones (:issue:`24043`)
- :meth:`DataFrame.set_index` now works for instances of ``abc.Iterator``, provided their output is of the same length as the calling frame (:issue:`22484`, :issue:`24984`)
- :meth:`DatetimeIndex.union` now supports the ``sort`` argument. The behaviour of the sort parameter matches that of :meth:`Index.union` (:issue:`24994`)
-

Expand Down
2 changes: 2 additions & 0 deletions pandas/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def lfilter(*args, **kwargs):
reload = reload
Hashable = collections.abc.Hashable
Iterable = collections.abc.Iterable
Iterator = collections.abc.Iterator
Mapping = collections.abc.Mapping
MutableMapping = collections.abc.MutableMapping
Sequence = collections.abc.Sequence
Expand Down Expand Up @@ -199,6 +200,7 @@ def get_range_parameters(data):

Hashable = collections.Hashable
Iterable = collections.Iterable
Iterator = collections.Iterator
Mapping = collections.Mapping
MutableMapping = collections.MutableMapping
Sequence = collections.Sequence
Expand Down
43 changes: 41 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from pandas import compat
from pandas.compat import (range, map, zip, lmap, lzip, StringIO, u,
PY36, raise_with_traceback,
PY36, raise_with_traceback, Iterator,
string_and_binary_types)
from pandas.compat.numpy import function as nv
from pandas.core.dtypes.cast import (
Expand Down Expand Up @@ -4025,7 +4025,8 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
This parameter can be either a single column key, a single array of
the same length as the calling DataFrame, or a list containing an
arbitrary combination of column keys and arrays. Here, "array"
encompasses :class:`Series`, :class:`Index` and ``np.ndarray``.
encompasses :class:`Series`, :class:`Index`, ``np.ndarray``, and
instances of :class:`abc.Iterator`.
drop : bool, default True
Delete columns to be used as the new index.
append : bool, default False
Expand Down Expand Up @@ -4104,6 +4105,32 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
if not isinstance(keys, list):
keys = [keys]

err_msg = ('The parameter "keys" may be a column key, one-dimensional '
'array, or a list containing only valid column keys and '
'one-dimensional arrays.')

missing = []
for col in keys:
if isinstance(col, (ABCIndexClass, ABCSeries, np.ndarray,
list, Iterator)):
# arrays are fine as long as they are one-dimensional
# iterators get converted to list below
if getattr(col, 'ndim', 1) != 1:
raise ValueError(err_msg)
else:
# everything else gets tried as a key; see GH 24969
try:
found = col in self.columns
except TypeError:
raise TypeError(err_msg + ' Received column of '
'type {}'.format(type(col)))
else:
if not found:
missing.append(col)

if missing:
raise KeyError('None of {} are in the columns'.format(missing))

if inplace:
frame = self
else:
Expand Down Expand Up @@ -4132,13 +4159,25 @@ def set_index(self, keys, drop=True, append=False, inplace=False,
elif isinstance(col, (list, np.ndarray)):
arrays.append(col)
names.append(None)
elif isinstance(col, Iterator):
arrays.append(list(col))
names.append(None)
# from here, col can only be a column label
else:
arrays.append(frame[col]._values)
names.append(col)
if drop:
to_remove.append(col)

if len(arrays[-1]) != len(self):
# check newest element against length of calling frame, since
# ensure_index_from_sequences would not raise for append=False.
raise ValueError('Length mismatch: Expected {len_self} rows, '
'received array of length {len_col}'.format(
len_self=len(self),
len_col=len(arrays[-1])
))

index = ensure_index_from_sequences(arrays, names)

if verify_integrity and not index.is_unique:
Expand Down
44 changes: 35 additions & 9 deletions pandas/tests/frame/test_alter_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ def test_set_index_pass_arrays(self, frame_of_index_cols,
# MultiIndex constructor does not work directly on Series -> lambda
# We also emulate a "constructor" for the label -> lambda
# also test index name if append=True (name is duplicate here for A)
@pytest.mark.parametrize('box2', [Series, Index, np.array, list,
@pytest.mark.parametrize('box2', [Series, Index, np.array, list, iter,
lambda x: MultiIndex.from_arrays([x]),
lambda x: x.name])
@pytest.mark.parametrize('box1', [Series, Index, np.array, list,
@pytest.mark.parametrize('box1', [Series, Index, np.array, list, iter,
lambda x: MultiIndex.from_arrays([x]),
lambda x: x.name])
@pytest.mark.parametrize('append, index_name', [(True, None),
Expand All @@ -195,6 +195,9 @@ def test_set_index_pass_arrays_duplicate(self, frame_of_index_cols, drop,
keys = [box1(df['A']), box2(df['A'])]
result = df.set_index(keys, drop=drop, append=append)

# if either box is iter, it has been consumed; re-read
keys = [box1(df['A']), box2(df['A'])]

# need to adapt first drop for case that both keys are 'A' --
# cannot drop the same column twice;
# use "is" because == would give ambiguous Boolean error for containers
Expand Down Expand Up @@ -253,25 +256,48 @@ def test_set_index_raise_keys(self, frame_of_index_cols, drop, append):
df.set_index(['A', df['A'], tuple(df['A'])],
drop=drop, append=append)

@pytest.mark.xfail(reason='broken due to revert, see GH 25085')
@pytest.mark.parametrize('append', [True, False])
@pytest.mark.parametrize('drop', [True, False])
@pytest.mark.parametrize('box', [set, iter, lambda x: (y for y in x)],
ids=['set', 'iter', 'generator'])
@pytest.mark.parametrize('box', [set], ids=['set'])
def test_set_index_raise_on_type(self, frame_of_index_cols, box,
drop, append):
df = frame_of_index_cols

msg = 'The parameter "keys" may be a column key, .*'
# forbidden type, e.g. set/iter/generator
# forbidden type, e.g. set
with pytest.raises(TypeError, match=msg):
df.set_index(box(df['A']), drop=drop, append=append)

# forbidden type in list, e.g. set/iter/generator
# forbidden type in list, e.g. set
with pytest.raises(TypeError, match=msg):
df.set_index(['A', df['A'], box(df['A'])],
drop=drop, append=append)

# MultiIndex constructor does not work directly on Series -> lambda
@pytest.mark.parametrize('box', [Series, Index, np.array, iter,
lambda x: MultiIndex.from_arrays([x])],
ids=['Series', 'Index', 'np.array',
'iter', 'MultiIndex'])
@pytest.mark.parametrize('length', [4, 6], ids=['too_short', 'too_long'])
@pytest.mark.parametrize('append', [True, False])
@pytest.mark.parametrize('drop', [True, False])
def test_set_index_raise_on_len(self, frame_of_index_cols, box, length,
drop, append):
# GH 24984
df = frame_of_index_cols # has length 5

values = np.random.randint(0, 10, (length,))

msg = 'Length mismatch: Expected 5 rows, received array of length.*'

# wrong length directly
with pytest.raises(ValueError, match=msg):
df.set_index(box(values), drop=drop, append=append)

# wrong length in list
with pytest.raises(ValueError, match=msg):
df.set_index(['A', df.A, box(values)], drop=drop, append=append)

def test_set_index_custom_label_type(self):
# GH 24969

Expand Down Expand Up @@ -341,7 +367,7 @@ def __repr__(self):

# missing key
thing3 = Thing(['Three', 'pink'])
msg = '.*' # due to revert, see GH 25085
msg = r"frozenset\(\{'Three', 'pink'\}\)"
with pytest.raises(KeyError, match=msg):
# missing label directly
df.set_index(thing3)
Expand All @@ -366,7 +392,7 @@ def __str__(self):
thing2 = Thing('Two', 'blue')
df = DataFrame([[0, 2], [1, 3]], columns=[thing1, thing2])

msg = 'unhashable type.*'
msg = 'The parameter "keys" may be a column key, .*'

with pytest.raises(TypeError, match=msg):
# use custom label directly
Expand Down

0 comments on commit 5ae9b48

Please sign in to comment.