Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] FIX bug in nested set_params usage #9999

Merged

Conversation

jnothman
Copy link
Member

Issue where estimator is changed as well as its parameter: #9945 (comment)

Issue where estimator is changed as well as its parameter: scikit-learn#9945 (comment)
@jnothman jnothman added the Bug label Oct 25, 2017
@jnothman jnothman changed the title FIX bug in nested set_params usage [MRG] FIX bug in nested set_params usage Oct 25, 2017
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpicks but otherwise +1

def test_set_params_updates_valid_params():
# Check that set_params tries to set SVC().C, not
# DecisionTreeClassifier().C
pipe = GridSearchCV(DecisionTreeClassifier(), {})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why naming it pipe instead of something like gridsearchcv?

# Check that set_params tries to set SVC().C, not
# DecisionTreeClassifier().C
pipe = GridSearchCV(DecisionTreeClassifier(), {})
pipe.set_params(estimator=SVC(), estimator__C=1.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please extend this test to do estimator__C=42.0 and then assert gridsearchcv.estimator.C == 42.0?

def test_set_params_updates_valid_params():
# Check that set_params tries to set SVC().C, not
# DecisionTreeClassifier().C
pipe = GridSearchCV(DecisionTreeClassifier(), {})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pipe may not be the best of names ;-)

@jnothman
Copy link
Member Author

jnothman commented Oct 25, 2017 via email

@jnothman
Copy link
Member Author

jnothman commented Oct 25, 2017 via email

@lesteve
Copy link
Member

lesteve commented Oct 25, 2017

In the SO question, the OP mentions that this diff is working for him:

diff --git a/sklearn/base.py b/sklearn/base.py
index b653b7149..81c7e5dae 100644
--- a/sklearn/base.py
+++ b/sklearn/base.py
@@ -263,6 +263,7 @@ class BaseEstimator(object):
                 nested_params[key][sub_key] = value
             else:
                 setattr(self, key, value)
+                valid_params[key] = value
 
         for key, sub_params in nested_params.items():
             valid_params[key].set_params(**sub_params)

I checked indeed that the test passes with this patch which looked simpler than your current change. You know a lot more than me about this code, so there may be a reason why your current change is the way it is. Maybe there is a edge case that the simpler patch is not covering, if that is the case, it would be great to add a test for it.

@jnothman
Copy link
Member Author

jnothman commented Oct 25, 2017 via email

@marcus-voss
Copy link

Hey, the OP of SO here. I could PR it, but as I am not very experienced with such open source work I was kinda intimidated by simply adding the PR. As @lesteve mentions, who knows what edge cases I may be breaking.

@jnothman
Copy link
Member Author

jnothman commented Oct 25, 2017 via email

@TomDLT
Copy link
Member

TomDLT commented Oct 25, 2017

LGTM as well, thanks @marcus-voss for the clean fix !

@marcus-voss
Copy link

Hey @TomDLT and @jnothman, thanks for the kind words. For today just finding the bug and providing the fix indirectly for that great library already made my day. Next time, I'll definitely consider the PR!

@ogrisel ogrisel merged commit 102620f into scikit-learn:master Oct 25, 2017
@ogrisel
Copy link
Member

ogrisel commented Oct 25, 2017

Thanks all!

donigian added a commit to donigian/scikit-learn that referenced this pull request Oct 28, 2017
…cs/donigian-update-contribution-guidelines

* 'master' of github.com:scikit-learn/scikit-learn: (23 commits)
  fixes scikit-learn#10031: fix attribute name and shape in documentation (scikit-learn#10033)
  [MRG+1] add changelog entry for fixed and merged PR scikit-learn#10005 issue scikit-learn#9633 (scikit-learn#10025)
  [MRG] Fix LogisticRegression see also should include LogisticRegressionCV(scikit-learn#9995) (scikit-learn#10022)
  [MRG + 1] Labels of clustering should start at 0 or -1 if noise (scikit-learn#10015)
  MAINT Remove redundancy in scikit-learn#9552 (scikit-learn#9573)
  [MRG+1] correct comparison in GaussianNB for 'priors' (scikit-learn#10005)
  [MRG + 1] ENH add check_inverse in FunctionTransformer (scikit-learn#9399)
  [MRG] FIX bug in nested set_params usage (scikit-learn#9999)
  [MRG+1] Fix LOF and Isolation benchmarks (scikit-learn#9798)
  [MRG + 1] Fix negative inputs checking in mean_squared_log_error (scikit-learn#9968)
  DOC Fix typo (scikit-learn#9996)
  DOC Fix typo: x axis -> y axis (scikit-learn#9985)
  improve example plot_forest_iris.py (scikit-learn#9989)
  [MRG+1] Deprecate pooling_func unused parameter in AgglomerativeClustering (scikit-learn#9875)
  DOC update news
  DOC Fix three typos in manifold documentation (scikit-learn#9990)
  DOC add missing dot in docstring
  DOC Add what's new for 0.19.1 (scikit-learn#9983)
  Improve readability of outlier detection example. (scikit-learn#9973)
  DOC: Fixed typo (scikit-learn#9977)
  ...
maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants