Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Fixes visualization for nested meta-estimators #21310

Merged
merged 7 commits into from
Mar 14, 2022

Conversation

thomasjpfan
Copy link
Member

Reference Issues/PRs

Fixes #21267

What does this implement/fix? Explain your changes.

This PR changes the implementation a little to only work on estimators. This fixes the original issue:

from sklearn.gaussian_process.kernels import ExpSineSquared
from sklearn.kernel_ridge import KernelRidge
from sklearn.ensemble import VotingClassifier
from sklearn.model_selection import RandomizedSearchCV

from sklearn import set_config

set_config(display="diagram")
kernel_ridge = KernelRidge(kernel=ExpSineSquared())

kernel_ridge_tuned = RandomizedSearchCV(
    kernel_ridge,
    param_distributions={},
)
kernel_ridge_tuned

this PR

Screen Shot 2021-10-12 at 10 11 57 AM

main

Screen Shot 2021-10-12 at 10 06 30 AM

Any other comments?

There was also issues with nesting meta-estimators that is resolved:

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import RandomizedSearchCV
from sklearn.ensemble import VotingClassifier
from sklearn.ensemble import RandomForestClassifier

vc = VotingClassifier(
    [("log_reg", LogisticRegression()),
    ("rf", RandomForestClassifier())]
)

vc_tuned = RandomizedSearchCV(
    vc,
    param_distributions={},
)
vc_tuned

this PR

Screen Shot 2021-10-12 at 10 12 04 AM

main

Screen Shot 2021-10-12 at 10 12 17 AM

Comment on lines 104 to 110
for key, est in estimator.get_params().items()
if (
"__" not in key
and hasattr(est, "get_params")
and hasattr(est, "fit")
and callable(est.fit)
)
Copy link
Member

@jeremiedbb jeremiedbb Mar 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't get_params be called with deep=False ?
I guess it's what "__" not in key is suppose to mimic

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree setting deep=False works here. I updated the PR with your suggestion.

doc/whats_new/v1.1.rst Outdated Show resolved Hide resolved
@jeremiedbb jeremiedbb merged commit 6511ef0 into scikit-learn:main Mar 14, 2022
glemaitre pushed a commit to glemaitre/scikit-learn that referenced this pull request Apr 6, 2022
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

HTML representation not really working with RandomizedSearchCV
3 participants