Skip to content

Commit

Permalink
remove check_arrays stuff and old input validation
Browse files Browse the repository at this point in the history
  • Loading branch information
amueller committed Jul 20, 2014
1 parent 8dab222 commit 6e2a83b
Show file tree
Hide file tree
Showing 86 changed files with 9,079 additions and 10,706 deletions.
9 changes: 4 additions & 5 deletions doc/developers/index.rst
Expand Up @@ -401,10 +401,9 @@ do *not* use ``np.asanyarray`` or ``np.atleast_2d``, since those let NumPy's
(e.g., ``*`` means dot product on ``np.matrix``,
but Hadamard product on ``np.ndarray``).

In other cases, be sure to call :func:`safe_asarray`, :func:`atleast2d_or_csr`,
:func:`as_float_array` or :func:`array2d` on any array-like argument passed to a
scikit-learn API function. The exact function to use depends mainly on whether
``scipy.sparse`` matrices must be accepted.
In other cases, be sure to call :func:`check_array` on any array-like argument
passed to a scikit-learn API function. The exact parameters to use depends
mainly on whether and which ``scipy.sparse`` matrices must be accepted.

For more information, refer to the :ref:`developers-utils` page.

Expand Down Expand Up @@ -490,7 +489,7 @@ E.g., if the function ``zero_one`` is renamed to ``zero_one_loss``,
we add the decorator ``deprecated`` (from ``sklearn.utils``)
to ``zero_one`` and call ``zero_one_loss`` from that function::

from ..utils import check_arrays, deprecated
from ..utils import deprecated

def zero_one_loss(y_true, y_pred, normalize=True):
# actual implementation
Expand Down
19 changes: 10 additions & 9 deletions doc/developers/utilities.rst
Expand Up @@ -27,20 +27,21 @@ should be used when applicable.

- :func:`assert_all_finite`: Throw an error if array contains NaNs or Infs.

- :func:`safe_asarray`: Convert input to array or sparse matrix. Equivalent
to ``np.asarray``, but sparse matrices are passed through.

- :func:`as_float_array`: convert input to an array of floats. If a sparse
matrix is passed, a sparse matrix will be returned.

- :func:`array2d`: equivalent to ``np.atleast_2d``, but the ``order`` and
``dtype`` of the input are maintained.
- :func:`check_array`: convert input to ndim 2 array, raise error on sparse
matrices. Allowed sparse matrix formats can be given optionally, as well as
allowing 1d or nd arrays. Calls :func:`assert_all_finite` by default.

- :func:`atleast2d_or_csr`: equivalent to ``array2d``, but if a sparse matrix
is passed, will convert to csr format. Also calls ``assert_all_finite``.
- :func:`check_X_y`: check that X and y have consistent length, calls
check_array on X, and column_or_1d on y. For multilabel classification or
multitarget regression, specify multi_ouput=True, in which case check_array
will be called on y.

- :func:`check_arrays`: check that all input arrays have consistent first
dimensions. This will work for an arbitrary number of arrays.
- :func:`indexable`: check that all input arrays have consistent length and can
be sliced or indexed using safe_index. This is used to validate input for
cross-validation.

- :func:`warn_if_not_float`: Warn if input is not a floating-point value.
the input ``X`` is assumed to have ``X.dtype``.
Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/_feature_agglomeration.py
Expand Up @@ -8,7 +8,7 @@
import numpy as np

from ..base import TransformerMixin
from ..utils import array2d
from ..utils import check_array

import warnings

