ENH set_params method on BaseEstimator, deprecate estimator params to fit #306

Closed
wants to merge 3 commits into
from

Conversation

Projects
None yet
7 participants
@larsmans
Member

larsmans commented Aug 10, 2011

As proposed on the mailing list, here's a patch that introduces a public set_params method on all built-in estimators. The rationale is summarized in the c94ac99f08505073f7c46b004c202c1a54e897a0's commit message. Passing data-independent parameters to fit, partial_fit or fit_transform is now deprecated. set_params returns self, so it can be chained.

Some more remarks: estimators behave inconsistently in that some do _set_params first, then input validation, while other do both steps in reverse order. I've removed the params keyword from MiniBatchKMeans' fit and partial_fit, since that hasn't featured in any release yet. (partial_fit accepted a params keyword arg, but then ignored it, btw.)

KMeans got a k parameter to fit. The remaining problem is DBSCAN, where could I not decide whether the metric, eps and min_samples parameters should be given to __init__ or fit. Maybe @robertlayton has an opinion on that?

ENH set_params method on BaseEstimator, deprecate estimator params to…
… fit

Passing estimator parameters to fit leads to a situation where fit might
half-fail when it has set the parameters, but then fails to do input
validation.
@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Aug 10, 2011

Member

On Wed, Aug 10, 2011 at 06:21:13AM -0700, larsmans wrote:

KMeans got a k parameter to fit. The remaining problem is DBSCAN, where could I not decide whether the metric, eps and min_samples parameters should be given to __init__ or fit. Maybe @robertlayton has an opinion on that?

When in doubt, pass to init: parameters should be settable
independent of fitting, unless they absolutely depend on the data.

Member

GaelVaroquaux commented Aug 10, 2011

On Wed, Aug 10, 2011 at 06:21:13AM -0700, larsmans wrote:

KMeans got a k parameter to fit. The remaining problem is DBSCAN, where could I not decide whether the metric, eps and min_samples parameters should be given to __init__ or fit. Maybe @robertlayton has an opinion on that?

When in doubt, pass to init: parameters should be settable
independent of fitting, unless they absolutely depend on the data.

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Aug 10, 2011

Actually, I had missed the commit using _set_params in GridSearchCV. I don't really like these lines (we tried to avoid the use of deep_copy, which can be a killer on big arrays, and to use clone).

That's for another pull request, but I think that in GridSearchCV the following lines:

    # update parameters of the classifier after a copy of its base structure
    clf = copy.deepcopy(base_clf)
    clf._set_params(**clf_params)

Should be replaced by a clone.

Actually, I had missed the commit using _set_params in GridSearchCV. I don't really like these lines (we tried to avoid the use of deep_copy, which can be a killer on big arrays, and to use clone).

That's for another pull request, but I think that in GridSearchCV the following lines:

    # update parameters of the classifier after a copy of its base structure
    clf = copy.deepcopy(base_clf)
    clf._set_params(**clf_params)

Should be replaced by a clone.

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Aug 10, 2011

Isn't that going to break, now that _check_data doesn't accept **params

Isn't that going to break, now that _check_data doesn't accept **params

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Aug 10, 2011

Owner

Yep. Will patch.

Owner

larsmans replied Aug 10, 2011

Yep. Will patch.

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Aug 10, 2011

At least eps could go in the init.

At least eps could go in the init.

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Aug 10, 2011

Owner

I believe eps can be estimated from the data; there's a method for that in the DBSCAN paper, although Robert didn't implement it because he has bad experience with its performance.

Owner

larsmans replied Aug 10, 2011

I believe eps can be estimated from the data; there's a method for that in the DBSCAN paper, although Robert didn't implement it because he has bad experience with its performance.

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Aug 10, 2011

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Aug 10, 2011

Member

In general that looks good at first glance. I wonder how badly we have broken backward compatibility :). Don't get me wrong, I think that it is a good change, I just don't want to have too many people having their code broken.

I'd like a second person to look at this before we merge.

Member

GaelVaroquaux commented Aug 10, 2011

In general that looks good at first glance. I wonder how badly we have broken backward compatibility :). Don't get me wrong, I think that it is a good change, I just don't want to have too many people having their code broken.

I'd like a second person to look at this before we merge.

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Aug 11, 2011

Member

