Skip to content

Commit

Permalink
Merge pull request #410 from larsmans/accept-matrix-input
Browse files Browse the repository at this point in the history
ENH accept matrix input throughout
  • Loading branch information
larsmans committed Oct 25, 2011
2 parents 59f1970 + 1f2549d commit 0c9cb49
Show file tree
Hide file tree
Showing 55 changed files with 2,669 additions and 2,441 deletions.
23 changes: 10 additions & 13 deletions doc/developers/index.rst
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -275,10 +275,16 @@ Input validation
---------------- ----------------


The module ``sklearn.utils`` contains various functions for doing input The module ``sklearn.utils`` contains various functions for doing input
validation/conversion. Sometimes, ``np.atleast_2d`` suffices for validation; validation/conversion. Sometimes, ``np.asarray`` suffices for validation;
in other cases, be sure to call ``safe_asanyarray``, ``atleast2d_or_csr`` or do `not` use ``np.asanyarray`` or ``np.atleast_2d``, since those let NumPy's
``as_float_array`` on any array-like argument passed to a scikit-learn API ``np.matrix`` through, which has a different API
function. (e.g., ``*`` means dot product on ``np.matrix``,
but Hadamard product on ``np.ndarray``).

In other cases, be sure to call ``safe_asarray``, ``atleast2d_or_csr``,
``as_float_array`` or ``array2d`` on any array-like argument passed to a
scikit-learn API function. The exact function to use depends mainly on whether
``scipy.sparse`` matrices must be accepted.




APIs of scikit-learn objects APIs of scikit-learn objects
Expand Down Expand Up @@ -430,15 +436,6 @@ you call ``fit`` a second time without taking any previous value into
account: **fit should be idempotent**. account: **fit should be idempotent**.




Python tuples
^^^^^^^^^^^^^

In addition to numpy arrays, all methods should be able to accept
Python tuples as arguments. In practice, this means you should call
``numpy.asanyarray`` at the beginning at each public method that accepts
arrays.


Optional Arguments Optional Arguments
^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^


Expand Down
7 changes: 5 additions & 2 deletions sklearn/cluster/_feature_agglomeration.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
# License: BSD 3 clause # License: BSD 3 clause


import numpy as np import numpy as np

from ..base import TransformerMixin from ..base import TransformerMixin
from ..utils import array2d




############################################################################### ###############################################################################
Expand All @@ -31,6 +33,7 @@ def transform(self, X, pooling_func=np.mean):
return an array of value of size M. return an array of value of size M.
Defaut is np.mean Defaut is np.mean
""" """
X = np.asarray(X)
nX = [] nX = []
for l in np.unique(self.labels_): for l in np.unique(self.labels_):
nX.append(pooling_func(X[:, self.labels_ == l], axis=1)) nX.append(pooling_func(X[:, self.labels_ == l], axis=1))
Expand Down Expand Up @@ -63,6 +66,6 @@ def inverse_transform(self, Xred):
X[self.labels_ == unil[i]] = Xred[i] X[self.labels_ == unil[i]] = Xred[i]
else: else:
ncol = np.sum(self.labels_ == unil[i]) ncol = np.sum(self.labels_ == unil[i])
X[:, self.labels_ == unil[i]] = np.tile(np.atleast_2d(Xred X[:, self.labels_ == unil[i]] = np.tile(array2d(Xred[:, i]).T,
[:, i]).T, ncol) ncol)
return X return X
7 changes: 2 additions & 5 deletions sklearn/cluster/affinity_propagation_.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np import numpy as np


from ..base import BaseEstimator from ..base import BaseEstimator
from ..utils import as_float_array