Expand Down Expand Up @@ -48,7 +48,7 @@ def transform(self, X, pooling_func=None):
"removed in 0.18. Pass it to the constructor instead.", DeprecationWarning)
else:
pooling_func = self.pooling_func
X = array2d(X)
X = check_array(X)
nX = []
if len(self.labels_) != X.shape[1]:
raise ValueError("X has a different number of features than "
Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/bicluster/spectral.py
Expand Up @@ -24,7 +24,7 @@
from sklearn.utils.extmath import norm

from sklearn.utils.validation import assert_all_finite
from sklearn.utils.validation import check_arrays
from sklearn.utils.validation import check_array

from .utils import check_array_ndim

Expand Down Expand Up @@ -120,7 +120,7 @@ def fit(self, X):
X : array-like, shape (n_samples, n_features)
"""
X, = check_arrays(X, sparse_format='csr', dtype=np.float64)
X = check_array(X, 'csr', dtype=np.float64)
check_array_ndim(X)
self._check_parameters()
self._fit(X)
Expand Down
8 changes: 4 additions & 4 deletions sklearn/cluster/hierarchical.py
Expand Up @@ -18,7 +18,7 @@
from ..externals.joblib import Memory
from ..externals import six
from ..metrics.pairwise import paired_distances, pairwise_distances
from ..utils import array2d, safe_asarray, check_arrays
from ..utils import check_array
from ..utils.sparsetools import connected_components

from . import _hierarchical
Expand Down Expand Up @@ -614,7 +614,7 @@ def fit(self, X):
-------
self
"""
X = array2d(X)
X = check_array(X)
memory = self.memory
if isinstance(memory, six.string_types):
memory = Memory(cachedir=memory, verbose=0)
Expand Down Expand Up @@ -685,7 +685,7 @@ def fit(self, X, y=None, **params):
-------
self
"""
X = safe_asarray(X)
X = check_array(X, ['csr', 'csc', 'coo'])
if not (len(X.shape) == 2 and X.shape[0] > 0):
raise ValueError('At least one sample is required to fit the '
'model. A data matrix of shape %s was given.'
Expand Down Expand Up @@ -838,5 +838,5 @@ def fit(self, X, y=None, **params):
-------
self
"""
X, = check_arrays(X)
X = check_array(X)
return Ward.fit(self, X.T, **params)
12 changes: 5 additions & 7 deletions sklearn/cluster/k_means_.py
Expand Up @@ -22,9 +22,8 @@
from ..utils.sparsefuncs_fast import assign_rows_csr
from ..utils.sparsefuncs import mean_variance_axis0
from ..utils.fixes import astype
from ..utils import check_arrays
from ..utils import check_array
from ..utils import check_random_state
from ..utils import atleast2d_or_csr
from ..utils import as_float_array
from ..utils import gen_batches
from ..utils.random import choice
Expand Down Expand Up @@ -688,14 +687,14 @@ def __init__(self, n_clusters=8, init='k-means++', n_init=10, max_iter=300,

def _check_fit_data(self, X):
"""Verify that the number of samples given is larger than k"""
X = atleast2d_or_csr(X, dtype=np.float64)
X = check_array(X, 'csr', dtype=np.float64)
if X.shape[0] < self.n_clusters:
raise ValueError("n_samples=%d should be >= n_clusters=%d" % (
X.shape[0], self.n_clusters))
return X

def _check_test_data(self, X):
X = atleast2d_or_csr(X)
X = check_array(X, 'csr')
n_samples, n_features = X.shape
expected_n_features = self.cluster_centers_.shape[1]
if not n_features == expected_n_features:
Expand Down Expand Up @@ -1132,8 +1131,7 @@ def fit(self, X, y=None):
Coordinates of the data points to cluster
"""
random_state = check_random_state(self.random_state)
X = check_arrays(X, sparse_format="csr", copy=False,
check_ccontiguous=True, dtype=np.float64)[0]
X = check_array(X, "csr", order='C', dtype=np.float64)
n_samples, n_features = X.shape
if n_samples < self.n_clusters:
raise ValueError("Number of samples smaller than number "
Expand Down Expand Up @@ -1293,7 +1291,7 @@ def partial_fit(self, X, y=None):
Coordinates of the data points to cluster.
"""

X = check_arrays(X, sparse_format="csr", copy=False)[0]
X = check_array(X, "csr")
n_samples, n_features = X.shape
if hasattr(self.init, '__array__'):
self.init = np.ascontiguousarray(self.init, dtype=np.float64)
Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/spectral.py
Expand Up @@ -11,7 +11,7 @@

from ..base import BaseEstimator, ClusterMixin
from ..utils import check_random_state, as_float_array
from ..utils.validation import check_arrays
from ..utils.validation import check_array
from ..utils.extmath import norm
from ..metrics.pairwise import pairwise_kernels
from ..neighbors import kneighbors_graph
Expand Down Expand Up @@ -415,7 +415,7 @@ def fit(self, X):
OR, if affinity==`precomputed`, a precomputed affinity
matrix of shape (n_samples, n_samples)
"""
X, = check_arrays(X)
X = check_array(X, ['csr', 'csc', 'coo'])
if X.shape[0] == X.shape[1] and self.affinity != "precomputed":
warnings.warn("The spectral clustering API has changed. ``fit``"
"now constructs an affinity matrix from data. To use"
Expand Down
4 changes: 2 additions & 2 deletions sklearn/covariance/empirical_covariance_.py
Expand Up @@ -16,7 +16,7 @@
from scipy import linalg

from ..base import BaseEstimator
from ..utils import array2d
from ..utils import check_array
from ..utils.extmath import fast_logdet, pinvh


Expand Down Expand Up @@ -122,7 +122,7 @@ def _set_covariance(self, covariance):
is computed.
"""
covariance = array2d(covariance)
covariance = check_array(covariance)
# set covariance
self.covariance_ = covariance
# set precision
Expand Down
4 changes: 2 additions & 2 deletions sklearn/covariance/shrunk_covariance_.py
Expand Up @@ -19,7 +19,7 @@

from .empirical_covariance_ import empirical_covariance, EmpiricalCovariance
from ..externals.six.moves import xrange
from ..utils import array2d
from ..utils import check_array


# ShrunkCovariance estimator
Expand Down Expand Up @@ -51,7 +51,7 @@ def shrunk_covariance(emp_cov, shrinkage=0.1):
where mu = trace(cov) / n_features
"""
emp_cov = array2d(emp_cov)
emp_cov = check_array(emp_cov)
n_features = emp_cov.shape[0]

mu = np.trace(emp_cov) / n_features
Expand Down
28 changes: 7 additions & 21 deletions sklearn/cross_decomposition/pls_.py
Expand Up @@ -6,7 +6,7 @@
# License: BSD 3 clause

from ..base import BaseEstimator, RegressorMixin, TransformerMixin
from ..utils import check_arrays
from ..utils import check_array, check_consistent_length
from ..externals import six

import warnings
Expand Down Expand Up @@ -229,15 +229,9 @@ def fit(self, X, Y):
"""

# copy since this will contains the residuals (deflated) matrices
X, Y = check_arrays(X, Y, dtype=np.float, copy=self.copy,
sparse_format='dense')

if X.ndim != 2:
raise ValueError('X must be a 2D array')
if Y.ndim == 1:
Y = Y.reshape((Y.size, 1))
if Y.ndim != 2:
raise ValueError('Y must be a 1D or a 2D array')
check_consistent_length(X, Y)
X = check_array(X, dtype=np.float, copy=self.copy)
Y = check_array(Y, dtype=np.float, copy=self.copy)

n = X.shape[0]
p = X.shape[1]
Expand Down Expand Up @@ -727,20 +721,12 @@ def __init__(self, n_components=2, scale=True, copy=True):

def fit(self, X, Y):
# copy since this will contains the centered data
X, Y = check_arrays(X, Y, dtype=np.float, copy=self.copy,
sparse_format='dense')
check_consistent_length(X, Y)
X = check_array(X, dtype=np.float, copy=self.copy)
Y = check_array(Y, dtype=np.float, copy=self.copy)

n = X.shape[0]
p = X.shape[1]

if X.ndim != 2:
raise ValueError('X must be a 2D array')

if n != Y.shape[0]:
raise ValueError(
'Incompatible shapes: X has %s samples, while Y '
'has %s' % (X.shape[0], Y.shape[0]))

if self.n_components < 1 or self.n_components > p:
raise ValueError('invalid number of components')

Expand Down
33 changes: 18 additions & 15 deletions sklearn/cross_validation.py
Expand Up @@ -22,8 +22,8 @@
import scipy.sparse as sp

from .base import is_classifier, clone
from .utils import check_arrays, check_random_state, safe_indexing
from .utils.validation import _num_samples
from .utils import indexable, check_random_state, safe_indexing
from .utils.validation import _num_samples, check_array
from .externals.joblib import Parallel, delayed, logger
from .externals.six import with_metaclass
from .externals.six.moves import zip
Expand Down Expand Up @@ -1133,8 +1133,7 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
scores : array of float, shape=(len(list(cv)),)
Array of scores of the estimator for each run of the cross validation.
"""
X, y = check_arrays(X, y, sparse_format='csr', force_arrays=False,
allow_nans=True, allow_nd=True)
X, y = indexable(X, y)

cv = _check_cv(cv, X, y, classifier=is_classifier(estimator))
scorer = check_scoring(estimator, scoring=scoring)
Expand Down Expand Up @@ -1443,7 +1442,7 @@ def permutation_test_score(estimator, X, y, cv=None,
vol. 11
"""
X, y = check_arrays(X, y, sparse_format='csr', allow_nans=True)
X, y = indexable(X, y)
cv = _check_cv(cv, X, y, classifier=is_classifier(estimator))
scorer = check_scoring(estimator, scoring=scoring)
random_state = check_random_state(random_state)
Expand All @@ -1467,7 +1466,7 @@ def permutation_test_score(estimator, X, y, cv=None,
def train_test_split(*arrays, **options):
"""Split arrays or matrices into random train and test subsets
Quick utility that wraps calls to ``check_arrays`` and
Quick utility that wraps input validation and
``next(iter(ShuffleSplit(n_samples)))`` and application to input
data into a single call for splitting (and optionally subsampling)
data in a oneliner.
Expand All @@ -1494,9 +1493,6 @@ def train_test_split(*arrays, **options):
random_state : int or RandomState
Pseudo-random number generator state used for random sampling.
dtype : a numpy dtype instance, None by default
Enforce a specific dtype.
Returns
-------
splitting : list of arrays, length=2 * len(arrays)
Expand Down Expand Up @@ -1539,15 +1535,22 @@ def train_test_split(*arrays, **options):
test_size = options.pop('test_size', None)
train_size = options.pop('train_size', None)
random_state = options.pop('random_state', None)
options['sparse_format'] = 'csr'
options['allow_nans'] = True
if not "force_arrays" in options:
options["force_arrays"] = False
dtype = options.pop('dtype', None)
if dtype is not None:
warnings.warn("dtype option is ignored and will be removed in 0.17.")

force_arrays = options.pop('force_arrays', False)
if options:
raise TypeError("Invalid parameters passed: %s" % str(options))
if force_arrays:
warnings.warn("The force_arrays option is deprecated and will be "
"removed in 0.17.", DeprecationWarning)
arrays = [check_array(x, 'csr', ensure_2d=False, force_all_finite=False)
if x is not None else x for x in arrays]

if test_size is None and train_size is None:
test_size = 0.25

arrays = check_arrays(*arrays, **options)
arrays = indexable(*arrays)
n_samples = _num_samples(arrays[0])
cv = ShuffleSplit(n_samples, test_size=test_size,
train_size=train_size,
Expand Down
4 changes: 2 additions & 2 deletions sklearn/datasets/samples_generator.py
Expand Up @@ -14,7 +14,7 @@
import scipy.sparse as sp

from ..preprocessing import MultiLabelBinarizer
from ..utils import array2d, check_random_state
from ..utils import check_array, check_random_state
from ..utils import shuffle as util_shuffle
from ..utils.fixes import astype
from ..utils.random import sample_without_replacement
Expand Down Expand Up @@ -695,7 +695,7 @@ def make_blobs(n_samples=100, n_features=2, centers=3, cluster_std=1.0,
centers = generator.uniform(center_box[0], center_box[1],
size=(centers, n_features))
else:
centers = array2d(centers)
centers = check_array(centers)
n_features = centers.shape[1]

X = []
Expand Down
4 changes: 2 additions & 2 deletions sklearn/datasets/svmlight_format.py
Expand Up @@ -27,7 +27,7 @@
from ..externals import six
from ..externals.six import u, b
from ..externals.six.moves import range, zip
from ..utils import atleast2d_or_csr
from ..utils import check_array


def load_svmlight_file(f, n_features=None, dtype=np.float64,
Expand Down Expand Up @@ -356,7 +356,7 @@ def dump_svmlight_file(X, y, f, zero_based=True, comment=None, query_id=None):
raise ValueError("expected y of shape (n_samples,), got %r"
% (y.shape,))

Xval = atleast2d_or_csr(X)
Xval = check_array(X, 'csr')
if Xval.shape[0] != y.shape[0]:
raise ValueError("X.shape[0] and y.shape[0] should be the same, got"
" %r and %r instead." % (Xval.shape[0], y.shape[0]))
Expand Down

0 comments on commit 6e2a83b

Please sign in to comment.