New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+1] fix _BaseComposition._set_params with nested parameters #9945

Merged
merged 12 commits into from Oct 18, 2017

Test that nested estimators get passed all params at once

  • Loading branch information...
jnothman committed Oct 18, 2017
commit dd1b79230a91fcb2ff882ed045ee61c46accbe3b
View
@@ -228,6 +228,24 @@ def test_set_params():
# bad__stupid_param=True)
def test_set_params_passes_all_parameters():
# Make sure all parameters are passed together to set_params
# of nested estimator. Regression test for #9944
class TestDecisionTree(DecisionTreeClassifier):
def set_params(self, **kwargs):
super(TestDecisionTree, self).set_params(**kwargs)
# expected_kwargs is in test scope
assert kwargs == expected_kwargs
return self
expected_kwargs = {'max_depth': 5, 'min_samples_leaf': 2}
for est in [Pipeline([('estimator', TestDecisionTree())]),
GridSearchCV(TestDecisionTree(), {})]:
est.set_params(estimator__max_depth=5,
estimator__min_samples_leaf=2)
def test_score_sample_weight():
rng = np.random.RandomState(0)
@@ -38,14 +38,17 @@ def _get_params(self, attr, deep=True):
def _set_params(self, attr, **params):
# Ensure strict ordering of parameter setting:
# 1. All steps
print("Pipeline.set_params", sorted(params.keys()))
if attr in params:
setattr(self, attr, params.pop(attr))
# 2. Step replacement
names, _ = zip(*getattr(self, attr))
for name in list(six.iterkeys(params)):
if '__' not in name and name in names:
print("replacing", name, 'with', params[name])
self._replace_estimator(attr, name, params.pop(name))
# 3. Step parameters and other initilisation arguments
print("passing on", params, 'to', self)
super(_BaseComposition, self).set_params(**params)
return self
ProTip! Use n and p to navigate between commits in a pull request.