Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Greatly reduces memory usage of histogram gradient boosting #18242

Closed

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented Aug 23, 2020

Reference Issues/PRs

Resolves #18152
Closes #18163

What does this implement/fix? Explain your changes.

Uses a histogram pool to improve the memory usage of histogram gradient boosting. When running this script:

from sklearn.datasets import make_classification
from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingClassifier
from memory_profiler import memory_usage

X, y = make_classification(n_classes=2,
                           n_samples=10_000,
                           n_features=400,
                           random_state=0)

hgb = HistGradientBoostingClassifier(
    max_iter=100,
    max_leaf_nodes=127,
    learning_rate=.1,
    random_state=0,
    verbose=1,
)

mems = memory_usage((hgb.fit, (X, y)))
print(f"{max(mems):.2f}, {max(mems) - min(mems):.2f} MB")

I get:

This PR

Fit 100 trees in 19.477 s, (12700 total leaves)
Time spent computing histograms: 4.300s
Time spent finding best splits:  2.203s
Time spent applying splits:      0.627s
Time spent predicting:           0.016s
678.93, 486.12 MB

Master

Fit 100 trees in 20.472 s, (12700 total leaves)
Time spent computing histograms: 13.489s
Time spent finding best splits:  2.695s
Time spent applying splits:      2.583s
Time spent predicting:           0.016s
7140.17, 6947.62 MB

@jnothman
Copy link
Member

Nice work!

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow! This is more than a 10x memory usage improvement.
Having results from the other benchmark scripts would also be nice.

doc/whats_new/v0.24.rst Outdated Show resolved Hide resolved
sklearn/ensemble/_hist_gradient_boosting/grower.py Outdated Show resolved Hide resolved
sklearn/ensemble/_hist_gradient_boosting/grower.py Outdated Show resolved Hide resolved
@glemaitre
Copy link
Member

@thomasjpfan I am surprised with the benchmark time. All different processings are slower (even x4 times more) in master but the overall execution time of the 2 PRs is the same?

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @thomasjpfan
Could you please:

  • add tests for the HistogramsPool class, in particular checks for the used_indices and available_indices sets
  • run the higgs boson benchmark for both time and memory usage (with appropriate number of repetitions where relevant please)

sklearn/ensemble/_hist_gradient_boosting/histogram.pyx Outdated Show resolved Hide resolved
sklearn/ensemble/_hist_gradient_boosting/grower.py Outdated Show resolved Hide resolved
sklearn/ensemble/_hist_gradient_boosting/grower.py Outdated Show resolved Hide resolved
@thomasjpfan
Copy link
Member Author

I simplified some of the logic in HistogramsPool to use weakref.ref. Now the TreeNodes will contain references to histograms. HistogramsPool will "own" the histograms.

@NicolasHug
Copy link
Member

NicolasHug commented Aug 23, 2020

I could be wrong but it seems to me that using weakrefs is not necessary: the goal of using weakref would be to avoid keeping a histogram object "alive" in a node if the Pool dies.

But the pool only dies at the very end (when fit ends). So I'm not sure this is enabling anything. Could you please confirm (or not) via new benchmarks?

@thomasjpfan
Copy link
Member Author

But the pool only dies at the very end (when fit ends). So I'm not sure this is enabling anything. Could you please confirm (or not) via new benchmarks?

You are correct. Weak references are not needed. This PR was updated to remove them simplifies things.

On higgs benchmark 5 times with n_trees=100 without the memory profiler. (The memory profiler this would slow down the timings):

PR

Time spent computing histograms: 21.6124 +/- 0.2307s
Time spent finding best splits:  0.1922 +/- 0.0076s
Time spent applying splits:      5.8666 +/- 0.1388s
Time spent predicting:           2.3112 +/- 0.010s
fitted in 50.126 +/- 0.660s

master

Time spent computing histograms: 22.277 +/- 0.085s
Time spent finding best splits:  0.2422 +/- 0.0020s
Time spent applying splits:      6.1414 +/- 0.0402s
Time spent predicting:           2.3508 +/- 0.0139s
fitted in 51.581 +/- 0.1897 s

For the memory usage of fit with 5 runs. This PR: 3392.7 +/- 2.7 MB, and master: 3541.2 +/- 1.9 MB.

@amueller
Copy link
Member

Awesome! my PR had a bit less memory, I think? Can you confirm or deny? This is certainly the cleaner solution.

@thomasjpfan
Copy link
Member Author

Using #18163 and running the script in the opening post I get: 1040.11, 841.03 MB

@thomasjpfan
Copy link
Member Author

I optimized this more by not needing to fill the array with zeros everytime. Now the benchmark above is:

PR

Fit 100 trees in 9.955 s, (12700 total leaves)
Time spent computing histograms: 5.072s
Time spent finding best splits:  2.540s
Time spent applying splits:      0.698s
Time spent predicting:           0.017s
675.42, 481.73 MB

master

Fit 100 trees in 21.137 s, (12700 total leaves)
Time spent computing histograms: 13.262s
Time spent finding best splits:  3.054s
Time spent applying splits:      3.065s
Time spent predicting:           0.017s
7019.68, 6825.42 MB

@thomasjpfan
Copy link
Member Author

thomasjpfan commented Aug 24, 2020

The higgs with 100 trees:

python benchmarks/bench_hist_gradient_boosting_higgsboson.py --n-trees 100

PR

Fit 100 trees in 54.164 s, (3100 total leaves)
Time spent computing histograms: 23.580s
Time spent finding best splits:  0.218s
Time spent applying splits:      6.305s
Time spent predicting:           2.603s
fitted in 54.358s

memory from a separate run:
7139.24, 3356.80 MB

Master

Fit 100 trees in 54.347 s, (3100 total leaves)
Time spent computing histograms: 23.834s
Time spent finding best splits:  0.222s
Time spent applying splits:      6.348s
Time spent predicting:           2.588s
fitted in 54.518s

memory from a separate run:
7331.99, 3550.26 MB

@thomasjpfan
Copy link
Member Author

