Skip to content

Loading…

WIP Grid search convenience class. #1034

Open
wants to merge 10 commits into from

6 participants

@amueller
scikit-learn member

This aims at closing #1020.

It introduces a new class to handle the output of GridSearchCV.
I don't like complexity but I haven't found a nice way to do this otherwise. If you have less complex solutions, please let me know.

Basically this transforms the dicts that are usually in GridSearchCV.grid_scores_ (remember, this is in general a list of dictionaries) to a list of parameters (which is sorted(param_grid.keys()) and an array where
each axis corresponds to one parameter and the last corresponds to folds.
I think this is already an improvement.

The reason why I added the class is that I also want to marginalize parameters. I want to look at it even if I have 5 parameters to adjust. Maximizing over multiple axis is ugly so I wanted to class to handle this.

As always, any comments welcome.
I'll make an example to illustrate the usefulness now :)

Going back to wip as I think this should be designed with having non-grid evaluations of estimators in mind.

@amueller
scikit-learn member

The current example is not polished and takes to long, but you can have a look to get the general idea.
Btw, I noticed that using standard deviations is not a good idea and I should use the binomial thing. Could be an option or something. Finishing for today.

@amueller
scikit-learn member

@agramfort you seemed quite enthusiastic about the topic ;) Any opinions?

An alternative to the marginalized parameter plot would be doing a scatterplot showing all values along a given dimension btw. I don't know how to plot uncertainty there, though.

@amueller
scikit-learn member

The new example looks something like this:
grid_search_plots

I'm not terribly happy with the graphs but the run time shouldn't go up to much. Any ideas for a better example are welcome.

I'm quite happy with the code for the example, though. Also note how the SVM grid search example got a bit nicer.

@ogrisel
scikit-learn member

It would be great to measure the fit / predict times for each parameter and add them to the collected statistics and plots.

@amueller
scikit-learn member

I definitely won't have time to do that before the release.
What do you think of the current approach?
It would be easy to give the class another function to get the timings later.

@ogrisel
scikit-learn member

I don't have the time to review in details right now but I surely find the global idea a good approach.

I would like to have a similar "result class" for the classification_report utility function and have the __repr__ method compute the same string as the current implementation of the classification_report function.

@agramfort agramfort commented on an outdated diff
examples/plot_grid_search.py
((4 lines not shown))
+=====================================================
+
+Often one is faced with combining feature extraction, feature selection
+and classification into a complex pipeline.
+Each individual step usually has many tunable parameters. Finding the
+important parameters for a given task and picking robust settings is often
+hard.
+
+This example show how to visualize results of a grid search with
+many interacting parameters.
+The ``DecisionTreeClassifier`` is a good model for a complex pipeline as there
+are many parameters to tweak, but often only few have significant influence.
+"""
+print __doc__
+
+import matplotlib.pyplot as plt
@agramfort scikit-learn member

import pylab as pl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@agramfort agramfort commented on an outdated diff
sklearn/grid_search.py
((26 lines not shown))
+
+ def __init__(self, params, values, scores):
+ self.scores = scores
+ self.params = params
+ self.values = values
+
+ def mean(self):
+ """Returns mean scores over folds for the whole parameter grid."""
+ return np.mean(self.scores, axis=-1)
+
+ def std(self):
+ """Returns standard deviation of scores over folds for the whole
+ parameter grid."""
+ return np.std(self.scores, axis=-1)
+
+ def accumulated(self, param, kind="max"):
@agramfort scikit-learn member

accumulate or pool sounds better to me

@NelleV scikit-learn member
NelleV added a note

or accumulate_scores ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@agramfort agramfort commented on an outdated diff
examples/svm/plot_rbf_parameters.py
@@ -105,12 +105,8 @@
pl.axis('tight')
# plot the scores of the grid
-# grid_scores_ contains parameter settings and scores
-score_dict = grid.grid_scores_
-
-# We extract just the scores
-scores = [x[1] for x in score_dict]
-scores = np.array(scores).reshape(len(C_range), len(gamma_range))
+results = grid.scores_
@agramfort scikit-learn member

I would name it cv_scores rather results

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@agramfort agramfort commented on an outdated diff
examples/plot_grid_search.py
((19 lines not shown))
+import matplotlib.pyplot as plt
+
+from sklearn.datasets import make_classification
+from sklearn.grid_search import GridSearchCV
+from sklearn.tree import DecisionTreeClassifier
+
+X, y = make_classification(n_samples=100, n_features=10)
+
+param_grid = {'max_depth': range(1, 8), 'min_samples_split': [1, 2, 3, 4],
+ 'max_features': [1, 3, 5, 8, 10]}
+
+grid_search = GridSearchCV(DecisionTreeClassifier(), param_grid=param_grid,
+ cv=5)
+grid_search.fit(X, y)
+
+results = grid_search.scores_
@agramfort scikit-learn member

cv_scores here too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@agramfort agramfort commented on an outdated diff
examples/plot_grid_search.py
((27 lines not shown))
+param_grid = {'max_depth': range(1, 8), 'min_samples_split': [1, 2, 3, 4],
+ 'max_features': [1, 3, 5, 8, 10]}
+
+grid_search = GridSearchCV(DecisionTreeClassifier(), param_grid=param_grid,
+ cv=5)
+grid_search.fit(X, y)
+
+results = grid_search.scores_
+
+fig, axes = plt.subplots(1, 3)
+axes = axes.ravel()
+for ax, param in zip(axes, results.params):
+ means, errors = results.accumulated(param, 'max')
+ ax.errorbar(results.values[param], means, yerr=errors)
+ ax.set_title(param)
+plt.show()
@agramfort scikit-learn member

can you tweak bit the figsize to the text shows up well?

@amueller scikit-learn member

The text did show up for me but I changed the sizes a bit, so it should look better now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@agramfort
scikit-learn member

besides this design looks good to me

@amueller
scikit-learn member

Thanks for the feedback. Glad you like the design :)

@NelleV NelleV commented on an outdated diff
examples/plot_grid_search.py
((25 lines not shown))
+X, y = make_classification(n_samples=100, n_features=10)
+
+param_grid = {'max_depth': range(1, 8), 'min_samples_split': [1, 2, 3, 4],
+ 'max_features': [1, 3, 5, 8, 10]}
+
+grid_search = GridSearchCV(DecisionTreeClassifier(), param_grid=param_grid,
+ cv=5)
+grid_search.fit(X, y)
+
+cv_scores = grid_search.scores_
+
+fig, axes = pl.subplots(1, 3)
+axes = axes.ravel()
+for ax, param in zip(axes, cv_scores.params):
+ means, errors = cv_scores.accumulated(param, 'max')
+ ax.errorbar(cv_scores.values[param], means, yerr=errors)
@NelleV scikit-learn member
NelleV added a note

As mentionned on IRC, I think boxplots would be better than errorbars.
I also think the labels should be set, both on X and Y axis.

@NelleV scikit-learn member
NelleV added a note

Also, I think the three plots should be scaled the same way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@NelleV NelleV commented on the diff
examples/svm/plot_rbf_parameters.py
((11 lines not shown))
# draw heatmap of accuracy as a function of gamma and C
pl.figure(figsize=(8, 6))
pl.subplots_adjust(left=0.05, right=0.95, bottom=0.15, top=0.95)
-pl.imshow(scores, interpolation='nearest', cmap=pl.cm.spectral)
+pl.imshow(cv_scores.mean(), interpolation='nearest', cmap=pl.cm.spectral)
@NelleV scikit-learn member
NelleV added a note

On this plot, I'd also set a title to the color bar

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@NelleV NelleV commented on an outdated diff
sklearn/grid_search.py
((20 lines not shown))
+ for one_grid in param_grid:
+ sorted_params = sorted(one_grid.keys())
+ # get the number of values for each parameter
+ grid_shape = [len(one_grid[k]) for k in sorted_params]
+ n_entries = np.prod(grid_shape)
+ grid_shape.append(n_folds)
+ # get scores
+ score_array = np.array(cv_scores[start:start + n_entries])
+ # reshape to fit the sequence of values
+ score_array = score_array.reshape(grid_shape)
+ self.scores_.append(ResultGrid(sorted_params, one_grid,
+ score_array))
+ start += n_entries
+
+ # often the list is just one grid. Make access easier
+ if len(self.scores_) is 1:
@NelleV scikit-learn member
NelleV added a note

You should never use is unless comparing to None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@amueller
scikit-learn member

About the boxplots: that would mean I have to return all sample points, not only summarized statistics.
that could be done but I don't see a huge benefit in using the boxplots. instead, I'd rather go for the correct statistical measure, which I have to look up ;)

@amueller
scikit-learn member

I am not so happy about this any more. It assumes an actual grid. I think the object should be constructed such that it supports random sampling and other ways to try out parameters.
This opens up a whole new can of worms, though...
Because then you might feed a stored object into the searching object so as to refine the search or something... not sure if these should be addressed together.

What was the API idea for the random search? How do we report / store progress? You definitely want to be able to continue a search that was started before....

@ogrisel
scikit-learn member

Indeed, supporting @jaberg's randomized grid search would be neat (even if we decide to make that a non-default option). Hence the report object should be able to deal with that and maybe the case where the user interrupts the grid search early (e.g. using a keyboard interrupt), hence the report would have a non complete grid.

I think we should postpone the discussion for after the release.

@amueller
scikit-learn member

Yeah sure, this is not release-related. But I think I should do this together with #455 and maybe an interrupt-robust GridSearchCV or similar...

@agramfort
scikit-learn member
@amueller
scikit-learn member

sure, why not.

@amueller amueller referenced this pull request
Merged

MRG Feature stacker #1173

@amueller
scikit-learn member

Will look at @jaberg's JMLR paper before continuing this, and then hopefully do randomized search together with sensible visualization / analysis for both grid and random search.

@GaelVaroquaux
scikit-learn member
@jnothman
scikit-learn member

One nice thing about this feature is being able to examine variation some parameters while taking the max or average over some other parameters. I would like to be able to see an argmax equivalent, which basically takes a parameter grid (or other space) and selects the parameter settings (i.e. returns an index array) corresponding to some sub-grid (choosing the max-scoring instance for each point).

With #1842, one could do something like:

# Get an index over the parameters of intersest
ind_params, index = sklearn.grid_search.ParameterGrid(grid.param_grid).build_index(interesting_params, ravel=True)
# Reshape our results (assuming structured array output) and max out over remaining params
best = grid.grid_results_[index]['test_score'].argmax(axis=-1)
# Get back indices into original results (perhaps not the best way to do it)
index.flat[best.flat + (index.shape[-1] * np.arange(best.size))]

but it similarly doesn't apply nicely to non-grid shapes.

@amueller
scikit-learn member

I abandoned the PR because it is way out of date, but I would still like to see something like it.
It should be able to handle non-grid structure, though... if that is possible in a sensible way...

@jnothman
scikit-learn member

Yeah, I noticed. Arranging different parameters along different axes really doesn't work for non-grids, but can be really useful for grids, which is why #1842 proposes building such a structure within ParameterGrid, with the constraint that when it uses a list of grids, you can only index one at a time.

I've also privately implemented returning parameters as a record array, which means it can be more easily sliced and diced using numpy comparators and indexing operations. (And if each distinct group of values for some fields can be assigned an integer, calling bincount twice can find the average score; the argmax score can also be found. Or there's pgnumpy.) But because it's possible to get parameter spaces where some parameters are simply not set for some points (and this will more often happen when users are able to replace pipeline steps), one needs to use a masked recarray, which comes with somewhat quirky behaviour making them not very user-friendly.

And perhaps to make this data wrangling user-friendly it should be within an accessory class, which could even provide things like a Wilcoxon signed-rank test between fold results under different parameters. Or perhaps that's all unnecessary bloat.

@jnothman
scikit-learn member

Here's my own tool to do some of these things: https://gist.github.com/jnothman/5480026
But I've realised this is more-or-less what the pandas project specialises in.

@amueller
scikit-learn member

actually, some people from my lab use pandas exactly for that purpose. So maybe one goal would be to make it easy to create a dataframe from the results? I still haven't got around to review your stuff :-/ Wednesday is the next deadline....

@amueller
scikit-learn member

btw, accumulating with the mean is actually a bad idea, as the variables are err.. beta distributed?

@jnothman
scikit-learn member

Okay, so let's say we had two attributes or functions to get back:

  • the parameters and average search results as mrecarray
  • the parameters and fold results with additional fold number field as mrecarray

(I.e. to avoid the heterogenous data types/shapes, we allow users to take either the raw fold scores, or the aggregate scores, not both at same time.)

These have the traditional table form and can be played with in numpy, or passed to Pandas, or dumped to spreadsheet / *SQL / mongodb.

@jnothman
scikit-learn member

What do you think of tabulate_results?

@GaelVaroquaux
scikit-learn member
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Showing with 207 additions and 15 deletions.
  1. +46 −0 examples/plot_grid_search.py
  2. +11 −8 examples/svm/plot_rbf_parameters.py
  3. +118 −6 sklearn/grid_search.py
  4. +32 −1 sklearn/tests/test_grid_search.py
View
46 examples/plot_grid_search.py
@@ -0,0 +1,46 @@
+"""
+=====================================================
+Visualizing results of high dimensional grid searches
+=====================================================
+
+Often one is faced with combining feature extraction, feature selection
+and classification into a complex pipeline.
+Each individual step usually has many tunable parameters. Finding the
+important parameters for a given task and picking robust settings is often
+hard.
+
+This example show how to visualize results of a grid search with
+many interacting parameters.
+The ``DecisionTreeClassifier`` is a good model for a complex pipeline as there
+are many parameters to tweak, but often only few have significant influence.
+"""
+print __doc__
+
+import pylab as pl
+
+from sklearn.datasets import make_classification
+from sklearn.grid_search import GridSearchCV
+from sklearn.tree import DecisionTreeClassifier
+
+X, y = make_classification(n_samples=100, n_features=10, random_state=0)
+
+param_grid = {'max_depth': range(1, 8), 'min_samples_split': [1, 2, 3, 4],
+ 'max_features': [1, 3, 5, 8, 10]}
+
+grid_search = GridSearchCV(DecisionTreeClassifier(), param_grid=param_grid,
+ cv=5)
+grid_search.fit(X, y)
+
+cv_scores = grid_search.scores_
+
+fig, axes = pl.subplots(1, 3)
+axes = axes.ravel()
+for ax, param in zip(axes, cv_scores.params):
+ means, errors = cv_scores.accumulate(param, 'max')
+ ax.boxplot(cv_scores.values[param], means, yerr=errors)
+ ax.set_xlabel(param)
+ ax.set_ylabel("accuracy")
+ ax.set_ylim(0.6, 0.95)
+fig.set_size_inches((12, 4), forward=True)
+pl.subplots_adjust(left=0.07, right=0.95, bottom=0.15, wspace=0.26)
+pl.show()
View
19 examples/svm/plot_rbf_parameters.py
@@ -105,21 +105,24 @@
pl.axis('tight')
# plot the scores of the grid
-# grid_scores_ contains parameter settings and scores
-score_dict = grid.grid_scores_
-
-# We extract just the scores
-scores = [x[1] for x in score_dict]
-scores = np.array(scores).reshape(len(C_range), len(gamma_range))
+cv_scores = grid.scores_
# draw heatmap of accuracy as a function of gamma and C
pl.figure(figsize=(8, 6))
pl.subplots_adjust(left=0.05, right=0.95, bottom=0.15, top=0.95)
-pl.imshow(scores, interpolation='nearest', cmap=pl.cm.spectral)
+pl.imshow(cv_scores.mean(), interpolation='nearest', cmap=pl.cm.spectral)
@NelleV scikit-learn member
NelleV added a note

On this plot, I'd also set a title to the color bar

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
pl.xlabel('gamma')
pl.ylabel('C')
-pl.colorbar()
+cb = pl.colorbar()
+cb.set_label("Accuracy")
pl.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)
pl.yticks(np.arange(len(C_range)), C_range)
+fig, axes = pl.subplots(2, 1)
+for ax, param in zip(axes, cv_scores.params):
+ maxs, errors = cv_scores.accumulate(param, 'max')
+ ax.errorbar(np.arange(len(cv_scores.values[param])), maxs,
+ yerr=errors)
+ ax.set_title(param)
+
pl.show()
View
124 sklearn/grid_search.py
@@ -19,6 +19,86 @@
from .utils import check_arrays, safe_mask
+class ResultGrid(object):
+ """Provides easy access to grid search results.
+
+ This object is constructed by GridSearchCV and
+ provides an easy interface to evaluate the grid search
+ results.
+
+ Attributes
+ ----------
+ params: list of string
+ Lists parameters adjusted during grid-search
+ This is an alphabetical sorting of the keys
+ of the ``param_grid`` used in the GridSearchCV.
+ values: dict
+ This contains the values of the parameters
+ that were used during grid search.
+ scores: ndarray
+ Contains all the scores of all runs.
+ Each axis corresponds to the setting of one
+ parameter, in the order given in params.
+ The last axis corresponds to the folds.
+ """
+
+ def __init__(self, params, values, scores):
+ self.scores = scores
+ self.params = params
+ self.values = values
+
+ def mean(self):
+ """Returns mean scores over folds for the whole parameter grid."""
+ return np.mean(self.scores, axis=-1)
+
+ def std(self):
+ """Returns standard deviation of scores over folds for the whole
+ parameter grid."""
+ return np.std(self.scores, axis=-1)
+
+ def accumulate(self, param, kind="max"):
+ """Accumulates scores over all but one parameter.
+
+ Useful for grid searches in many parameters, where
+ the whole grid can not easily be visualized.
+
+ Parameters
+ ----------
+ param: string
+ Name of the parameter not to accumulate over.
+ kind: string, 'mean' or 'max'
+ Operation that is used to accumulate over all parameters
+ except ``param``.
+
+ Returns
+ -------
+ scores: ndarray
+ 1d array of scores corresponding to the different settings
+ of ``param``.
+ errors: ndarray
+ 1d array of standard deviations of scores.
+ """
+ index = self.params.index(param)
+ # make interesting axis the first
+ n_values = len(self.values[param])
+ accumulated_mean = np.rollaxis(self.mean(), index, 0)
+ accumulated_mean = accumulated_mean.reshape(n_values, -1)
+ accumulated_std = np.rollaxis(self.std(), index, 0)
+ accumulated_std = accumulated_std.reshape(n_values, -1)
+ if kind == "mean":
+ accumulated_mean = np.mean(accumulated_mean, axis=-1)
+ accumulated_std = np.mean(accumulated_std, axis=-1)
+ elif kind == "max":
+ max_inds = np.argmax(accumulated_mean, axis=-1)
+ inds = np.indices(max_inds.shape)
+ accumulated_mean = accumulated_mean[inds, max_inds].ravel()
+ accumulated_std = accumulated_std[inds, max_inds].ravel()
+ else:
+ raise ValueError("kind must be 'mean' or 'all', got %s." %
+ str(kind))
+ return accumulated_mean, accumulated_std
+
+
class IterGrid(object):
"""Generators on the combination of the various parameter lists given
@@ -97,7 +177,6 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, loss_func,
else:
X_train = X[safe_mask(X, train)]
X_test = X[safe_mask(X, test)]
-
if y is not None:
y_test = y[safe_mask(y, test)]
y_train = y[safe_mask(y, train)]
@@ -150,8 +229,8 @@ def _check_param_grid(param_grid):
raise ValueError("Parameter values should be a list.")
if len(v) == 0:
- raise ValueError("Parameter values should be a non-empty "
- "list.")
+ raise ValueError("Parameter values should be "
+ "a non-empty list.")
def _has_one_grid_point(param_grid):
@@ -268,6 +347,10 @@ class GridSearchCV(BaseEstimator, MetaEstimatorMixin):
`best_params_` : dict
Parameter setting that gave the best results on the hold out data.
+ `scores_`: list of ResultGrid
+ For each dict in ``param_grid`` this holds a ``ResultGrid`` that
+ provides easy analysis of the grid search scores.
+
Notes
------
The parameters selected are those that maximize the score of the left out
@@ -435,9 +518,38 @@ def _fit(self, X, y):
self._best_estimator_ = best_estimator
self._set_methods()
- # Store the computed scores
- # XXX: the name is too specific, it shouldn't have
- # 'grid' in it. Also, we should be retrieving/storing variance
+ # param grid can be a list
+ # make singleton to list for unified treatment
+ if hasattr(self.param_grid, 'items'):
+ # wrap dictionary in a singleton list
+ param_grid = [self.param_grid]
+ else:
+ param_grid = self.param_grid
+ # for each entry in the param_grid list, we build
+ # an array of scores.
+ # we don't know how long the parts are so we have
+ # to keep track of everything :-/
+ start = 0
+ self.scores_ = []
+ for one_grid in param_grid:
+ sorted_params = sorted(one_grid.keys())
+ # get the number of values for each parameter
+ grid_shape = [len(one_grid[k]) for k in sorted_params]
+ n_entries = np.prod(grid_shape)
+ grid_shape.append(n_folds)
+ # get scores
+ score_array = np.array(cv_scores[start:start + n_entries])
+ # reshape to fit the sequence of values
+ score_array = score_array.reshape(grid_shape)
+ self.scores_.append(ResultGrid(sorted_params, one_grid,
+ score_array))
+ start += n_entries
+
+ # often the list is just one grid. Make access easier
+ if len(self.scores_) == 1:
+ self.scores_ = self.scores_[0]
+
+ # old interface
self.grid_scores_ = [
(clf_params, score, all_scores)
for clf_params, (score, _), all_scores
View
33 sklearn/tests/test_grid_search.py
@@ -12,6 +12,7 @@
from sklearn.grid_search import GridSearchCV
from sklearn.datasets.samples_generator import make_classification
from sklearn.svm import LinearSVC, SVC
+from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import f1_score, precision_score
from sklearn.cross_validation import KFold
@@ -49,7 +50,7 @@ def test_grid_search():
assert_equal(grid_search.best_estimator_.foo_param, 2)
for i, foo_i in enumerate([1, 2, 3]):
- assert_true(grid_search.grid_scores_[i][0] == {'foo_param': foo_i})
+ assert_equal(grid_search.grid_scores_[i][0], {'foo_param': foo_i})
# Smoke test the score:
grid_search.score(X, y)
@@ -225,3 +226,33 @@ def test_X_as_list():
cv = KFold(n=len(X), k=3)
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv)
grid_search.fit(X.tolist(), y).score(X, y)
+
+
+def test_result_grid():
+ # make small grid search and test ResultGrid on it
+ clf = DecisionTreeClassifier()
+ X, y = make_classification()
+ param_grid = {'max_depth': np.arange(1, 5),
+ 'max_features': np.arange(1, 3)}
+ grid_search = GridSearchCV(clf, param_grid=param_grid)
+ grid_search.fit(X, y)
+ result = grid_search.scores_
+ assert_equal(result.mean().shape, (4, 2))
+ assert_equal(result.std().shape, (4, 2))
+ assert_equal(result.scores.shape, (4, 2, 3))
+ means, errs = result.accumulated('max_depth')
+ assert_equal(len(means), 4)
+ assert_equal(len(errs), 4)
+ assert_equal(len(result.values['max_depth']), 4)
+
+
+def test_list():
+ # test that grid search can handle list of dics as param_grid
+ # smoke test!
+ clf = DecisionTreeClassifier()
+ X, y = make_classification()
+ param_grid = [{'max_depth': np.arange(1, 5)},
+ {'max_features': np.arange(1, 3)}]
+ grid_search = GridSearchCV(clf, param_grid=param_grid)
+ grid_search.fit(X, y)
+ assert_equal(len(grid_search.scores_), 2)
Something went wrong with that request. Please try again.