Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EFF Speed-up MiniBatchDictionaryLearning by avoiding multiple validation #25493

Closed

Conversation

jeremiedbb
Copy link
Member

Alternative to #25490

MinibatchDictionaryLearning calls public functions that call public classes themselves. We end up validating the parameters and the input/dict twice per minibatch. When the batch size is large it has barely no impact but for small batch sizes it can be very detrimental.

For instance, here's a profiling result in the extreme case batch_size=1
prof3

This PR proposes to add a new config that allows to disable parameter validation. It has the advantage over #25490 that it does not involve any refactoring and can be useful in other places where we call public functions/classes within estimators.

For the input validation I just set the assume finite config. It's the most costly part of check_array. It's good enough for now but it leaves some overhead still and in the end I think we could add a config to disable check_array entirely (we have some check_input args in some places but I think a config option would be more convenient).

Here's the profiling result with this PR
prof5
There's still some overhead from check_array but it's a lot more reasonable

@jeremiedbb
Copy link
Member Author

Side note: adding this config option would make me much more confident in keeping adding param validation for public functions because we'd have a solution to bypass it where we do call such function within our code base

@ogrisel
Copy link
Member

ogrisel commented Jan 27, 2023

Irrespective of the minibatch dictionary learning case and the potentially expensive future parameter checks, func_sig = signature(func) alone can be non-trivial and having this option can be useful for our advanced users that might want to run many repeated low-latency calls to the scikit-learn public API.

@ogrisel
Copy link
Member

ogrisel commented Jan 27, 2023

Can you please add a changelog entry to document this new option?

@@ -134,6 +136,12 @@ def set_config(

.. versionadded:: 1.2

skip_parameter_validation : bool, default=None
If True, parameter validation of public estimators and functions will be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me it wasn't immediately clear what "parameter validation of public estimators and functions" covers and what it doesn't covers. After a look at the code I think it only covers the constructor arguments of estimators and the arguments passed to functions. Where as before I looked at the code I thought it might also disable validation of the input data (I think it doesn't really do that).

So I was thinking about alternative sentences. Here a few ideas. Maybe you like some of them or we can combine them.

  • If True, disable the validation of estimator's hyper-parameters and arguments to helper functions. This can be useful for very small inputs, but it can lead to crashes.
  • If True, disable the validation of the hyper-parameter values of estimators and arguments passed to other functions. This can be useful in situations like small input data, but it can also lead to misconfigurations going undetected.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I adapted the description to be more explicit about what kind of validation is disabled. I did not keep the like small input data part because it can be problematic even with large data if you have validation in a loop that processes samples one at a time for instance.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. In my head I had implicitly set "dataset size == number of samples passed to predict", so the case you described is part of that. However, we have at least N=1 data that this "obvious" assumption isn't obvious to everyone :D So I'm ok with this. (I can't mark this as resolved, but if I could I would)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"dataset size == number of samples passed to predict",

What I meant is that here predict might call a public function for each sample which would be catastrophic

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, there has been three ways to get around parameter validation:

  1. Creating a private function without validation and use it when we do not need validation.
  2. Add a check_parameters keyword argument everywhere.
  3. Add a global configuration. (This PR)

After years of doing option 1, I am preferring option 3 (this PR).

sklearn/base.py Show resolved Hide resolved
@jeremiedbb
Copy link
Member Author

jeremiedbb commented Feb 10, 2023

After years of doing option 1, I am preferring option 3

I kind of arrived to the same conclusion. Option 1 becomes a nightmare when you have functions calling other functions in cascade and you want to keep a public version of each of these functions...

I still think that it's a good thing to have a separation, not public/private this time but more boilerplate/algorithm. I often find that the validation/preparation and the core computational part of an algorithm are too intricate in our functions, making them hard too read and understand. So overall I think option 3 is more appropriate to easily avoid unnecessary overhead, but option 1 is still on the table, when used reasonably, to have a more modular and maintainable code base.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comments below. Other than that, can you please add a few smoke tests that check that it's possible to call a @validate_params decorated function and an estimator fit with disabled parameter validation and still get the same outcome as the original function / call or estimator fit?

I don't think we need a common test for this. Just a test for an arbitrary choice of validated function and estimator would be enough.

with config_context(assume_finite=True, skip_parameter_validation=True):
batch_cost = self._minibatch_step(
X_batch, dictionary, self._random_state, i
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I am still in favor of merging #25490 first.

But then I am not opposed to merge this one as well, but not for the _minibatch_step call anymore if #25490 is merged.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#25490 only gets rid of 1 layer of validation (there are 4 validations in total) so even if it gets merged, we'll still need the context manager around the _minibatch_step.

Comment on lines 265 to 267
If True, disable the validation of the hyper-parameters of estimators and
arguments passed to public functions. It can save time in some situations but
can lead to potential crashes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If True, disable the validation of the hyper-parameters of estimators and
arguments passed to public functions. It can save time in some situations but
can lead to potential crashes.
If True, disable the validation of the hyper-parameters values in the fit
method of estimators and for arguments passed to public functions. It can
save time in some situations but can lead to low level crashes and
exceptions with confusing error messages.

@betatim
Copy link
Member

betatim commented Feb 20, 2023

I like the technique of using with config_context to disable the checks for everything inside the context. It seems simpler/prettier than having to pass an additional parameter to all functions or the duplicate universes that private + public function combos lead to.

One thing that I am not sure what to think of is that a library that changes the user provided configuration at runtime feels a little like a Python script that modifies its sys.path (or even worse a library that modifies it). It is usually done with good intentions but also leads to all sorts of a mess and unpredictable side-effects. You could argue that the only reason I'm thinking of "config" is because it is called config_context, I guess naming matters?

I was thinking the same when reading #25617 (comment), my first reaction is that the library shouldn't be enabling array dispatching because it is a configuration flag that is "controlled by the user".

What do others think on this topic?


Related to my earlier comment: I think the flag introduced in this PR only effects the validation of the constructor parameters (things that can be set via set_params). But the docs also mention "public functions", which the first time I read it (after the weekend or some time away from this PR) makes me think it also effects X and y, etc that are passed to the public methods of an estimator. Then after 5min of thinking and staring at the code I realise that what it means is the helper functions like cross_val_score and their "hyper-parameters" (aka not X, etc). I don't have a good suggestion for how to phrase the docs, but maybe we can find a alternative formulation that doesn't lead people down this path of thinking (or maybe I'm the only one who gets lost at first).

@jeremiedbb
Copy link
Member Author

I don't have a good suggestion for how to phrase the docs, but maybe we can find a alternative formulation that doesn't lead people down this path of thinking

The phrasing changed a little bit. Let me know if it's clearer.

@@ -134,6 +136,14 @@ def set_config(

.. versionadded:: 1.2

skip_parameter_validation : bool, default=None
If True, disable the validation of the hyper-parameters types and values in the
fit method of estimators and for arguments passed to public helper functions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's important to state that check_array still runs:

For data parameters, such as X and y, only type validation is skipped and validation with check_array will continue to run.

@jeremiedbb
Copy link
Member Author

I opened an alternative PR #25815 to go beyond this one and disable inner validation in a general manner.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants