From ba0558c901d8de70413ed628ce0ce573da8746a2 Mon Sep 17 00:00:00 2001 From: Stephen Hoover Date: Sat, 4 Feb 2017 16:12:05 -0600 Subject: [PATCH 1/4] ENH Add classes_ parameter to hyperparameter CV classes In ``BaseSearchCV`` (superclass of ``GridSearchCV`` and ``RandomizedSearchCV``), add a ``clases_`` parameter which surfaces the ``classes_`` parameter of the ``best_estimator_``. Other parts of the scikit-learn code (e.g. ``cross_val_predict``) as well as users expect this property to be present on fitted classifiers. --- sklearn/model_selection/_search.py | 5 +++++ sklearn/model_selection/tests/test_search.py | 10 ++++++++++ sklearn/model_selection/tests/test_validation.py | 16 +++++++++++++--- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 566ec8c996c53..88b38ae0a0857 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -532,6 +532,11 @@ def inverse_transform(self, Xt): self._check_is_fitted('inverse_transform') return self.best_estimator_.transform(Xt) + @property + def classes_(self): + self._check_is_fitted("classes_") + return self.best_estimator_.classes_ + def fit(self, X, y=None, groups=None): """Run fit with all sets of parameters. diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 117b81a35ae2c..4b5df5b9fb777 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -72,6 +72,7 @@ def __init__(self, foo_param=0): def fit(self, X, Y): assert_true(len(X) == len(Y)) + self.classes_ = np.unique(Y) return self def predict(self, T): @@ -254,6 +255,15 @@ def test_grid_search_groups(): gs.fit(X, y) +def test_grid_search_classes_parameter(): + # Verify that the GridSearchCV can pass through the classes_ attribute + clf = MockClassifier() + grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}) + + grid_search.fit(X, y) + assert_array_equal(grid_search.classes_, np.unique(y)) + + def test_trivial_cv_results_attr(): # Test search over a "grid" with only one point. # Non-regression test: grid_scores_ wouldn't be set by GridSearchCV. diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index c6ae5f3fdd18a..6dbee112454d3 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -62,6 +62,7 @@ from sklearn.datasets import make_multilabel_classification from sklearn.model_selection.tests.common import OneTimeSplitter +from sklearn.model_selection import GridSearchCV try: @@ -914,7 +915,7 @@ def test_cross_val_predict_sparse_prediction(): assert_array_almost_equal(preds_sparse, preds) -def test_cross_val_predict_with_method(): +def run_cross_val_predict_with_method(est): iris = load_iris() X, y = iris.data, iris.target X, y = shuffle(X, y, random_state=0) @@ -924,8 +925,6 @@ def test_cross_val_predict_with_method(): methods = ['decision_function', 'predict_proba', 'predict_log_proba'] for method in methods: - est = LogisticRegression() - predictions = cross_val_predict(est, X, y, method=method) assert_equal(len(predictions), len(y)) @@ -955,6 +954,17 @@ def test_cross_val_predict_with_method(): assert_array_equal(predictions, predictions_ystr) +def test_cross_val_predict_with_method(): + run_cross_val_predict_with_method(LogisticRegression()) + + +def test_gridsearchcv_cross_val_predict_with_method(): + est = GridSearchCV(LogisticRegression(random_state=42), + {'C': [0.1, 1]}, + cv=2) + run_cross_val_predict_with_method(est) + + def get_expected_predictions(X, y, cv, classes, est, method): expected_predictions = np.zeros([len(y), classes]) From e25d2e0784195e2fb87f217164b9726eeb38bead Mon Sep 17 00:00:00 2001 From: Stephen Hoover Date: Mon, 6 Feb 2017 09:16:54 -0600 Subject: [PATCH 2/4] Modify tests to address code review --- sklearn/model_selection/tests/test_search.py | 30 +++++++++++++++---- .../model_selection/tests/test_validation.py | 6 ++-- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 4b5df5b9fb777..64e45f815f36e 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -58,7 +58,7 @@ from sklearn.metrics import roc_auc_score from sklearn.preprocessing import Imputer from sklearn.pipeline import Pipeline -from sklearn.linear_model import SGDClassifier +from sklearn.linear_model import Ridge, SGDClassifier from sklearn.model_selection.tests.common import OneTimeSplitter @@ -255,13 +255,31 @@ def test_grid_search_groups(): gs.fit(X, y) -def test_grid_search_classes_parameter(): - # Verify that the GridSearchCV can pass through the classes_ attribute - clf = MockClassifier() - grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}) +def test_classes__property(): + # Test that classes_ property matches best_estimator_.classes_ + X = np.arange(100).reshape(10, 10) + y = np.array([0] * 5 + [1] * 5) + Cs = [.1, 1, 10] + + grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs}) + grid_search.fit(X, y) + assert_array_equal(grid_search.best_estimator_.classes_, + grid_search.classes_) + + # Test that regressors do not have a classes_ attribute + grid_search = GridSearchCV(Ridge(), {'alpha': [1.0, 2.0]}) + grid_search.fit(X, y) + assert_false(hasattr(grid_search, 'classes_')) + + # Test that the grid searcher has no classes_ attribute before it's fit + grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs}) + assert_false(hasattr(grid_search, 'classes_')) + # Test that the grid searcher has no classes_ attribute without a refit + grid_search = GridSearchCV(LinearSVC(random_state=0), + {'C': Cs}, refit=False) grid_search.fit(X, y) - assert_array_equal(grid_search.classes_, np.unique(y)) + assert_false(hasattr(grid_search, 'classes_')) def test_trivial_cv_results_attr(): diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 6dbee112454d3..cc6f5973a0b09 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -915,7 +915,7 @@ def test_cross_val_predict_sparse_prediction(): assert_array_almost_equal(preds_sparse, preds) -def run_cross_val_predict_with_method(est): +def check_cross_val_predict_with_method(est): iris = load_iris() X, y = iris.data, iris.target X, y = shuffle(X, y, random_state=0) @@ -955,14 +955,14 @@ def run_cross_val_predict_with_method(est): def test_cross_val_predict_with_method(): - run_cross_val_predict_with_method(LogisticRegression()) + check_cross_val_predict_with_method(LogisticRegression()) def test_gridsearchcv_cross_val_predict_with_method(): est = GridSearchCV(LogisticRegression(random_state=42), {'C': [0.1, 1]}, cv=2) - run_cross_val_predict_with_method(est) + check_cross_val_predict_with_method(est) def get_expected_predictions(X, y, cv, classes, est, method): From 8a64508ef3dbdcb81c62161423dc25db9b6b1e57 Mon Sep 17 00:00:00 2001 From: Stephen Hoover Date: Thu, 9 Feb 2017 16:29:35 -0600 Subject: [PATCH 3/4] DOC Update What's New --- doc/whats_new.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 6be337bbe6765..7c923b1f1520c 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -68,7 +68,8 @@ Enhancements - Added ``classes_`` attribute to :class:`model_selection.GridSearchCV` that matches the ``classes_`` attribute of ``best_estimator_``. :issue:`7661` - by :user:`Alyssa Batula ` and :user:`Dylan Werner-Meier `. + and :issue:`8295` by :user:`Alyssa Batula `, + :user:`Dylan Werner-Meier `, and :user:`Stephen Hoover `. - The ``min_weight_fraction_leaf`` constraint in tree construction is now more efficient, taking a fast path to declare a node a leaf if its weight From 6b136a25eccfcfdde71c0a85ad1a092359a51a9b Mon Sep 17 00:00:00 2001 From: Stephen Hoover Date: Thu, 9 Feb 2017 16:57:41 -0600 Subject: [PATCH 4/4] DOC Fix What's New addition --- doc/whats_new.rst | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 7c923b1f1520c..7a93e8feee74a 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -66,10 +66,12 @@ Enhancements now uses significantly less memory when assigning data points to their nearest cluster center. :issue:`7721` by :user:`Jon Crall `. - - Added ``classes_`` attribute to :class:`model_selection.GridSearchCV` - that matches the ``classes_`` attribute of ``best_estimator_``. :issue:`7661` - and :issue:`8295` by :user:`Alyssa Batula `, - :user:`Dylan Werner-Meier `, and :user:`Stephen Hoover `. + - Added ``classes_`` attribute to :class:`model_selection.GridSearchCV`, + :class:`model_selection.RandomizedSearchCV`, :class:`grid_search.GridSearchCV`, + and :class:`grid_search.RandomizedSearchCV` that matches the ``classes_`` + attribute of ``best_estimator_``. :issue:`7661` and :issue:`8295` + by :user:`Alyssa Batula `, :user:`Dylan Werner-Meier `, + and :user:`Stephen Hoover `. - The ``min_weight_fraction_leaf`` constraint in tree construction is now more efficient, taking a fast path to declare a node a leaf if its weight