New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] add support for lists of dictionaries to RandomizedSearchCV #14549
[MRG] add support for lists of dictionaries to RandomizedSearchCV #14549
Conversation
Awesome! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add whats new?
sklearn/model_selection/_search.py
Outdated
Dictionary with parameters names (string) as keys and distributions | ||
or lists of parameters to try. Distributions must provide a ``rvs`` | ||
method for sampling (such as those from scipy.stats.distributions). | ||
If a list is given, it is sampled uniformly. | ||
If a list of dicts is given, for each parameter, one of the dicts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is slightly unclear. It looks like, first dicts are sampled uniformly, then the parameters are sampled based on that dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise lgtm
sklearn/model_selection/_search.py
Outdated
for _ in range(self.n_iter): | ||
dist = self.param_distributions[ | ||
rnd.randint(len(self.param_distributions))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an awkwardly numpy way of expressing random.choose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
all_lists = all( | ||
all(not hasattr(v, "rvs") for v in dist.values()) | ||
for dist in self.param_distributions) | ||
rng = check_random_state(self.random_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed this to be more consistent with the rest of the library
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nitpic feel free to merge without addressing
@@ -210,6 +210,9 @@ Changelog | |||
plot model scalability (see learning_curve example). | |||
:pr:`13938` by :user:`Hadrien Reboul <H4dr1en>`. | |||
|
|||
- |Enhancement| :class:`model_selection.RandomizedSearchCV` now accepts lists | |||
of parameter distributions. :pr:`14549` by `Andreas Müller`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe
lists of dicts to sample from multiple parameter spaces
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm unconvinced ;)
for key in dist: | ||
if (not isinstance(dist[key], Iterable) | ||
and not hasattr(dist[key], 'rvs')): | ||
raise TypeError('Parameter value is not iterable ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... must be an iterable or a distribution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is copy & pasted from ParameterGrid. Not sure if your version is any clearer and I think being semi-consistent between the two is good.
OH YEAH! |
Follow up on #12759 with a slightly simplified interface.
This makes the API of RandomizedSearchCV a superset of GridSearchCV which makes it more convenient to use.