Permalink
Browse files

Merge branch 'cross_val' of github.com:GaelVaroquaux/scikit-learn

  • Loading branch information...
2 parents 7b12364 + fd60658 commit 2138c515ba9fa88460bb753737b16ad1884e11f0 @GaelVaroquaux GaelVaroquaux committed Sep 16, 2010
View
@@ -6,24 +6,45 @@
# License: BSD Style
import inspect
+import copy
import numpy as np
from .metrics import explained_variance
################################################################################
-def clone(estimator):
+def clone(estimator, safe=True):
""" Constructs a new estimator with the same parameters.
Clone does a deep copy of the model in an estimator
without actually copying attached data. It yields a new estimator
with the same parameters that has not been fit on any data.
+
+ Parameters
+ ============
+ estimator: estimator object, or list, tuple or set of objects
+ The estimator or group of estimators to be cloned
+ safe: boolean, optional
+ If safe is false, clone will fall back to a deepcopy on objects
+ that are not estimators.
+
"""
+ estimator_type = type(estimator)
+ # XXX: not handling dictionnaries
+ if estimator_type in (list, tuple, set, frozenset):
+ return estimator_type([clone(e, safe=safe) for e in estimator])
+ elif not hasattr(estimator, '_get_params'):
+ if not safe:
+ return copy.deepcopy(estimator)
+ else:
+ raise ValueError("Cannot clone object '%s' (type %s): "
+ "it does not seem to be a scikit-learn estimator as "
+ "it does not implement a '_get_params' methods."
+ % (repr(estimator), type(estimator)))
klass = estimator.__class__
new_object_params = estimator._get_params(deep=False)
for name, param in new_object_params.iteritems():
- if hasattr(param, '_get_params'):
- new_object_params[name] = clone(param)
+ new_object_params[name] = clone(param, safe=False)
new_object = klass(**new_object_params)
return new_object
@@ -108,7 +129,7 @@ def _get_param_names(cls):
args = []
return args
- def _get_params(self, deep=False):
+ def _get_params(self, deep=True):
""" Get parameters for the estimator
Parameters
@@ -220,3 +241,26 @@ def score(self, X, y):
z : float
"""
return explained_variance(self.predict(X), y)
+
+
+################################################################################
+# XXX: Temporary solution to figure out if an estimator is a classifier
+
+def _get_sub_estimator(estimator):
+ """ Returns the final estimator if there is any.
+ """
+ if hasattr(estimator, 'estimator'):
+ # GridSearchCV and other CV-tuned estimators
+ return _get_sub_estimator(estimator.estimator)
+ if hasattr(estimator, 'steps'):
+ # Pipeline
+ return _get_sub_estimator(estimator.steps[-1][1])
+ return estimator
+
+
+def is_classifier(estimator):
+ """ Returns True if the given estimator is (probably) a classifier.
+ """
+ estimator = _get_sub_estimator(estimator)
+ return isinstance(estimator, ClassifierMixin)
+
View
@@ -9,7 +9,7 @@
from math import ceil
import numpy as np
-from .base import ClassifierMixin
+from .base import is_classifier, clone
from .utils.extmath import factorial, combinations
from .externals.joblib import Parallel, delayed
@@ -485,9 +485,7 @@ def cross_val_score(estimator, X, y=None, score_func=None, cv=None,
"""
n_samples = len(X)
if cv is None:
- if y is not None and (isinstance(estimator, ClassifierMixin)
- or (hasattr(estimator, 'estimator')
- and isinstance(estimator.estimator, ClassifierMixin))):
+ if y is not None and is_classifier(estimator):
cv = StratifiedKFold(y, k=3)
else:
cv = KFold(n_samples, k=3)
@@ -497,8 +495,10 @@ def cross_val_score(estimator, X, y=None, score_func=None, cv=None,
"should have a 'score' method. The estimator %s "
"does not." % estimator
)
+ # We clone the estimator to make sure that all the folds are
+ # independent, and that it is pickable.
scores = Parallel(n_jobs=n_jobs, verbose=verbose)(
- delayed(_cross_val_score)(estimator, X, y, score_func,
+ delayed(_cross_val_score)(clone(estimator), X, y, score_func,
train, test)
for train, test in cv)
return np.array(scores)
@@ -10,7 +10,7 @@
from .externals.joblib import Parallel, delayed
from .cross_val import KFold, StratifiedKFold
-from .base import BaseEstimator, ClassifierMixin, clone
+from .base import BaseEstimator, is_classifier, clone
try:
from itertools import product
@@ -187,9 +187,7 @@ def fit(self, X, y, cv=None, **kw):
estimator = self.estimator
if cv is None:
n_samples = len(X)
- if y is not None and (isinstance(estimator, ClassifierMixin)
- or (hasattr(estimator, 'estimator')
- and isinstance(estimator.estimator, ClassifierMixin))):
+ if y is not None and is_classifier(estimator):
cv = StratifiedKFold(y, k=3)
else:
cv = KFold(n_samples, k=3)
@@ -206,7 +204,8 @@ def fit(self, X, y, cv=None, **kw):
self.best_estimator = best_estimator
self.predict = best_estimator.predict
- self.score = best_estimator.score
+ if hasattr(best_estimator, 'score'):
+ self.score = best_estimator.score
# Store the computed scores
grid = iter_grid(self.param_grid)
@@ -216,6 +215,12 @@ def fit(self, X, y, cv=None, **kw):
return self
+ def score(self, X, y=None):
+ # This method is overridden during the fit if the best estimator
+ # found has a score function.
+ y_predicted = self.predict(X)
+ return -self.loss_func(y_predicted, y)
+
if __name__ == '__main__':
from scikits.learn.svm import SVC
from scikits.learn import datasets
@@ -101,7 +101,7 @@ def __init__(self, steps):
"'%s' (type %s) doesn't)" % (estimator, type(estimator))
)
- def _get_params(self, deep=False):
+ def _get_params(self, deep=True):
if not deep:
return super(Pipeline, self)._get_params(deep=False)
else:
@@ -1,6 +1,10 @@
+
+# Author: Gael Varoquaux
+# License: BSD
+
from nose.tools import assert_true, assert_false, assert_equal, \
assert_raises
-from ..base import BaseEstimator, clone
+from ..base import BaseEstimator, clone, is_classifier
################################################################################
# A few test classes
@@ -74,7 +78,6 @@ def test_str():
def test_get_params():
-
test = T(K(), K())
assert_true('a__d' in test._get_params(deep=True))
@@ -84,3 +87,15 @@ def test_get_params():
assert test.a.d == 2
assert_raises(AssertionError, test._set_params, a__a=2)
+
+def test_is_classifier():
+ from ..svm import SVC
+ from ..pipeline import Pipeline
+ from ..grid_search import GridSearchCV
+ svc = SVC()
+ assert_true(is_classifier(svc))
+ assert_true(is_classifier(GridSearchCV(svc, {'C': [0.1, 1]})))
+ assert_true(is_classifier(Pipeline([('svc', svc)])))
+ assert_true(is_classifier(Pipeline([('svc_cv',
+ GridSearchCV(svc, {'C': [0.1, 1]}))])))
+
@@ -2,7 +2,7 @@
Test the pipeline module.
"""
-from nose.tools import assert_raises, assert_equal
+from nose.tools import assert_raises, assert_equal, assert_false
from ..base import BaseEstimator, clone
from ..pipeline import Pipeline
@@ -56,4 +56,14 @@ def test_pipeline_init():
# Test clone
pipe2 = clone(pipe)
- assert_equal(pipe._get_params(), pipe2._get_params())
+ assert_false(pipe._named_steps['svc'] is pipe2._named_steps['svc'])
+
+ # Check that appart from estimators, the parameters are the same
+ params = pipe._get_params()
+ params2 = pipe2._get_params()
+ # Remove estimators that where copied
+ params.pop('svc')
+ params.pop('anova')
+ params2.pop('svc')
+ params2.pop('anova')
+ assert_equal(params, params2)

0 comments on commit 2138c51

Please sign in to comment.