-
-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PERF Improve runtime for early stopping in HistGradientBoosting #26163
PERF Improve runtime for early stopping in HistGradientBoosting #26163
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A first round.
# scoring is a predefined metric string | ||
if isinstance(self.scoring, str): | ||
raw_predictions_small_train = raw_predictions[ | ||
indices_small_train | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouln't we use the whole training set as in self.scoring == "loss"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, this is still the case. This branch is never reached when self.scoring == "loss"
.
That being said, I can see how using instance(..., str)
can lead to confusion. I updated this PR with 89c3dc0
(#26163) to have a more explicit variable name for "scorer is a predefined string".
if isinstance(self.scoring, str): | ||
raw_predictions_small_train = raw_predictions[ | ||
indices_small_train | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, should be now use the whole training set as in self.scoring == "loss"
?
@contextmanager | ||
def _patch_raw_predict(estimator, raw_predictions): | ||
"""Context manager that patches _raw_predict to return raw_predictions.""" | ||
orig_raw_predict = estimator._raw_predict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to digest this 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
A second review should focus on _patch_raw_predict
and the solution with the context manager.
We should also change the statement in the user guide of scorers in HGBT being much slower. Maybe better wait for #26778. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had missed this PR. I think it's a useful improvement but I do not understand the call count assertion in the tests. Maybe there is a bug?. The test seems correct (apart a mistake in the inline comment) but I think it could be improved to test the actual change in this PR.
) | ||
hist.fit(X, y, sample_weight=sample_weight) | ||
|
||
# For scorer is called three times per iteration. (2 x 3 = 6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am trying to investigate by I don't understand this. I would have expected 3 in total: one for the baseline score call at line 614 and 2 for the line score at the end of each iteration at line 789.
There might be a bug.
EDIT: I forgot about validation. So it's two for the baseline (train + validation) and two per iteration (train + val):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this test should also check the contents of hist.train_score_
and hist.validation_score_
.
EDIT: a first version of this comment was wrong (I did a local change to the code base that introduced a bug in the mock).
@@ -848,6 +850,37 @@ def test_early_stopping_on_test_set_with_warm_start(): | |||
gb.fit(X, y) | |||
|
|||
|
|||
def test_early_stopping_with_sample_weights(monkeypatch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is useful in itself but it does not check that we do not call the original hist._raw_predict
when doing the early stopping checks.
I think the test could be extended to monkeypatch with a second mocker to check that that enabling early stopping does not induce extra calls to hist._raw_predict
compared to fitting with a fixed number of iteration with early stopping disabled and with the default scoring
parameter value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In edafefa, I added a mock to _raw_predict
to make sure it is not called.
compared to fitting with a fixed number of iteration with early stopping disabled and with the default scoring parameter value.
This is unclear to me. Assuming warm_start=False
:
- With
scoring="loss"
,_raw_predict
is not called. - With
early_stopping
disabled,_raw_predict
is not called. - After this PR,
_raw_predict
is only called whenscoring
is a custom callable and early stopping is enabled. I added a new test in b4cc727 to assert this behavior.
sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
43f8ecb
to
b4cc727
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the follow. LGTM once the black formatting pbm is fixed.
Reference Issues/PRs
Fixes #25974
What does this implement/fix? Explain your changes.
This PR reuses the raw_predictions for the validation and training set when
scoring
is a predefined metric. I ran this benchmark and this PR is about ~ 2-3 times faster thanmain
.main
PR