Seems like the MiniBatchKMeans tests now fail, because they rely on the ability to pass parameters to fit. @pprett, could you have a look at this? I don't understand what test_mbkm_fixed_array_init_fit is testing for...

Member

larsmans commented Aug 11, 2011

Seems like the MiniBatchKMeans tests now fail, because they rely on the ability to pass parameters to fit. @pprett, could you have a look at this? I don't understand what test_mbkm_fixed_array_init_fit is testing for...

@pprett

This comment has been minimized.

Show comment
Hide comment
@pprett

pprett Aug 11, 2011

Member

@larsmans: I fixed the issue - test_mbkm_fixed_array_init_fit checks if passing the init parameter via fit works properly (overrides existing init param set via constructor).

I apologize but I didn't follow your discussion concerning set_params closely - should I remove all fit parameters from my code? I'll have to go through the mailing list thread on this topic - in the meanwhile you can find the patch here: https://github.com/pprett/scikit-learn/tree/larsmans-set_params

Member

pprett commented Aug 11, 2011

@larsmans: I fixed the issue - test_mbkm_fixed_array_init_fit checks if passing the init parameter via fit works properly (overrides existing init param set via constructor).

I apologize but I didn't follow your discussion concerning set_params closely - should I remove all fit parameters from my code? I'll have to go through the mailing list thread on this topic - in the meanwhile you can find the patch here: https://github.com/pprett/scikit-learn/tree/larsmans-set_params

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Aug 11, 2011

Member