With 7067e9c (#18242), this PR is at:

Fit 100 trees in 9.217 s, (12700 total leaves)
Time spent computing histograms: 3.940s
Time spent finding best splits:  2.699s
Time spent applying splits:      0.757s
Time spent predicting:           0.020s
369.68, 156.93 MB

with the snippet.

@ogrisel
Copy link
Member

ogrisel commented Sep 3, 2020

I was concurrently applying this fix + addressing the remaining comments + included the cyclic ref cleaning that is useful anyways.

Here is the final memory usage:

Fit 100 trees in 31.063 s, (12700 total leaves)
Time spent computing histograms: 23.325s
Time spent finding best splits:  5.056s
Time spent applying splits:      0.785s
Time spent predicting:           0.024s
284.12, 145.80 MB

@ogrisel
Copy link
Member

ogrisel commented Sep 4, 2020

I updated this PR to work on top of the recently merged #18334 that fixed the root cause of the memory efficiency issue. However, based on @thomasjpfan's tests on macOS, it seems that explicit memory management with the HistogramPool class is still useful for that platform: #18334 (comment).

@thomasjpfan can you please confirm that you get around a ~145 MB increment with this PR when running the reproducer snippet on macOS?

@lorentzenchr
Copy link
Member

lorentzenchr commented Sep 4, 2020

I ran the benchmark on my macbook.

Edit:

  • Python 3.7.2
  • macOS Catalina
  • Intel i7 @ 2.70GHz (8th generation)
  • Compiler: Apple clang version 11.0.3

@ogrisel
Copy link
Member

ogrisel commented Sep 4, 2020

@lorentzenchr that's weird. You do not reproduce the macOS / Python GC underperformance @thomasjpfan observed on his machine. Maybe it also depends on the speed of the CPU. Your CPU with 1 core seems to be much faster than mine (with 2 physical cores)!

@ogrisel
Copy link
Member

ogrisel commented Sep 4, 2020

Actually if I set OMP_NUM_THREADS=1 to run this benchmark, I run slightly faster (28s instead of 31s) which means that the sample size from this benchmark is actually too small to benefit from OMP parallelism. But at least it's not too detrimental.

@NicolasHug
Copy link
Member

With a newly-updated master and this PR, I get the following result using the snippet at the top, on my laptop (4 threads):

  • Master : ~16sec, 135MB
  • This PR: ~15sec, 145MB

I observed a similar behaviour on this benchmark which creates even more histograms. These results are somewhat consistent with those of @lorentzenchr in #18242 (comment)

While the minor time difference can be explained by the reduced number of memory allocations, I find the increase in memory usage quite surprising.

@NicolasHug
Copy link
Member

NicolasHug commented Sep 4, 2020

Also, I want to explain exactly what this PR does because it does have some impact on how memory usage evolves over time:

Let m_used_i be the total memory effectively used by the histograms at iteration i
Let m_alloc_i be the total memory allocated for the histograms at iteration i.

  • In master, m_alloc_i == m_used_i for all i
  • In this PR, m_alloc_i == max(m_used_j for j in [0, i]). So m_alloc_i >= m_used_i.

This difference can be observed by printing / plotting the mems list in the benchmarks: in master you'll observe some fluctuations, whereas in this PR you will see monotonically increasing values.

This might be detrimental in cases where later trees are smaller than the first ones, which is usually what I observed empirically: in this PR, the memory usage will never decrease even if an iteration needs less memory than a previous one. In master, it will decrease as expected.

This is another thing to take into account, on top of the upcoming benchmarks on MacOS from Thomas

@thomasjpfan
Copy link
Member Author

I am on s 2.3 GHz i9-9880H 9th Generation intel cpu on macOS Catalina with python 3.7 and I set OMP_NUM_THREADS=8.

master

Fit 100 trees in 13.369 s, (12700 total leaves)
Time spent computing histograms: 4.536s
Time spent finding best splits:  2.749s
Time spent applying splits:      0.739s
Time spent predicting:           0.016s
842.59, 648.93 MB

This PR

Fit 100 trees in 9.309 s, (12700 total leaves)
Time spent computing histograms: 4.166s
Time spent finding best splits:  2.707s
Time spent applying splits:      0.640s
Time spent predicting:           0.016s
365.10, 151.53 MB

I am a little lost in figuring out why me and @lorentzenchr results are different on the OSX.

@lorentzenchr
Copy link
Member

I updated my benchmarks runs on macOS, now with OMP_NUM_THREADS=4. In this parallel setting, I see a similar speed improvement of this PR as @thomasjpfan does. But I get quite different results for the memory usage.

@NicolasHug
Copy link
Member

I'm really confused that OMP has such a drastic effect on total time on @lorentzenchr benchmarks, especially regarding "Time spent computing histograms" which goes from 5s to 2.5s. None of the changes involved in this PR are OMP-related.

In @thomasjpfan benchmark just above, OMP does not have a significant effect on the "Time spent computing histograms", as I would expect.

@ogrisel
Copy link
Member

ogrisel commented Sep 7, 2020

I'm really confused that OMP has such a drastic effect on total time on @lorentzenchr benchmarks, especially regarding "Time spent computing histograms" which goes from 5s to 2.5s. None of the changes involved in this PR are OMP-related.

@lorentzenchr in your latest run of #18242 (comment), was your master branch up to date? Does it have the #18341 PR (parallel init) merged?

@ogrisel
Copy link
Member

ogrisel commented Sep 7, 2020

For the record, it seems that #14392 brings the same benefits (slight speed improvement by reducing the number of malloc / free and the reliance on the Python GC which is sometimes useful on macOS) while being simpler by implementing the recycling logic into the histogram builder class itself.

@lorentzenchr
Copy link
Member

@ogrisel I updated my run of #18242 (comment), nothing really changed, except that this PR is now even faster when single threaded (before 18s, now 16s).

@lorentzenchr
Copy link
Member

As #18152 and #18163 have been solved by #18334, shall we close this PR?

@NicolasHug
Copy link
Member

This PR might need a small merge / clean up but it's still "competing" with #14392 and we haven't really decided on what to do (see discussions there, basically we need more benchmarks). So I'd keep it open until then

Base automatically changed from master to main January 22, 2021 10:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

High memory usage in HistGradientBoostingClassifier
9 participants