Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PERF Improve runtime for early stopping in HistGradientBoosting #26163

Merged

Conversation

thomasjpfan
Copy link
Member

Reference Issues/PRs

Fixes #25974

What does this implement/fix? Explain your changes.

This PR reuses the raw_predictions for the validation and training set when scoring is a predefined metric. I ran this benchmark and this PR is about ~ 2-3 times faster than main.

main

python bench_hist_early_stopping.py --problem regression
Runtime: 11.068207994001568

python bench_hist_early_stopping.py --problem classification
Runtime: 32.575281541998265

PR

python bench_hist_early_stopping.py --problem regression
Runtime: 2.9262633729995287

python bench_hist_early_stopping.py --problem classification
Runtime: 13.191749556999639

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A first round.

doc/whats_new/v1.3.rst Outdated Show resolved Hide resolved
Comment on lines 601 to 605
# scoring is a predefined metric string
if isinstance(self.scoring, str):
raw_predictions_small_train = raw_predictions[
indices_small_train
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouln't we use the whole training set as in self.scoring == "loss"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, this is still the case. This branch is never reached when self.scoring == "loss".

That being said, I can see how using instance(..., str) can lead to confusion. I updated this PR with 89c3dc0 (#26163) to have a more explicit variable name for "scorer is a predefined string".

Comment on lines 775 to 778
if isinstance(self.scoring, str):
raw_predictions_small_train = raw_predictions[
indices_small_train
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, should be now use the whole training set as in self.scoring == "loss"?

Comment on lines 84 to 87
@contextmanager
def _patch_raw_predict(estimator, raw_predictions):
"""Context manager that patches _raw_predict to return raw_predictions."""
orig_raw_predict = estimator._raw_predict
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to digest this 😄

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
A second review should focus on _patch_raw_predict and the solution with the context manager.

@lorentzenchr lorentzenchr added the Waiting for Second Reviewer First reviewer is done, need a second one! label Jun 23, 2023
@lorentzenchr
Copy link
Member

We should also change the statement in the user guide of scorers in HGBT being much slower. Maybe better wait for #26778.

@github-actions
Copy link

github-actions bot commented Jul 24, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 25b7b9a. Link to the linter CI: here

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had missed this PR. I think it's a useful improvement but I do not understand the call count assertion in the tests. Maybe there is a bug?. The test seems correct (apart a mistake in the inline comment) but I think it could be improved to test the actual change in this PR.

)
hist.fit(X, y, sample_weight=sample_weight)

# For scorer is called three times per iteration. (2 x 3 = 6)
Copy link
Member

@ogrisel ogrisel Nov 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to investigate by I don't understand this. I would have expected 3 in total: one for the baseline score call at line 614 and 2 for the line score at the end of each iteration at line 789.

There might be a bug.

EDIT: I forgot about validation. So it's two for the baseline (train + validation) and two per iteration (train + val):

Copy link
Member

@ogrisel ogrisel Nov 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this test should also check the contents of hist.train_score_ and hist.validation_score_.

EDIT: a first version of this comment was wrong (I did a local change to the code base that introduced a bug in the mock).

@@ -848,6 +850,37 @@ def test_early_stopping_on_test_set_with_warm_start():
gb.fit(X, y)


def test_early_stopping_with_sample_weights(monkeypatch):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is useful in itself but it does not check that we do not call the original hist._raw_predict when doing the early stopping checks.

I think the test could be extended to monkeypatch with a second mocker to check that that enabling early stopping does not induce extra calls to hist._raw_predict compared to fitting with a fixed number of iteration with early stopping disabled and with the default scoring parameter value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In edafefa, I added a mock to _raw_predict to make sure it is not called.

compared to fitting with a fixed number of iteration with early stopping disabled and with the default scoring parameter value.

This is unclear to me. Assuming warm_start=False:

  • With scoring="loss", _raw_predict is not called.
  • With early_stopping disabled, _raw_predict is not called.
  • After this PR, _raw_predict is only called when scoring is a custom callable and early stopping is enabled. I added a new test in b4cc727 to assert this behavior.

@thomasjpfan thomasjpfan force-pushed the hist_gradient_string_early_stopping branch from 43f8ecb to b4cc727 Compare November 25, 2023 16:02
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the follow. LGTM once the black formatting pbm is fixed.

@ogrisel ogrisel merged commit 08b6157 into scikit-learn:main Nov 27, 2023
27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:ensemble Waiting for Second Reviewer First reviewer is done, need a second one!
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve early stopping of HGBT for scorers passed as string
3 participants