-
Notifications
You must be signed in to change notification settings - Fork 52
Conversation
This seems like a feature Tune itself could use. It'd be odd to limit it to just tune-sklearn. Great work! |
tune_sklearn/tune_search.py
Outdated
@@ -265,6 +268,11 @@ class TuneSearchCV(TuneBaseSearchCV): | |||
determined by 'Pipeline.warm_start' or 'Pipeline.partial_fit' | |||
capabilities, which are by default not supported by standard | |||
SKlearn. Defaults to True. | |||
stop_on_plateau (bool|dict|TrialPlateauStopper): Stop trials early if |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps it would be a good idea to just let users pass their own Stopper instance? That way users could just extend the Stopper class for their own purposes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you mean instead of these arguments we just support stopper
or so and document how to pass a TrialPlateauStopper
for this use case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I believe that would be a good idea. That way users could define their own Stoppers, or import other Stoppers from Tune - and we would not need to add special support for each of them in tune-sklearn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that makes much sense.
I refactored the changes. The Tune stoppers are in this PR: ray-project/ray#12750 A change where I'd like to get your feedback on is that I introduced a "default metric" called There might be a better way to achieve this, but this was straightforward to implement. Do you have any suggestions? |
@krfricke Looks great! Just to clear up how score_dict = {"accuracy": accuracy_metric, "auc": auc_metric}
ts = TuneSearchCV(scoring=score_dict, refit=True) # Will throw an exception when fit is called: "When using multimetric scoring, refit must be the name of the scorer used to pick the best parameters. If not needed, set refit to False"
ts = TuneSearchCV(scoring=score_dict, refit="accuracy") #correct usage, accuracy will be used as the objective value, the name being average_test_accuracy That being said, the approach you have taken will of course work regardless of that value is, without concern for its type. I don't think I can think of a better one and I believe that other sklearn wrappers use a similar approach as well. |
BTW. We'll need to update the readme too, I think. |
Thanks for the explanation. I updated the README, but we will have to wait until ray-project/ray#12750 is merged so that the link works. |
The PR is merged and I think the test errors are unrelated to this PR. |
|
That's right, we just pass stoppers to Ray Tune directly. |
With the
stop_on_plateau
parameter, trials can be early stopped if their score does change over a number of trials.If
True
, a default configuration will be used. Ifdict
, the parameters will be passed to the respective stopper class. Can also be an instantiatedTrialPlateauStopper
object.I'm happy to add an example to the docs, but would like to get initial feedback/review first.
Things to consider:
Closes #98