Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Introduces the __sklearn_clone__ protocol #24568

Merged
merged 8 commits into from
Jan 19, 2023

Conversation

thomasjpfan
Copy link
Member

Reference Issues/PRs

Related to #8370
Implementation of SLEP017: scikit-learn/enhancement_proposals#67

What does this implement/fix? Explain your changes.

This PR introduces the __sklearn_clone__ protocol. This PR contains a test that has a possible implementation of freezing:

class FrozenEstimator(BaseEstimator):
    def __init__(self, estimator):
        self.estimator = estimator

    def __getattr__(self, name):
        return getattr(self.estimator, name)

    def __sklearn_clone__(self):
        return self

    def fit(self, *args, **kwargs):
        return self

    @available_if(lambda self: hasattr(self.estimator, "transform"))
    def fit_transform(self, *args, **kwargs):
        return self.estimator.transform(*args, **kwargs)

The idea is mostly like #9464, but it uses __getattr__ to pass all the non-overridden methods to the inner estimator. This means attributes and methods are available from the FrozenEstimator object.

@thomasjpfan thomasjpfan changed the title ENH Adds __sklearn_clone__ protocol ENH Introduces the __sklearn_clone__ protocol Oct 3, 2022
@adrinjalali
Copy link
Member

We should also add it to "how to write your estimator" guide.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs docs, otherwise looking great!

def fit(self, *args, **kwargs):
return self

@available_if(lambda self: hasattr(self.estimator, "transform"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we need to use available_if here!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. The available_if is not needed here. If we introduce a generic "FrozenEstimator", then the available_if would be useful.

That being said, the test does not require available_if, so I'll remove it.

return _clone_parametrized(estimator, safe=safe)


def _clone_parametrized(estimator, *, safe=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this benefit from being public? Or will those who need it be able to use super?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking third party estimators would call super in __sklearn_clone__ or clone directly.

API wise, I think it would be a little strange to have two clone methods.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not safe for a third party estimator to call clone from its __sklearn_clone__. Yes it could call BaseEstimator.__sklearn_clone__ but then it needs to import BaseEstimator (or copy-paste).

Having a public clone_default_implementation or clone_parametrized wouldn't hurt...? But I don't know how important it is to support non-use of BaseEstimator.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. Yea, it comes down to how much we want to support the non-use of BaseEstimator.

What do you think about adding a use_sklearn_clone_protocol parameter to clone? If use_sklearn_clone_protocol=False, then the default implementation is used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we're not caring about not having to inherit from BaseEstimator as much any more?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's keep it like this and see is someone complains that they can't use it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with keeping this PR as is and develop a solution if a third party estimator developer requires it.

@thomasjpfan thomasjpfan marked this pull request as ready for review January 9, 2023 02:13

def fit_transform(self, *args, **kwargs):
return self.fitted_transformer.transform(*args, **kwargs)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth noting that most implementations would call super().__sklearn_clone__().

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me, most estimators would not touch __sklearn_clone__ at all. What use case would most implementations call super().__sklearn_clone__()?

@betatim
Copy link
Member

betatim commented Jan 9, 2023

I think the documentation would be improved with a paragraph that explains what the call to __sklearn_clone__ has to return. From the frozen estimator example I can't work out what it should return. I had a quick read through scikit-learn/enhancement_proposals#67 and that also doesn't seem to specify what should be returned.

It is probably pretty simple once you know it and that is why it never comes up in the SLEP or in the docs here. But as someone who hasn't got the first idea, I feel lost.

@amueller
Copy link
Member

Fix merge conflicts?

@lorentzenchr lorentzenchr linked an issue Jan 13, 2023 that may be closed by this pull request
@adrinjalali adrinjalali merged commit 3f82f84 into scikit-learn:main Jan 19, 2023
@adrinjalali
Copy link
Member

Excited!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

__sklearn_clone__ protocol proposal
5 participants