Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] Added unit test for adding classes_ property to GridSearchCV, fixes #6298 #7661

Merged
merged 10 commits into from
Oct 20, 2016
5 changes: 5 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ New features
Enhancements
............

- Added ``classes_`` attribute to :class:`model_selection.GridSearchCV`
that matches the ``classes_`` attribute of ``best_estimator_``. (`#7661
<https://github.com/scikit-learn/scikit-learn/pull/7661>`_) by `Alyssa
Batula`_ and `Dylan Werner-Meier`_.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to add your names to the bottom of the file for the links to work. I'll do that.


- The ``min_weight_fraction_leaf`` constraint in tree construction is now
more efficient, taking a fast path to declare a node a leaf if its weight
is less than 2 * the minimum. Note that the constructed tree will be
Expand Down
8 changes: 6 additions & 2 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ def __init__(self, estimator, scoring=None,
def _estimator_type(self):
return self.estimator._estimator_type

@property
def classes_(self):
return self.best_estimator_.classes_

def score(self, X, y=None):
"""Returns the score on the given data, if the estimator has been refit.

Expand Down Expand Up @@ -688,7 +692,7 @@ class GridSearchCV(BaseSearchCV):
- An iterable yielding train/test splits.

For integer/None inputs, if the estimator is a classifier and ``y`` is
either binary or multiclass,
either binary or multiclass,
:class:`sklearn.model_selection.StratifiedKFold` is used. In all
other cases, :class:`sklearn.model_selection.KFold` is used.

Expand Down Expand Up @@ -900,7 +904,7 @@ class RandomizedSearchCV(BaseSearchCV):
- An iterable yielding train/test splits.

For integer/None inputs, if the estimator is a classifier and ``y`` is
either binary or multiclass,
either binary or multiclass,
:class:`sklearn.model_selection.StratifiedKFold` is used. In all
other cases, :class:`sklearn.model_selection.KFold` is used.

Expand Down
18 changes: 18 additions & 0 deletions sklearn/tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from sklearn.metrics import f1_score
from sklearn.metrics import make_scorer
from sklearn.metrics import roc_auc_score
from sklearn.linear_model import Ridge

from sklearn.exceptions import ChangedBehaviorWarning
from sklearn.exceptions import FitFailedWarning
Expand Down Expand Up @@ -785,3 +786,20 @@ def test_parameters_sampler_replacement():
sampler = ParameterSampler(params_distribution, n_iter=7)
samples = list(sampler)
assert_equal(len(samples), 7)


def test_classes__property():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay to combine these into one test function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I'd prefer it.

# Test that classes_ property matches best_esimator_.classes_
X = np.arange(100).reshape(10, 10)
y = np.array([0] * 5 + [1] * 5)
Cs = [.1, 1, 10]

grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs})
grid_search.fit(X, y)
assert_array_equal(grid_search.best_estimator_.classes_,
grid_search.classes_)

# Test that regressors do not have a classes_ attribute
grid_search = GridSearchCV(Ridge(), {'alpha': [1.0, 2.0]})
grid_search.fit(X, y)
assert_false(hasattr(grid_search, 'classes_'))