-
-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve HGBT leaf values updates #22173
Comments
If we do this, we should take care to keep backwards compatibility and use the median definition of numpy implemented here: Using |
scikit-learn/sklearn/neighbors/_partition_nodes.pyx Lines 3 to 64 in ee524f4
Cython is used to have an interface to this C++ algorithm. The side effect is that it has to modify an array somewhere, it can be a temporary array. |
We should avoid mutating
I'm not sure this will answer your question, but the reason I wrote "this requires a cython version of median() " is because to parallelize the median computation over the leaves, we'd need to do something like # Cython code:
for leave in prange(leaves):
median(...) |
Thank you for the update, @NicolasHug. We could |
My concern is that adding one single C++ function
def median_over_leaves(n_leaves, indices, y_true, raw_prediction, sample_weight):
"""
indices : list of sample_indices (ndarray) # Note that each node has sample_indices of different shape
"""
# initialise the array median_leaf
for i in prange(n_leaves):
tmp = y_true[indices] - raw_prediction[indices]
median_leaf[i] = median_cpp(tmp) As long as |
We can probably work around this "list of arrays of different sizes" issue by doing something similar to https://github.com/scikit-learn/scikit-learn/blob/5d7dc4ba327c138cf63be5cd9238200037c1eb13/sklearn/ensemble/_hist_gradient_boosting/_gradient_boosting.pyx (in particular the
I'm not sure what you mean by that, could you clarify? Do you mean that the Cython code will be compiled to C++ instead of C? |
Yes, at least that is my understanding. If we implement |
Using: py-spy record --native -o profile.svg -f speedscope -- python ./hgbt.py on: # hgbt.py
import numpy as np
import os
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.datasets import make_regression
if os.path.exists("X.npy"):
X = np.load("X.npy")
y = np.load("y.npy")
else:
X, y = make_regression(100_000, 100, n_informative=100)
np.save("X.npy", X)
np.save("y.npy", y)
clf = HistGradientBoostingRegressor(loss="absolute_error").fit(X, y) And searching for Hence, is it worth optimizing? |
We'll probably see a stronger effect with more leaves, e.g.
I think that's correct. Do you think it would be a problem? We already inline some C++ code in some places IIRC so I think our tooling can support that |
It would be worth quickly checking the impact on the overall build time of scikit-learn and the size of the generated binary. But even if it has an impact on the build time, I suspect it won't be much for our CI if we properly configure ccache everywhere (which might still not be the case as was recently discovered on the ARM64 builds). |
@NicolasHug: using |
@jjerphan Does your profiling indicate other room for improvements? We could at least remove the TODO comment hinting at this issue. |
I think the best way to tell is to try optimizing it. For what I recall of my exploration of the OpenMP parallel sections (namely |
Describe the workflow you want to enable
This issue references this
TODO
comment:scikit-learn/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Lines 65 to 67 in be89eb7
Describe your proposed solution
Some APIs (like
std::nth_element
) allows getting the median or a given quantile of a contiguous buffer in O(n) but it does need to mutate a data-structure to sort (it can be the buffer or another data-structure if using a Comparator).Could it be used there?
Moreover, why can't we compute the median in parallel here?
Additional context
Follows-up with discussions: https://github.com/scikit-learn/scikit-learn/pull/20811/files/4df17828e439ad09a7784e99cc3d8d956eb50fe0#r781036357
/cc @NicolasHug who might be interested in following this issue as he initially worked on the update and authored this comment.
The text was updated successfully, but these errors were encountered: