New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] EHN add support for scalar, slice and mask in safe_indexing axis=0 #14475
Changes from all commits
c01385c
0e5c037
f5e08c4
bb4db91
8cd74db
9878ef1
2f6a0bd
d0f8d60
f95a228
1c81803
c8009a2
a80b33d
56a6759
9fb045d
0d46f7f
5dcf34f
70f0e02
7127b5a
2f96882
619fb05
b7539bd
92d1aaf
b1918e8
fe29402
6322f99
4d4cc2d
050932a
46f96a9
d880075
ae868e1
ead08d6
c7bb542
655a218
db80c5a
ce618fd
581402b
b8e98c2
11dff01
94edba2
ef4aa2c
395803a
a7d29f6
3624cc5
557aa43
7d5404a
698aef2
4f2fd8f
abc90d7
b99b5a2
3bf36b0
bd86ccd
bfb9fa2
c433494
7994cc9
5e29cc6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -3,6 +3,7 @@ | |||||||||||||||||
""" | ||||||||||||||||||
from collections.abc import Sequence | ||||||||||||||||||
from contextlib import contextmanager | ||||||||||||||||||
from itertools import compress | ||||||||||||||||||
from itertools import islice | ||||||||||||||||||
import numbers | ||||||||||||||||||
import platform | ||||||||||||||||||
|
@@ -180,199 +181,165 @@ def axis0_safe_slice(X, mask, len_mask): | |||||||||||||||||
return np.zeros(shape=(0, X.shape[1])) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def safe_indexing(X, indices, axis=0): | ||||||||||||||||||
"""Return rows, items or columns of X using indices. | ||||||||||||||||||
def _array_indexing(array, key, key_dtype, axis): | ||||||||||||||||||
"""Index an array or scipy.sparse consistently across NumPy version.""" | ||||||||||||||||||
if np_version < (1, 12) or issparse(array): | ||||||||||||||||||
# FIXME: Remove the check for NumPy when using >= 1.12 | ||||||||||||||||||
# check if we have an boolean array-likes to make the proper indexing | ||||||||||||||||||
if key_dtype == 'bool': | ||||||||||||||||||
key = np.asarray(key) | ||||||||||||||||||
return array[key] if axis == 0 else array[:, key] | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _pandas_indexing(X, key, key_dtype, axis): | ||||||||||||||||||
"""Index a pandas dataframe or a series.""" | ||||||||||||||||||
if hasattr(key, 'shape'): | ||||||||||||||||||
# Work-around for indexing with read-only key in pandas | ||||||||||||||||||
# FIXME: solved in pandas 0.25 | ||||||||||||||||||
key = np.asarray(key) | ||||||||||||||||||
key = key if key.flags.writeable else key.copy() | ||||||||||||||||||
# check whether we should index with loc or iloc | ||||||||||||||||||
indexer = X.iloc if key_dtype == 'int' else X.loc | ||||||||||||||||||
return indexer[:, key] if axis else indexer[key] | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _list_indexing(X, key, key_dtype): | ||||||||||||||||||
"""Index a Python list.""" | ||||||||||||||||||
if np.isscalar(key) or isinstance(key, slice): | ||||||||||||||||||
# key is a slice or a scalar | ||||||||||||||||||
return X[key] | ||||||||||||||||||
if key_dtype == 'bool': | ||||||||||||||||||
# key is a boolean array-like | ||||||||||||||||||
return list(compress(X, key)) | ||||||||||||||||||
# key is a integer array-like of key | ||||||||||||||||||
return [X[idx] for idx in key] | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _determine_key_type(key): | ||||||||||||||||||
"""Determine the data type of key. | ||||||||||||||||||
|
||||||||||||||||||
Parameters | ||||||||||||||||||
---------- | ||||||||||||||||||
X : array-like, sparse-matrix, list, pandas.DataFrame, pandas.Series | ||||||||||||||||||
Data from which to sample rows, items or columns. | ||||||||||||||||||
indices : array-like | ||||||||||||||||||
- When ``axis=0``, indices need to be an array of integer. | ||||||||||||||||||
- When ``axis=1``, indices can be one of: | ||||||||||||||||||
- scalar: output is 1D, unless `X` is sparse. | ||||||||||||||||||
Supported data types for scalars: | ||||||||||||||||||
- integer: supported for arrays, sparse matrices and | ||||||||||||||||||
dataframes. | ||||||||||||||||||
- string (key-based): only supported for dataframes. | ||||||||||||||||||
- container: lists, slices, boolean masks: output is 2D. | ||||||||||||||||||
Supported data types for containers: | ||||||||||||||||||
- integer or boolean (positional): supported for | ||||||||||||||||||
arrays, sparse matrices and dataframes | ||||||||||||||||||
- string (key-based): only supported for dataframes. No keys | ||||||||||||||||||
other than strings are allowed. | ||||||||||||||||||
axis : int, default=0 | ||||||||||||||||||
The axis along which `X` will be subsampled. ``axis=0`` will select | ||||||||||||||||||
rows while ``axis=1`` will select columns. | ||||||||||||||||||
key : scalar, slice or array-like | ||||||||||||||||||
The key from which we want to infer the data type. | ||||||||||||||||||
|
||||||||||||||||||
Returns | ||||||||||||||||||
------- | ||||||||||||||||||
subset | ||||||||||||||||||
Subset of X on axis 0 or 1. | ||||||||||||||||||
|
||||||||||||||||||
Notes | ||||||||||||||||||
----- | ||||||||||||||||||
CSR, CSC, and LIL sparse matrices are supported. COO sparse matrices are | ||||||||||||||||||
not supported. | ||||||||||||||||||
dtype : {'int', 'str', 'bool', None} | ||||||||||||||||||
Returns the data type of key. | ||||||||||||||||||
""" | ||||||||||||||||||
if axis == 0: | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main refactor here is to use different dispatcher depending on the type of |
||||||||||||||||||
return _safe_indexing_row(X, indices) | ||||||||||||||||||
elif axis == 1: | ||||||||||||||||||
return _safe_indexing_column(X, indices) | ||||||||||||||||||
else: | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
"'axis' should be either 0 (to index rows) or 1 (to index " | ||||||||||||||||||
" column). Got {} instead.".format(axis) | ||||||||||||||||||
) | ||||||||||||||||||
err_msg = ("No valid specification of the columns. Only a scalar, list or " | ||||||||||||||||||
"slice of all integers or all strings, or boolean mask is " | ||||||||||||||||||
"allowed") | ||||||||||||||||||
|
||||||||||||||||||
dtype_to_str = {int: 'int', str: 'str', bool: 'bool', np.bool_: 'bool'} | ||||||||||||||||||
array_dtype_to_str = {'i': 'int', 'u': 'int', 'b': 'bool', 'O': 'str', | ||||||||||||||||||
'U': 'str', 'S': 'str'} | ||||||||||||||||||
|
||||||||||||||||||
def _array_indexing(array, key, axis=0): | ||||||||||||||||||
"""Index an array consistently across NumPy version.""" | ||||||||||||||||||
if axis not in (0, 1): | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
"'axis' should be either 0 (to index rows) or 1 (to index " | ||||||||||||||||||
" column). Got {} instead.".format(axis) | ||||||||||||||||||
) | ||||||||||||||||||
if np_version < (1, 12) or issparse(array): | ||||||||||||||||||
# check if we have an boolean array-likes to make the proper indexing | ||||||||||||||||||
key_array = np.asarray(key) | ||||||||||||||||||
if np.issubdtype(key_array.dtype, np.bool_): | ||||||||||||||||||
key = key_array | ||||||||||||||||||
return array[key] if axis == 0 else array[:, key] | ||||||||||||||||||
|
||||||||||||||||||
if key is None: | ||||||||||||||||||
return None | ||||||||||||||||||
if isinstance(key, tuple(dtype_to_str.keys())): | ||||||||||||||||||
try: | ||||||||||||||||||
return dtype_to_str[type(key)] | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure this is the right behaviour for a scalar bool There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you mean? Note that in safe indexing, we are not using scalar bool to index. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Okay ;) |
||||||||||||||||||
except KeyError: | ||||||||||||||||||
raise ValueError(err_msg) | ||||||||||||||||||
if isinstance(key, slice): | ||||||||||||||||||
if key.start is None and key.stop is None: | ||||||||||||||||||
return None | ||||||||||||||||||
key_start_type = _determine_key_type(key.start) | ||||||||||||||||||
key_stop_type = _determine_key_type(key.stop) | ||||||||||||||||||
if key_start_type is not None and key_stop_type is not None: | ||||||||||||||||||
if key_start_type != key_stop_type: | ||||||||||||||||||
raise ValueError(err_msg) | ||||||||||||||||||
if key_start_type is not None: | ||||||||||||||||||
return key_start_type | ||||||||||||||||||
return key_stop_type | ||||||||||||||||||
if isinstance(key, list): | ||||||||||||||||||
unique_key = set(key) | ||||||||||||||||||
key_type = {_determine_key_type(elt) for elt in unique_key} | ||||||||||||||||||
if not key_type: | ||||||||||||||||||
return None | ||||||||||||||||||
if len(key_type) != 1: | ||||||||||||||||||
raise ValueError(err_msg) | ||||||||||||||||||
return key_type.pop() | ||||||||||||||||||
if hasattr(key, 'dtype'): | ||||||||||||||||||
try: | ||||||||||||||||||
return array_dtype_to_str[key.dtype.kind] | ||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is behaviour with dtype='O' where the array contains only ints or only bools correct? (How does numpy handle this?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NumPy will return One would need to try converting the array to bool and int first. However, I thought that we are usually There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This maintains the behavior on master where we use ('O', 'U', 'S') to mean string: scikit-learn/sklearn/utils/__init__.py Lines 311 to 318 in 68044b0
|
||||||||||||||||||
except KeyError: | ||||||||||||||||||
raise ValueError(err_msg) | ||||||||||||||||||
raise ValueError(err_msg) | ||||||||||||||||||
|
||||||||||||||||||
def _safe_indexing_row(X, indices): | ||||||||||||||||||
"""Return items or rows from X using indices. | ||||||||||||||||||
|
||||||||||||||||||
Allows simple indexing of lists or arrays. | ||||||||||||||||||
def safe_indexing(X, indices, axis=0): | ||||||||||||||||||
"""Return rows, items or columns of X using indices. | ||||||||||||||||||
|
||||||||||||||||||
Parameters | ||||||||||||||||||
---------- | ||||||||||||||||||
X : array-like, sparse-matrix, list, pandas.DataFrame, pandas.Series | ||||||||||||||||||
Data from which to sample rows or items. | ||||||||||||||||||
indices : array-like of int | ||||||||||||||||||
Indices according to which X will be subsampled. | ||||||||||||||||||
Data from which to sample rows, items or columns. `list` are only | ||||||||||||||||||
supported when `axis=0`. | ||||||||||||||||||
indices : bool, int, str, slice, array-like | ||||||||||||||||||
- If `axis=0`, boolean and integer array-like, integer slice, | ||||||||||||||||||
and scalar integer are supported. | ||||||||||||||||||
- If `axis=1`: | ||||||||||||||||||
- to select a single column, `indices` can be of `int` type for | ||||||||||||||||||
all `X` types and `str` only for dataframe. The selected subset | ||||||||||||||||||
will be 1D, unless `X` is a sparse matrix in which case it will | ||||||||||||||||||
be 2D. | ||||||||||||||||||
- to select multiples columns, `indices` can be one of the | ||||||||||||||||||
following: `list`, `array`, `slice`. The type used in | ||||||||||||||||||
these containers can be one of the following: `int`, 'bool' and | ||||||||||||||||||
`str`. However, `str` is only supported when `X` is a dataframe. | ||||||||||||||||||
The selected subset will be 2D. | ||||||||||||||||||
axis : int, default=0 | ||||||||||||||||||
The axis along which `X` will be subsampled. `axis=0` will select | ||||||||||||||||||
rows while `axis=1` will select columns. | ||||||||||||||||||
|
||||||||||||||||||
Returns | ||||||||||||||||||
------- | ||||||||||||||||||
subset | ||||||||||||||||||
Subset of X on first axis. | ||||||||||||||||||
Subset of X on axis 0 or 1. | ||||||||||||||||||
|
||||||||||||||||||
Notes | ||||||||||||||||||
----- | ||||||||||||||||||
CSR, CSC, and LIL sparse matrices are supported. COO sparse matrices are | ||||||||||||||||||
not supported. | ||||||||||||||||||
""" | ||||||||||||||||||
if hasattr(X, "iloc"): | ||||||||||||||||||
# Work-around for indexing with read-only indices in pandas | ||||||||||||||||||
indices = np.asarray(indices) | ||||||||||||||||||
indices = indices if indices.flags.writeable else indices.copy() | ||||||||||||||||||
# Pandas Dataframes and Series | ||||||||||||||||||
try: | ||||||||||||||||||
return X.iloc[indices] | ||||||||||||||||||
except ValueError: | ||||||||||||||||||
# Cython typed memoryviews internally used in pandas do not support | ||||||||||||||||||
# readonly buffers. | ||||||||||||||||||
warnings.warn("Copying input dataframe for slicing.", | ||||||||||||||||||
DataConversionWarning) | ||||||||||||||||||
return X.copy().iloc[indices] | ||||||||||||||||||
elif hasattr(X, "shape"): | ||||||||||||||||||
if hasattr(X, 'take') and (hasattr(indices, 'dtype') and | ||||||||||||||||||
indices.dtype.kind == 'i'): | ||||||||||||||||||
# This is often substantially faster than X[indices] | ||||||||||||||||||
return X.take(indices, axis=0) | ||||||||||||||||||
else: | ||||||||||||||||||
return _array_indexing(X, indices, axis=0) | ||||||||||||||||||
else: | ||||||||||||||||||
return [X[idx] for idx in indices] | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _check_key_type(key, superclass): | ||||||||||||||||||
"""Check that scalar, list or slice is of a certain type. | ||||||||||||||||||
|
||||||||||||||||||
This is only used in _safe_indexing_column and _get_column_indices to check | ||||||||||||||||||
if the ``key`` (column specification) is fully integer or fully | ||||||||||||||||||
string-like. | ||||||||||||||||||
if indices is None: | ||||||||||||||||||
return X | ||||||||||||||||||
|
||||||||||||||||||
Parameters | ||||||||||||||||||
---------- | ||||||||||||||||||
key : scalar, list, slice, array-like | ||||||||||||||||||
The column specification to check. | ||||||||||||||||||
superclass : int or str | ||||||||||||||||||
The type for which to check the `key`. | ||||||||||||||||||
""" | ||||||||||||||||||
if isinstance(key, superclass): | ||||||||||||||||||
return True | ||||||||||||||||||
if isinstance(key, slice): | ||||||||||||||||||
return (isinstance(key.start, (superclass, type(None))) and | ||||||||||||||||||
isinstance(key.stop, (superclass, type(None)))) | ||||||||||||||||||
if isinstance(key, list): | ||||||||||||||||||
return all(isinstance(x, superclass) for x in key) | ||||||||||||||||||
if hasattr(key, 'dtype'): | ||||||||||||||||||
if superclass is int: | ||||||||||||||||||
return key.dtype.kind == 'i' | ||||||||||||||||||
elif superclass is bool: | ||||||||||||||||||
return key.dtype.kind == 'b' | ||||||||||||||||||
else: | ||||||||||||||||||
# superclass = str | ||||||||||||||||||
return key.dtype.kind in ('O', 'U', 'S') | ||||||||||||||||||
return False | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _safe_indexing_column(X, key): | ||||||||||||||||||
"""Get feature column(s) from input data X. | ||||||||||||||||||
if axis not in (0, 1): | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
"'axis' should be either 0 (to index rows) or 1 (to index " | ||||||||||||||||||
" column). Got {} instead.".format(axis) | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
Supported input types (X): numpy arrays, sparse arrays and DataFrames. | ||||||||||||||||||
indices_dtype = _determine_key_type(indices) | ||||||||||||||||||
|
||||||||||||||||||
Supported key types (key): | ||||||||||||||||||
- scalar: output is 1D; | ||||||||||||||||||
- lists, slices, boolean masks: output is 2D. | ||||||||||||||||||
if axis == 0 and indices_dtype == 'str': | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
"String indexing is not supported with 'axis=0'" | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
Supported key data types: | ||||||||||||||||||
- integer or boolean mask (positional): | ||||||||||||||||||
- supported for arrays, sparse matrices and dataframes. | ||||||||||||||||||
- string (key-based): | ||||||||||||||||||
- only supported for dataframes; | ||||||||||||||||||
- So no keys other than strings are allowed (while in principle you | ||||||||||||||||||
can use any hashable object as key). | ||||||||||||||||||
""" | ||||||||||||||||||
# check that X is a 2D structure | ||||||||||||||||||
if X.ndim != 2: | ||||||||||||||||||
if axis == 1 and X.ndim != 2: | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
"'X' should be a 2D NumPy array, 2D sparse matrix or pandas " | ||||||||||||||||||
"dataframe when indexing the columns (i.e. 'axis=1'). " | ||||||||||||||||||
"Got {} instead with {} dimension(s).".format(type(X), X.ndim) | ||||||||||||||||||
) | ||||||||||||||||||
# check whether we have string column names or integers | ||||||||||||||||||
if _check_key_type(key, int): | ||||||||||||||||||
column_names = False | ||||||||||||||||||
elif _check_key_type(key, str): | ||||||||||||||||||
column_names = True | ||||||||||||||||||
elif hasattr(key, 'dtype') and np.issubdtype(key.dtype, np.bool_): | ||||||||||||||||||
# boolean mask | ||||||||||||||||||
column_names = False | ||||||||||||||||||
if hasattr(X, 'loc'): | ||||||||||||||||||
# pandas boolean masks don't work with iloc, so take loc path | ||||||||||||||||||
column_names = True | ||||||||||||||||||
else: | ||||||||||||||||||
raise ValueError("No valid specification of the columns. Only a " | ||||||||||||||||||
"scalar, list or slice of all integers or all " | ||||||||||||||||||
"strings, or boolean mask is allowed") | ||||||||||||||||||
|
||||||||||||||||||
if column_names: | ||||||||||||||||||
if hasattr(X, 'loc'): | ||||||||||||||||||
# pandas dataframes | ||||||||||||||||||
return X.loc[:, key] | ||||||||||||||||||
else: | ||||||||||||||||||
raise ValueError("Specifying the columns using strings is only " | ||||||||||||||||||
"supported for pandas DataFrames") | ||||||||||||||||||
if axis == 1 and indices_dtype == 'str' and not hasattr(X, 'loc'): | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
"Specifying the columns using strings is only supported for " | ||||||||||||||||||
"pandas DataFrames" | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
if hasattr(X, "iloc"): | ||||||||||||||||||
return _pandas_indexing(X, indices, indices_dtype, axis=axis) | ||||||||||||||||||
elif hasattr(X, "shape"): | ||||||||||||||||||
return _array_indexing(X, indices, indices_dtype, axis=axis) | ||||||||||||||||||
else: | ||||||||||||||||||
if hasattr(X, 'iloc'): | ||||||||||||||||||
# pandas dataframes | ||||||||||||||||||
return X.iloc[:, key] | ||||||||||||||||||
else: | ||||||||||||||||||
# numpy arrays, sparse arrays | ||||||||||||||||||
return _array_indexing(X, key, axis=1) | ||||||||||||||||||
return _list_indexing(X, indices, indices_dtype) | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
def _get_column_indices(X, key): | ||||||||||||||||||
|
@@ -383,17 +350,22 @@ def _get_column_indices(X, key): | |||||||||||||||||
""" | ||||||||||||||||||
n_columns = X.shape[1] | ||||||||||||||||||
|
||||||||||||||||||
if (_check_key_type(key, int) | ||||||||||||||||||
or hasattr(key, 'dtype') and np.issubdtype(key.dtype, np.bool_)): | ||||||||||||||||||
key_dtype = _determine_key_type(key) | ||||||||||||||||||
|
||||||||||||||||||
if isinstance(key, list) and not key: | ||||||||||||||||||
# we get an empty list | ||||||||||||||||||
return [] | ||||||||||||||||||
elif key_dtype in ('bool', 'int'): | ||||||||||||||||||
# Convert key into positive indexes | ||||||||||||||||||
try: | ||||||||||||||||||
idx = safe_indexing(np.arange(n_columns), key) | ||||||||||||||||||
except IndexError as e: | ||||||||||||||||||
raise ValueError( | ||||||||||||||||||
'all features must be in [0, %d]' % (n_columns - 1) | ||||||||||||||||||
'all features must be in [0, {}] or [-{}, 0]' | ||||||||||||||||||
.format(n_columns - 1, n_columns) | ||||||||||||||||||
) from e | ||||||||||||||||||
return np.atleast_1d(idx).tolist() | ||||||||||||||||||
elif _check_key_type(key, str): | ||||||||||||||||||
elif key_dtype == 'str': | ||||||||||||||||||
try: | ||||||||||||||||||
all_columns = list(X.columns) | ||||||||||||||||||
except AttributeError: | ||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we use an explicit version comparison to take this codepath?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then we need to make an explicit pandas import