def affinity_propagation(S, p=None, convit=30, max_iter=200, damping=0.5, def affinity_propagation(S, p=None, convit=30, max_iter=200, damping=0.5,
Expand Down Expand Up @@ -49,11 +50,7 @@ def affinity_propagation(S, p=None, convit=30, max_iter=200, damping=0.5,
Between Data Points", Science Feb. 2007 Between Data Points", Science Feb. 2007
""" """
if copy: S = as_float_array(S, copy=copy)
# Copy the affinity matrix to avoid modifying it inplace
S = np.array(S, copy=True, dtype=np.float)
else:
S = np.asanyarray(S, dtype=np.float)


n_points = S.shape[0] n_points = S.shape[0]


Expand Down
2 changes: 1 addition & 1 deletion sklearn/cluster/hierarchical.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def ward_tree(X, connectivity=None, n_components=None, copy=True):
n_leaves : int n_leaves : int
The number of leaves in the tree The number of leaves in the tree
""" """
X = np.asanyarray(X) X = np.asarray(X)
n_samples, n_features = X.shape n_samples, n_features = X.shape
if X.ndim == 1: if X.ndim == 1:
X = np.reshape(X, (-1, 1)) X = np.reshape(X, (-1, 1))
Expand Down
4 changes: 2 additions & 2 deletions sklearn/cluster/k_means_.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -499,8 +499,8 @@ def __init__(self, k=8, init='k-means++', n_init=10, max_iter=300,
def _check_data(self, X): def _check_data(self, X):
"""Verify that the number of samples given is larger than k""" """Verify that the number of samples given is larger than k"""
if sp.issparse(X): if sp.issparse(X):
raise ValueError("K-Means does not support sparse input matrices.") raise TypeError("K-Means does not support sparse input matrices.")
X = np.asanyarray(X) X = np.asarray(X)
if X.shape[0] < self.k: if X.shape[0] < self.k:
raise ValueError("n_samples=%d should be >= k=%d" % ( raise ValueError("n_samples=%d should be >= k=%d" % (
X.shape[0], self.k)) X.shape[0], self.k))
Expand Down
5 changes: 3 additions & 2 deletions sklearn/covariance/empirical_covariance_.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from scipy import linalg from scipy import linalg


from ..base import BaseEstimator from ..base import BaseEstimator
from ..utils import array2d
from ..utils.extmath import fast_logdet as exact_logdet from ..utils.extmath import fast_logdet as exact_logdet




Expand Down Expand Up @@ -52,7 +53,7 @@ def empirical_covariance(X, assume_centered=False):
Empirical covariance (Maximum Likelihood Estimator) Empirical covariance (Maximum Likelihood Estimator)
""" """
X = np.asanyarray(X) X = np.asarray(X)
if X.ndim == 1: if X.ndim == 1:
X = np.atleast_2d(X).T X = np.atleast_2d(X).T


Expand Down Expand Up @@ -98,7 +99,7 @@ def _set_estimates(self, covariance):
is computed. is computed.
""" """
covariance = np.atleast_2d(covariance) covariance = array2d(covariance)
# set covariance # set covariance
self.covariance_ = covariance self.covariance_ = covariance
# set precision # set precision
Expand Down
2 changes: 1 addition & 1 deletion sklearn/covariance/robust_covariance.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def fast_mcd(X, correction="empirical", reweight="rousseeuw"):
the robust location and covariance estimates of the data set the robust location and covariance estimates of the data set
""" """
X = np.asanyarray(X) X = np.asarray(X)
if X.ndim <= 1: if X.ndim <= 1:
X = X.reshape((-1, 1)) X = X.reshape((-1, 1))
n_samples, n_features = X.shape n_samples, n_features = X.shape
Expand Down
7 changes: 4 additions & 3 deletions sklearn/covariance/shrunk_covariance_.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np import numpy as np


from .empirical_covariance_ import empirical_covariance, EmpiricalCovariance from .empirical_covariance_ import empirical_covariance, EmpiricalCovariance
from ..utils import array2d




############################################################################### ###############################################################################
Expand Down Expand Up @@ -50,7 +51,7 @@ def shrunk_covariance(emp_cov, shrinkage=0.1):
where mu = trace(cov) / n_features where mu = trace(cov) / n_features
""" """
emp_cov = np.atleast_2d(emp_cov) emp_cov = array2d(emp_cov)
n_features = emp_cov.shape[0] n_features = emp_cov.shape[0]


mu = np.trace(emp_cov) / n_features mu = np.trace(emp_cov) / n_features
Expand Down Expand Up @@ -165,7 +166,7 @@ def ledoit_wolf(X, assume_centered=False):
where mu = trace(cov) / n_features where mu = trace(cov) / n_features
""" """
X = np.asanyarray(X) X = np.asarray(X)
# for only one feature, the result is the same whatever the shrinkage # for only one feature, the result is the same whatever the shrinkage
if X.ndim == 1: if X.ndim == 1:
if not assume_centered: if not assume_centered:
Expand Down Expand Up @@ -303,7 +304,7 @@ def oas(X, assume_centered=False):
where mu = trace(cov) / n_features where mu = trace(cov) / n_features
""" """
X = np.asanyarray(X) X = np.asarray(X)
# for only one feature, the result is the same whatever the shrinkage # for only one feature, the result is the same whatever the shrinkage
if X.ndim == 1: if X.ndim == 1:
if not assume_centered: if not assume_centered:
Expand Down
7 changes: 1 addition & 6 deletions sklearn/covariance/tests/test_covariance.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
fast_mcd, MCD fast_mcd, MCD


X = datasets.load_iris().data X = datasets.load_iris().data
X_1d = X[:, 0]
n_samples, n_features = X.shape n_samples, n_features = X.shape




Expand All @@ -40,7 +41,6 @@ def test_covariance():
assert(np.amin(mahal_dist) > 50) assert(np.amin(mahal_dist) > 50)


# test with n_features = 1 # test with n_features = 1
X_1d = X[:, 0]
cov = EmpiricalCovariance() cov = EmpiricalCovariance()
cov.fit(X_1d) cov.fit(X_1d)
assert_array_almost_equal(empirical_covariance(X_1d), cov.covariance_, 4) assert_array_almost_equal(empirical_covariance(X_1d), cov.covariance_, 4)
Expand Down Expand Up @@ -78,7 +78,6 @@ def test_shrunk_covariance():
assert_array_almost_equal(empirical_covariance(X), cov.covariance_, 4) assert_array_almost_equal(empirical_covariance(X), cov.covariance_, 4)


# test with n_features = 1 # test with n_features = 1
X_1d = X[:, 0]
cov = ShrunkCovariance(shrinkage=0.3) cov = ShrunkCovariance(shrinkage=0.3)
cov.fit(X_1d) cov.fit(X_1d)
assert_array_almost_equal(empirical_covariance(X_1d), cov.covariance_, 4) assert_array_almost_equal(empirical_covariance(X_1d), cov.covariance_, 4)
Expand Down Expand Up @@ -109,7 +108,6 @@ def test_ledoit_wolf():
assert_array_almost_equal(scov.covariance_, lw.covariance_, 4) assert_array_almost_equal(scov.covariance_, lw.covariance_, 4)


# test with n_features = 1 # test with n_features = 1
X_1d = X[:, 0]
lw = LedoitWolf() lw = LedoitWolf()
lw.fit(X_1d, assume_centered=True) lw.fit(X_1d, assume_centered=True)
lw_cov_from_mle, lw_shinkrage_from_mle = ledoit_wolf(X_1d, lw_cov_from_mle, lw_shinkrage_from_mle = ledoit_wolf(X_1d,
Expand Down Expand Up @@ -140,7 +138,6 @@ def test_ledoit_wolf():
assert_array_almost_equal(scov.covariance_, lw.covariance_, 4) assert_array_almost_equal(scov.covariance_, lw.covariance_, 4)


# test with n_features = 1 # test with n_features = 1
X_1d = X[:, 0]
lw = LedoitWolf() lw = LedoitWolf()
lw.fit(X_1d) lw.fit(X_1d)
lw_cov_from_mle, lw_shinkrage_from_mle = ledoit_wolf(X_1d) lw_cov_from_mle, lw_shinkrage_from_mle = ledoit_wolf(X_1d)
Expand Down Expand Up @@ -174,7 +171,6 @@ def test_oas():
assert_array_almost_equal(scov.covariance_, oa.covariance_, 4) assert_array_almost_equal(scov.covariance_, oa.covariance_, 4)


# test with n_features = 1 # test with n_features = 1
X_1d = X[:, 0]
oa = OAS() oa = OAS()
oa.fit(X_1d, assume_centered=True) oa.fit(X_1d, assume_centered=True)
oa_cov_from_mle, oa_shinkrage_from_mle = oas(X_1d, assume_centered=True) oa_cov_from_mle, oa_shinkrage_from_mle = oas(X_1d, assume_centered=True)
Expand Down Expand Up @@ -204,7 +200,6 @@ def test_oas():
assert_array_almost_equal(scov.covariance_, oa.covariance_, 4) assert_array_almost_equal(scov.covariance_, oa.covariance_, 4)


# test with n_features = 1 # test with n_features = 1
X_1d = X[:, 0]
oa = OAS() oa = OAS()
oa.fit(X_1d) oa.fit(X_1d)
oa_cov_from_mle, oa_shinkrage_from_mle = oas(X_1d) oa_cov_from_mle, oa_shinkrage_from_mle = oas(X_1d)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/cross_validation.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class StratifiedKFold(object):
""" """


def __init__(self, y, k, indices=False): def __init__(self, y, k, indices=False):
y = np.asanyarray(y) y = np.asarray(y)
n = y.shape[0] n = y.shape[0]
assert k > 0, ValueError('Cannot have number of folds k below 1.') assert k > 0, ValueError('Cannot have number of folds k below 1.')
assert k <= n, ValueError('Cannot have number of folds k=%d, ' assert k <= n, ValueError('Cannot have number of folds k=%d, '
Expand Down
8 changes: 4 additions & 4 deletions sklearn/datasets/base.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ def load_iris():
target = np.empty((n_samples,), dtype=np.int) target = np.empty((n_samples,), dtype=np.int)


for i, ir in enumerate(data_file): for i, ir in enumerate(data_file):
data[i] = np.asanyarray(ir[:-1], dtype=np.float) data[i] = np.asarray(ir[:-1], dtype=np.float)
target[i] = np.asanyarray(ir[-1], dtype=np.int) target[i] = np.asarray(ir[-1], dtype=np.int)


return Bunch(data=data, target=target, return Bunch(data=data, target=target,
target_names=target_names, target_names=target_names,
Expand Down Expand Up @@ -350,8 +350,8 @@ def load_boston():
feature_names = np.array(temp) feature_names = np.array(temp)


for i, d in enumerate(data_file): for i, d in enumerate(data_file):
data[i] = np.asanyarray(d[:-1], dtype=np.float) data[i] = np.asarray(d[:-1], dtype=np.float)
target[i] = np.asanyarray(d[-1], dtype=np.float) target[i] = np.asarray(d[-1], dtype=np.float)


return Bunch(data=data, return Bunch(data=data,
target=target, target=target,
Expand Down
4 changes: 2 additions & 2 deletions sklearn/datasets/samples_generator.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np import numpy as np
from scipy import linalg from scipy import linalg


from ..utils import check_random_state from ..utils import array2d, check_random_state




def make_classification(n_samples=100, n_features=20, n_informative=2, def make_classification(n_samples=100, n_features=20, n_informative=2,
Expand Down Expand Up @@ -395,7 +395,7 @@ def make_blobs(n_samples=100, n_features=2, centers=3, cluster_std=1.0,
centers = generator.uniform(center_box[0], center_box[1], centers = generator.uniform(center_box[0], center_box[1],
size=(centers, n_features)) size=(centers, n_features))
else: else:
centers = np.atleast_2d(centers) centers = array2d(centers)
n_features = centers.shape[1] n_features = centers.shape[1]


X = [] X = []
Expand Down
14 changes: 7 additions & 7 deletions sklearn/decomposition/dict_learning.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@


from ..base import BaseEstimator, TransformerMixin from ..base import BaseEstimator, TransformerMixin
from ..externals.joblib import Parallel, delayed, cpu_count from ..externals.joblib import Parallel, delayed, cpu_count
from ..utils import check_random_state from ..utils import array2d, check_random_state, gen_even_slices
from ..utils import gen_even_slices
from ..utils.extmath import fast_svd from ..utils.extmath import fast_svd
from ..linear_model import Lasso, orthogonal_mp_gram, lars_path from ..linear_model import Lasso, orthogonal_mp_gram, lars_path


Expand Down Expand Up @@ -90,7 +89,8 @@ def sparse_encode(X, Y, gram=None, cov=None, algorithm='lasso_lars',
linear_model.Lasso linear_model.Lasso
""" """
alpha = float(alpha) if alpha is not None else None alpha = float(alpha) if alpha is not None else None
X, Y = map(np.asanyarray, (X, Y)) X = np.asarray(X)
Y = np.asarray(Y)
if Y.ndim == 1: if Y.ndim == 1:
Y = Y[:, np.newaxis] Y = Y[:, np.newaxis]
n_features = Y.shape[1] n_features = Y.shape[1]
Expand Down Expand Up @@ -688,7 +688,7 @@ def transform(self, X, y=None):
Transformed data Transformed data
""" """
# XXX : kwargs is not documented # XXX : kwargs is not documented
X = np.atleast_2d(X) X = array2d(X)
n_samples, n_features = X.shape n_samples, n_features = X.shape


code = sparse_encode_parallel( code = sparse_encode_parallel(
Expand Down Expand Up @@ -832,7 +832,7 @@ def fit(self, X, y=None):
Returns the object itself Returns the object itself
""" """
self.random_state = check_random_state(self.random_state) self.random_state = check_random_state(self.random_state)
X = np.asanyarray(X) X = np.asarray(X)
V, U, E = dict_learning(X, self.n_atoms, self.alpha, V, U, E = dict_learning(X, self.n_atoms, self.alpha,
tol=self.tol, max_iter=self.max_iter, tol=self.tol, max_iter=self.max_iter,
method=self.fit_algorithm, method=self.fit_algorithm,
Expand Down Expand Up @@ -968,7 +968,7 @@ def fit(self, X, y=None):
Returns the instance itself. Returns the instance itself.
""" """
self.random_state = check_random_state(self.random_state) self.random_state = check_random_state(self.random_state)
X = np.asanyarray(X) X = np.asarray(X)
U = dict_learning_online(X, self.n_atoms, self.alpha, U = dict_learning_online(X, self.n_atoms, self.alpha,
n_iter=self.n_iter, return_code=False, n_iter=self.n_iter, return_code=False,
method=self.fit_algorithm, method=self.fit_algorithm,
Expand All @@ -995,7 +995,7 @@ def partial_fit(self, X, y=None, iter_offset=0):
Returns the instance itself. Returns the instance itself.
""" """
self.random_state = check_random_state(self.random_state) self.random_state = check_random_state(self.random_state)
X = np.atleast_2d(X) X = array2d(X)
if hasattr(self, 'components_'): if hasattr(self, 'components_'):
dict_init = self.components_ dict_init = self.components_
else: else:
Expand Down
5 changes: 3 additions & 2 deletions sklearn/decomposition/fastica_.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from scipy import linalg from scipy import linalg


from ..base import BaseEstimator from ..base import BaseEstimator
from ..utils import array2d


__all__ = ['fastica', 'FastICA'] __all__ = ['fastica', 'FastICA']


Expand Down Expand Up @@ -121,7 +122,7 @@ def fastica(X, n_components=None, algorithm="parallel", whiten=True,
Parameters Parameters
---------- ----------
X : (n, p) array of shape = [n_samples, n_features], optional X : array-like, shape = [n_samples, n_features]
Training vector, where n_samples is the number of samples and Training vector, where n_samples is the number of samples and
n_features is the number of features. n_features is the number of features.
n_components : int, optional n_components : int, optional
Expand Down Expand Up @@ -197,7 +198,7 @@ def fastica(X, n_components=None, algorithm="parallel", whiten=True,
# make interface compatible with other decompositions # make interface compatible with other decompositions
warnings.warn("Please note: the interface of fastica has changed: " warnings.warn("Please note: the interface of fastica has changed: "
"X is now assumed to be of shape [n_samples, n_features]") "X is now assumed to be of shape [n_samples, n_features]")
X = X.T X = array2d(X).T


algorithm_funcs = {'parallel': _ica_par, algorithm_funcs = {'parallel': _ica_par,
'deflation': _ica_def} 'deflation': _ica_def}
Expand Down
Loading

0 comments on commit 0c9cb49

Please sign in to comment.