The idea is that fit should no longer try to override the estimator parameters, i.e. those already handed to __init__. Parameters to fit should always pertain to the dataset being fitted on. (So random_state is not a fit parameter, but k is. I'm not sure about init, because it may take an n_samples-sized array...)

Member

larsmans commented Aug 11, 2011

The idea is that fit should no longer try to override the estimator parameters, i.e. those already handed to __init__. Parameters to fit should always pertain to the dataset being fitted on. (So random_state is not a fit parameter, but k is. I'm not sure about init, because it may take an n_samples-sized array...)

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Aug 11, 2011

Member

Actually, I can summarize the reasons that @ogrisel, @agramfort and myself felt that we had to move away from fit parameters when we discussed this a while ago.

Mainly, we felt that:

a. From a machine learning point of view, we wanted the full estimation procedure to be defined in the object, to make it easier to e.g. compare strategies

b. From a coding perspective, it is easier to keep track of arguments and to propagate them when they are attached to an object.

So the general rule is that anything that is separable from the data should be an estimator parameter. A possible rule of thumb would be: when you are doing cross-validation, if you need to change the parameter at each fold, it should be a fit parameter, if not, put it in the init.

Member

GaelVaroquaux commented Aug 11, 2011

Actually, I can summarize the reasons that @ogrisel, @agramfort and myself felt that we had to move away from fit parameters when we discussed this a while ago.

Mainly, we felt that:

a. From a machine learning point of view, we wanted the full estimation procedure to be defined in the object, to make it easier to e.g. compare strategies

b. From a coding perspective, it is easier to keep track of arguments and to propagate them when they are attached to an object.

So the general rule is that anything that is separable from the data should be an estimator parameter. A possible rule of thumb would be: when you are doing cross-validation, if you need to change the parameter at each fold, it should be a fit parameter, if not, put it in the init.

@robertlayton

This comment has been minimized.

Show comment
Hide comment
@robertlayton

robertlayton Aug 12, 2011

Member

For completeness sake, I should implement the eps setting method from the DBSCAN paper.

On the question posed to me, I don't really understand the use of having parameters passed to __init__. I'm inclined to go with what @GaelVaroquaux said, to move away from this. However if there is a reason for this, please let me know.

Member

robertlayton commented Aug 12, 2011

For completeness sake, I should implement the eps setting method from the DBSCAN paper.

On the question posed to me, I don't really understand the use of having parameters passed to __init__. I'm inclined to go with what @GaelVaroquaux said, to move away from this. However if there is a reason for this, please let me know.

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Aug 12, 2011

Member

On Thu, Aug 11, 2011 at 11:03:44PM -0700, robertlayton wrote:

I don't really understand the use of having parameters passed to __init__.

All metaparameters that define how on algorithm or model behaves should
be passed to init. That way you have an object that is ready for
'consumption'.

I'm inclined to go with what @GaelVaroquaux said, to move away from this.

What I meant was not to get rid of the init parameters, just to pass in
strings that would be converted to kernels during fit. Sorry if I was
unclear.

Member

GaelVaroquaux commented Aug 12, 2011

On Thu, Aug 11, 2011 at 11:03:44PM -0700, robertlayton wrote:

I don't really understand the use of having parameters passed to __init__.

All metaparameters that define how on algorithm or model behaves should
be passed to init. That way you have an object that is ready for
'consumption'.

I'm inclined to go with what @GaelVaroquaux said, to move away from this.

What I meant was not to get rid of the init parameters, just to pass in
strings that would be converted to kernels during fit. Sorry if I was
unclear.

@robertlayton

This comment has been minimized.

Show comment
Hide comment
@robertlayton

robertlayton Aug 12, 2011

Member

If i understand the problem correctly, parameters are generally given to __init__, and then if data dependent parameters are needed, they are given to fit.
To answer the question quickly for DBSCAN, eps is a fit parameter. min_points could be one as well, but I would imagine that metric is fairly static.

However I'm a little confused as to why not just have everything in init then?
Perhaps my use case for objects in a workflow is different to that expected. I would of expected that, for each tuple of dataset/algorithm/parameters (one experiment), a separate object of the algorithm, initialized with the given parameters, is created. From the sound of this, one would create an object for each algorithm with parameters consistent across all tested datasets, and then updates the parameters on a per-dataset basis for all experiments using that algorithm. However, this isn't consistent, because attributes like centroids_ are set in k-means. Reusing an object for a different experiment would break the assumption that attributes belong to the object - they would instead belong to the object and the parameters given to it.

Member

robertlayton commented Aug 12, 2011

If i understand the problem correctly, parameters are generally given to __init__, and then if data dependent parameters are needed, they are given to fit.
To answer the question quickly for DBSCAN, eps is a fit parameter. min_points could be one as well, but I would imagine that metric is fairly static.

However I'm a little confused as to why not just have everything in init then?
Perhaps my use case for objects in a workflow is different to that expected. I would of expected that, for each tuple of dataset/algorithm/parameters (one experiment), a separate object of the algorithm, initialized with the given parameters, is created. From the sound of this, one would create an object for each algorithm with parameters consistent across all tested datasets, and then updates the parameters on a per-dataset basis for all experiments using that algorithm. However, this isn't consistent, because attributes like centroids_ are set in k-means. Reusing an object for a different experiment would break the assumption that attributes belong to the object - they would instead belong to the object and the parameters given to it.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Aug 12, 2011

Member

Any parameter that can have a value prior having access to the data should be an __init__ parameter. fit parameter should be restricted to directly data dependent stuff. For instance a Gram matrix or an affinity matrix which are precomputed from the data matrix X are data dependent. eps is not directly data dependent (although the optimal value is).

Any attribute that ends with _ is expected to be overridden when you can call fit a second time.

Member

ogrisel commented Aug 12, 2011

Any parameter that can have a value prior having access to the data should be an __init__ parameter. fit parameter should be restricted to directly data dependent stuff. For instance a Gram matrix or an affinity matrix which are precomputed from the data matrix X are data dependent. eps is not directly data dependent (although the optimal value is).

Any attribute that ends with _ is expected to be overridden when you can call fit a second time.

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Aug 12, 2011

Member

On Fri, Aug 12, 2011 at 02:05:57AM -0700, ogrisel wrote:

Any parameter that can have a value prior having access to the data should be an __init__ parameter. fit parameter should be restricted to directly data dependent stuff. For instance a Gram matrix or an affinity matrix which are precomputed from the data matrix X are data dependent. eps is not directly data dependent (although the optimal value is).

Any attribute that ends with _ is expected to be overridden when you can call fit a second time.

I like both of these phrasings. Do you want to add them to the docs?

G

Member

GaelVaroquaux commented Aug 12, 2011

On Fri, Aug 12, 2011 at 02:05:57AM -0700, ogrisel wrote:

Any parameter that can have a value prior having access to the data should be an __init__ parameter. fit parameter should be restricted to directly data dependent stuff. For instance a Gram matrix or an affinity matrix which are precomputed from the data matrix X are data dependent. eps is not directly data dependent (although the optimal value is).

Any attribute that ends with _ is expected to be overridden when you can call fit a second time.

I like both of these phrasings. Do you want to add them to the docs?

G

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Aug 12, 2011

Member

Sure. Let me do it.

Member

ogrisel commented Aug 12, 2011

Sure. Let me do it.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@robertlayton

This comment has been minimized.

Show comment
Hide comment
@robertlayton

robertlayton Aug 12, 2011

Member

Nice description. Given that, eps is not a data dependent feature and neither is anything else for DBSCAN. While the eps value can be calculated using the method in the original paper using the dataset (coming soon), having eps=None would fix this.

On the k-means algorithm - if a specific set of centroids is given as the initialization, this could be a dataset dependent parameter. Should init (not __init__) therefore be an optional fit parameter? If its set to None, then whatever method is given to __init__ is used.

Member

robertlayton commented Aug 12, 2011

Nice description. Given that, eps is not a data dependent feature and neither is anything else for DBSCAN. While the eps value can be calculated using the method in the original paper using the dataset (coming soon), having eps=None would fix this.

On the k-means algorithm - if a specific set of centroids is given as the initialization, this could be a dataset dependent parameter. Should init (not __init__) therefore be an optional fit parameter? If its set to None, then whatever method is given to __init__ is used.

@vene

This comment has been minimized.

Show comment
Hide comment
@vene

vene Aug 12, 2011

Member

+1 for eps=None triggering data-based calculation. It's in the spirit of n_nonzero_coefs=None in sparse linear models, and other places too. I find it a sensible default.

Member

vene commented Aug 12, 2011

+1 for eps=None triggering data-based calculation. It's in the spirit of n_nonzero_coefs=None in sparse linear models, and other places too. I find it a sensible default.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Aug 12, 2011

Member

Yes, either eps=None or eps='auto' (there already are several estimators that have parameters that accept the 'auto' keyword as a marker for data-driven tuning of the parameter.

For k-means this is a bit special indeed since the algorithm does not have a global minimum and is sensitive to the initial condition state. So indeed the init param could be data dependent or not (when set to k-means++ & random for instance). I am a bit undecided whether we should make it a fit parameter as well.

Member

ogrisel commented Aug 12, 2011

Yes, either eps=None or eps='auto' (there already are several estimators that have parameters that accept the 'auto' keyword as a marker for data-driven tuning of the parameter.

For k-means this is a bit special indeed since the algorithm does not have a global minimum and is sensitive to the initial condition state. So indeed the init param could be data dependent or not (when set to k-means++ & random for instance). I am a bit undecided whether we should make it a fit parameter as well.

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Aug 12, 2011

Member

Actually, there is a global minimum for the k-means problem, which is determined by a squared error function. k-means++ is explicitly designed with this goal in mind, and has been proven to approach this to within a logarithmic bound.

Member

larsmans commented Aug 12, 2011

Actually, there is a global minimum for the k-means problem, which is determined by a squared error function. k-means++ is explicitly designed with this goal in mind, and has been proven to approach this to within a logarithmic bound.

Rm k param from KMeans.fit again
Two tests fail mysteriously...
@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Aug 12, 2011

Member

Alright, I just pushed a new commit where the k parameter is gone from KMeans.fit and the init param is only available through the backward compat call to _set_params, which should go eventually. I also tried to clean up the tests, but I got some mysterious failures that I couldn't resolve; in fact, any attempt to fix test_mbk_means_fixed_array_init broke another test, test_sparse_mbk_means_callable_init (!).

@pprett, could you please have a look at this?

Member

larsmans commented Aug 12, 2011

Alright, I just pushed a new commit where the k parameter is gone from KMeans.fit and the init param is only available through the backward compat call to _set_params, which should go eventually. I also tried to clean up the tests, but I got some mysterious failures that I couldn't resolve; in fact, any attempt to fix test_mbk_means_fixed_array_init broke another test, test_sparse_mbk_means_callable_init (!).

@pprett, could you please have a look at this?

@fabianp

This comment has been minimized.

Show comment
Hide comment
@fabianp

fabianp Aug 23, 2011

Member

I changed the remaining usages of **params in ce9814b and pushed it all.

Member

fabianp commented Aug 23, 2011

I changed the remaining usages of **params in ce9814b and pushed it all.

@fabianp fabianp closed this Aug 23, 2011

naught101 added a commit to naught101/scikit-learn that referenced this pull request Mar 24, 2014

naught101 added a commit to naught101/scikit-learn that referenced this pull request Apr 7, 2016

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment