Skip to content

Commit

Permalink
[MRG + 1] Removes estimator method check in cross_val_predict before …
Browse files Browse the repository at this point in the history
…fitting (#9641)

* Removes check in cross_val_predict that checks estimator method before fitting

* Adds regression test for issue #9639
  • Loading branch information
jrbourbeau authored and jnothman committed Aug 30, 2017
1 parent 506380b commit f3412f8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
5 changes: 0 additions & 5 deletions sklearn/model_selection/_validation.py
Expand Up @@ -637,11 +637,6 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,

cv = check_cv(cv, y, classifier=is_classifier(estimator))

# Ensure the estimator has implemented the passed decision function
if not callable(getattr(estimator, method)):
raise AttributeError('{} not implemented in estimator'
.format(method))

if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
le = LabelEncoder()
y = le.fit_transform(y)
Expand Down
9 changes: 8 additions & 1 deletion sklearn/model_selection/tests/test_validation.py
Expand Up @@ -51,7 +51,7 @@
from sklearn.metrics import r2_score
from sklearn.metrics.scorer import check_scoring

from sklearn.linear_model import Ridge, LogisticRegression
from sklearn.linear_model import Ridge, LogisticRegression, SGDClassifier
from sklearn.linear_model import PassiveAggressiveClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
Expand Down Expand Up @@ -1194,6 +1194,13 @@ def test_cross_val_predict_with_method():
check_cross_val_predict_with_method(LogisticRegression())


def test_cross_val_predict_method_checking():
# Regression test for issue #9639. Tests that cross_val_predict does not
# check estimator methods (e.g. predict_proba) before fitting
est = SGDClassifier(loss='log', random_state=2)
check_cross_val_predict_with_method(est)


def test_gridsearchcv_cross_val_predict_with_method():
est = GridSearchCV(LogisticRegression(random_state=42),
{'C': [0.1, 1]},
Expand Down

0 comments on commit f3412f8

Please sign in to comment.