Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Add Categorical support for HistGradientBoosting #18394

Merged
merged 85 commits into from
Nov 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
2b3e044
ENH Adds categorical support for hist gradient boosting
thomasjpfan Sep 12, 2020
662f5a7
TST Remove need for check
thomasjpfan Sep 14, 2020
8608549
STY Lint fixes
thomasjpfan Sep 14, 2020
cca3cc5
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ca…
NicolasHug Sep 15, 2020
270895a
Cleaner condition for unknown values in binner.transform
NicolasHug Sep 15, 2020
692851e
Fix condition and error message in binner
NicolasHug Sep 15, 2020
0aaabf6
Slightly cleaned test + more thorough tests
NicolasHug Sep 15, 2020
e69d5ef
fixed test
NicolasHug Sep 15, 2020
af21dbc
Merge remote-tracking branch 'upstream/master' into cat_hgbt_256
thomasjpfan Sep 18, 2020
6b12abf
MNT Moves make_known_categories
thomasjpfan Sep 18, 2020
dd4898b
CLN Removes scanning both directions
thomasjpfan Sep 18, 2020
6cea455
TST Adjust tests
thomasjpfan Sep 18, 2020
e8c5e6a
pep8
NicolasHug Sep 19, 2020
75d7682
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ca…
NicolasHug Sep 19, 2020
def2b69
Pass known categories to the BinMapper (#31)
NicolasHug Sep 21, 2020
f43d7c0
CLN Slight refactor
thomasjpfan Sep 22, 2020
a5d64ef
Merge remote-tracking branch 'upstream/master' into cat_hgbt_256
thomasjpfan Sep 22, 2020
8d90f54
FIX Fixes benchmark
thomasjpfan Sep 22, 2020
d416dec
DOC Update example
thomasjpfan Sep 22, 2020
f8fc24b
CLN Renames bitset function to include memoryview
thomasjpfan Sep 22, 2020
c062097
BinMapper docs
NicolasHug Sep 23, 2020
672907c
slightly clearer test
NicolasHug Sep 23, 2020
3857dae
docs + remove unused parameter
NicolasHug Sep 23, 2020
3907007
Added validation test for BinMapper
NicolasHug Sep 23, 2020
30df335
cleaner _check_categories function and removed unused parameter
NicolasHug Sep 23, 2020
42822f0
use NO_CST instead of 0
NicolasHug Sep 23, 2020
8b999bb
reduced diff + docs in main class
NicolasHug Sep 23, 2020
216452b
rename variable
NicolasHug Sep 23, 2020
3ab1e74
minor doc changes in splitter
NicolasHug Sep 23, 2020
c0a7795
minor cleaning to splitter code
NicolasHug Sep 23, 2020
52559dc
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ca…
NicolasHug Sep 24, 2020
c3e2094
some cleaning, docs, and hopefully more descriptive names
NicolasHug Sep 24, 2020
bda7574
Faster and simpler binned to raw bitset conversion
NicolasHug Sep 24, 2020
d820f40
cleaner and more bitset tests
NicolasHug Sep 24, 2020
1d1ec9d
some cleaning and docs
NicolasHug Sep 24, 2020
2142324
Put back prange lol
NicolasHug Sep 24, 2020
5a9098c
Fixed comment and prediction consistency about missing_go_to_left
NicolasHug Sep 24, 2020
8e551f6
Avoid too many array accesses in predictions
NicolasHug Sep 24, 2020
664fdc6
pep8
NicolasHug Sep 24, 2020
27a0870
forgot one renaming
NicolasHug Sep 24, 2020
5b6ccea
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ca…
NicolasHug Sep 28, 2020
4e15d8f
reduce diff + comments
NicolasHug Sep 28, 2020
6ec1de5
updated and added more tests
NicolasHug Sep 28, 2020
75c53f3
more consistent handling of missing_go_to_left + more tests
NicolasHug Sep 28, 2020
51cb926
Test that compares different encoding strategies
NicolasHug Sep 29, 2020
4cacf43
remove perf comparison in other test
NicolasHug Sep 29, 2020
0958285
Fix MIN_CAT_SUPPORT so that it uses hessians instead of counts
NicolasHug Sep 29, 2020
193d6cc
Put back double scanning. Comments and tests to come.
NicolasHug Sep 29, 2020
2521e93
Comments + test on split finding
NicolasHug Sep 30, 2020
64c8cf5
A more complete example and User Guide
NicolasHug Sep 30, 2020
709f9f3
fixed doc rendering issues
NicolasHug Sep 30, 2020
eefddf5
Remove use of pandas in benchmark
NicolasHug Sep 30, 2020
71e2ce6
forgot to remove this
NicolasHug Oct 1, 2020
8008c33
Remove temporary variables in bitsets
NicolasHug Oct 1, 2020
2c37bb7
Comment about prediction time
NicolasHug Oct 1, 2020
8c774d8
Fixed dtype of f_idx_map + added stride info in predictor.pyx
NicolasHug Oct 1, 2020
f0873e1
Added new benchmark with only categorical features
NicolasHug Oct 1, 2020
485385f
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ca…
NicolasHug Oct 1, 2020
7c063b5
Merge remote-tracking branch 'origin/lolol' into cat_hgbt_256
NicolasHug Oct 1, 2020
8b298c6
Fixed predict computation time issues by passing 2d views
NicolasHug Oct 1, 2020
c654cfa
Fixed shape in docstring
NicolasHug Oct 1, 2020
7361b28
more latex
NicolasHug Oct 1, 2020
74216c2
slightly faster predictions by changing check orders
NicolasHug Oct 1, 2020
31fac1d
Another pass on the UG
NicolasHug Oct 1, 2020
1a178d7
Reduced diff, improved coverage, removed unsused bits
NicolasHug Oct 2, 2020
faad968
DOC Adds whats new
thomasjpfan Oct 10, 2020
3fb4fa9
STY remove comma
thomasjpfan Oct 10, 2020
66ad644
Merge remote-tracking branch 'upstream/master' into cat_hgbt_256
thomasjpfan Oct 10, 2020
83ba38d
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ca…
NicolasHug Oct 14, 2020
d72b6cd
Use # %% as a cell separator in HGBRT categorical example
ogrisel Oct 15, 2020
054d074
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ca…
ogrisel Oct 15, 2020
9c0ebc5
typo
ogrisel Oct 15, 2020
078f804
More informative error message
ogrisel Oct 16, 2020
d5a3485
Expand categorical GBRT example
ogrisel Oct 16, 2020
8f41021
Apply suggestions from code review
ogrisel Oct 16, 2020
a7e2cb2
Move paragraph
ogrisel Oct 16, 2020
53f5f61
Merge branch 'master' of github.com:scikit-learn/scikit-learn into ca…
NicolasHug Nov 13, 2020
33f7645
addressed comments
NicolasHug Nov 13, 2020
0394102
added comment for bitsets
NicolasHug Nov 13, 2020
6b00a7b
Addressed comments
NicolasHug Nov 15, 2020
8c1d101
Addressed comments
NicolasHug Nov 15, 2020
833f655
Addressed comments
NicolasHug Nov 16, 2020
0896fcd
the the
NicolasHug Nov 16, 2020
90103d3
comments
NicolasHug Nov 16, 2020
0eccd4b
comment about cython
NicolasHug Nov 16, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions benchmarks/bench_hist_gradient_boosting_adult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import argparse
from time import time

from sklearn.model_selection import train_test_split
from sklearn.datasets import fetch_openml
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.experimental import enable_hist_gradient_boosting # noqa
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.ensemble._hist_gradient_boosting.utils import (
get_equivalent_estimator)


parser = argparse.ArgumentParser()
parser.add_argument('--n-leaf-nodes', type=int, default=31)
parser.add_argument('--n-trees', type=int, default=100)
parser.add_argument('--lightgbm', action="store_true", default=False)
parser.add_argument('--learning-rate', type=float, default=.1)
parser.add_argument('--max-bins', type=int, default=255)
parser.add_argument('--no-predict', action="store_true", default=False)
parser.add_argument('--verbose', action="store_true", default=False)
args = parser.parse_args()

n_leaf_nodes = args.n_leaf_nodes
n_trees = args.n_trees
lr = args.learning_rate
max_bins = args.max_bins
verbose = args.verbose


def fit(est, data_train, target_train, libname, **fit_params):
print(f"Fitting a {libname} model...")
tic = time()
est.fit(data_train, target_train, **fit_params)
toc = time()
print(f"fitted in {toc - tic:.3f}s")


def predict(est, data_test, target_test):
if args.no_predict:
return
tic = time()
predicted_test = est.predict(data_test)
predicted_proba_test = est.predict_proba(data_test)
toc = time()
roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1])
acc = accuracy_score(target_test, predicted_test)
print(f"predicted in {toc - tic:.3f}s, "
f"ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}")


data = fetch_openml(data_id=179, as_frame=False) # adult dataset
X, y = data.data, data.target

n_features = X.shape[1]
n_categorical_features = len(data.categories)
n_numerical_features = n_features - n_categorical_features
print(f"Number of features: {n_features}")
print(f"Number of categorical features: {n_categorical_features}")
print(f"Number of numerical features: {n_numerical_features}")

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.2,
random_state=0)

# Note: no need to use an OrdinalEncoder because categorical features are
# already clean
is_categorical = [name in data.categories for name in data.feature_names]
est = HistGradientBoostingClassifier(
loss='binary_crossentropy',
learning_rate=lr,
max_iter=n_trees,
max_bins=max_bins,
max_leaf_nodes=n_leaf_nodes,
categorical_features=is_categorical,
early_stopping=False,
random_state=0,
verbose=verbose
)

fit(est, X_train, y_train, 'sklearn')
predict(est, X_test, y_test)

if args.lightgbm:
est = get_equivalent_estimator(est, lib='lightgbm')
est.set_params(max_cat_to_onehot=1) # dont use OHE
categorical_features = [f_idx
for (f_idx, is_cat) in enumerate(is_categorical)
if is_cat]
fit(est, X_train, y_train, 'lightgbm',
categorical_feature=categorical_features)
predict(est, X_test, y_test)
84 changes: 84 additions & 0 deletions benchmarks/bench_hist_gradient_boosting_categorical_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import argparse
from time import time

from sklearn.preprocessing import KBinsDiscretizer
from sklearn.datasets import make_classification
from sklearn.experimental import enable_hist_gradient_boosting # noqa
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.ensemble._hist_gradient_boosting.utils import (
get_equivalent_estimator)


parser = argparse.ArgumentParser()
parser.add_argument('--n-leaf-nodes', type=int, default=31)
parser.add_argument('--n-trees', type=int, default=100)
parser.add_argument('--n-features', type=int, default=20)
parser.add_argument('--n-cats', type=int, default=20)
parser.add_argument('--n-samples', type=int, default=10_000)
parser.add_argument('--lightgbm', action="store_true", default=False)
parser.add_argument('--learning-rate', type=float, default=.1)
parser.add_argument('--max-bins', type=int, default=255)
parser.add_argument('--no-predict', action="store_true", default=False)
parser.add_argument('--verbose', action="store_true", default=False)
args = parser.parse_args()

n_leaf_nodes = args.n_leaf_nodes
n_features = args.n_features
n_categories = args.n_cats
n_samples = args.n_samples
n_trees = args.n_trees
lr = args.learning_rate
max_bins = args.max_bins
verbose = args.verbose


def fit(est, data_train, target_train, libname, **fit_params):
print(f"Fitting a {libname} model...")
tic = time()
est.fit(data_train, target_train, **fit_params)
toc = time()
print(f"fitted in {toc - tic:.3f}s")


def predict(est, data_test):
# We don't report accuracy or ROC because the dataset doesn't really make
# sense: we treat ordered features as un-ordered categories.
if args.no_predict:
return
tic = time()
est.predict(data_test)
toc = time()
print(f"predicted in {toc - tic:.3f}s")


X, y = make_classification(n_samples=n_samples, n_features=n_features,
random_state=0)

X = KBinsDiscretizer(n_bins=n_categories, encode='ordinal').fit_transform(X)

print(f"Number of features: {n_features}")
print(f"Number of samples: {n_samples}")

is_categorical = [True] * n_features
est = HistGradientBoostingClassifier(
loss='binary_crossentropy',
learning_rate=lr,
max_iter=n_trees,
max_bins=max_bins,
max_leaf_nodes=n_leaf_nodes,
categorical_features=is_categorical,
early_stopping=False,
random_state=0,
verbose=verbose
)

fit(est, X, y, 'sklearn')
predict(est, X)

if args.lightgbm:
est = get_equivalent_estimator(est, lib='lightgbm')
est.set_params(max_cat_to_onehot=1) # dont use OHE
categorical_features = list(range(n_features))
fit(est, X, y, 'lightgbm',
categorical_feature=categorical_features)
predict(est, X)
68 changes: 68 additions & 0 deletions doc/modules/ensemble.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,68 @@ multiplying the gradients (and the hessians) by the sample weights. Note that
the binning stage (specifically the quantiles computation) does not take the
weights into account.

.. _categorical_support_gbdt:

Categorical Features Support
----------------------------

:class:`HistGradientBoostingClassifier` and
:class:`HistGradientBoostingRegressor` have native support for categorical
features: they can consider splits on non-ordered, categorical data.

For datasets with categorical features, using the native categorical support
is often better than relying on one-hot encoding
(:class:`~sklearn.preprocessing.OneHotEncoder`), because one-hot encoding
requires more tree depth to achieve equivalent splits. It is also usually
better to rely on the native categorical support rather than to treat
categorical features as continuous (ordinal), which happens for ordinal-encoded
categorical data, since categories are nominal quantities where order does not
matter.

To enable categorical support, a boolean mask can be passed to the
`categorical_features` parameter, indicating which feature is categorical. In
the following, the first feature will be treated as categorical and the
second feature as numerical::

>>> gbdt = HistGradientBoostingClassifier(categorical_features=[True, False])

Equivalently, one can pass a list of integers indicating the indices of the
categorical features::

>>> gbdt = HistGradientBoostingClassifier(categorical_features=[0])

The cardinality of each categorical feature should be less than the `max_bins`
parameter, and each categorical feature is expected to be encoded in
`[0, max_bins - 1]`. To that end, it might be useful to pre-process the data
with an :class:`~sklearn.preprocessing.OrdinalEncoder` as done in
:ref:`sphx_glr_auto_examples_ensemble_plot_gradient_boosting_categorical.py`.

If there are missing values during training, the missing values will be
treated as a proper category. If there are no missing values during training,
then at prediction time, missing values are mapped to the child node that has
the most samples (just like for continuous features). When predicting,
categories that were not seen during fit time will be treated as missing
values.

**Split finding with categorical features**: The canonical way of considering
categorical splits in a tree is to consider
all of the :math:`2^{K - 1} - 1` partitions, where :math:`K` is the number of
categories. This can quickly become prohibitive when :math:`K` is large.
Fortunately, since gradient boosting trees are always regression trees (even
for classification problems), there exist a faster strategy that can yield
equivalent splits. First, the categories of a feature are sorted according to
the variance of the target, for each category `k`. Once the categories are
sorted, one can consider *continuous partitions*, i.e. treat the categories
as if they were ordered continuous values (see Fisher [Fisher1958]_ for a
formal proof). As a result, only :math:`K - 1` splits need to be considered
instead of :math:`2^{K - 1} - 1`. The initial sorting is a
:math:`\mathcal{O}(K \log(K))` operation, leading to a total complexity of
:math:`\mathcal{O}(K \log(K) + K)`, instead of :math:`\mathcal{O}(2^K)`.

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_ensemble_plot_gradient_boosting_categorical.py`

.. _monotonic_cst_gbdt:

Monotonic Constraints
Expand Down Expand Up @@ -1092,6 +1154,10 @@ that the feature is supposed to have a positive / negative effect on the
probability to belong to the positive class. Monotonic constraints are not
supported for multiclass context.

.. note::
Since categories are unordered quantities, it is not possible to enforce
monotonic constraints on categorical features.

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_ensemble_plot_monotonic_constraints.py`
Expand Down Expand Up @@ -1158,6 +1224,8 @@ Finally, many parts of the implementation of
.. [LightGBM] Ke et. al. `"LightGBM: A Highly Efficient Gradient
BoostingDecision Tree" <https://papers.nips.cc/paper/
6907-lightgbm-a-highly-efficient-gradient-boosting-decision-tree>`_
.. [Fisher1958] Walter D. Fisher. `"On Grouping for Maximum Homogeneity"
<http://www.csiss.org/SPACE/workshops/2004/SAC/files/fisher.pdf>`_

.. _voting_classifier:

Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ Changelog
:mod:`sklearn.ensemble`
.......................

- |MajorFeature| :class:`ensemble.HistGradientBoostingRegressor` and
:class:`ensemble.HistGradientBoostingClassifier` now have native
support for categorical features with the `categorical_features`
parameter. :pr:`18394` by `Nicolas Hug`_ and `Thomas Fan`_.

- |Feature| :class:`ensemble.HistGradientBoostingRegressor` and
:class:`ensemble.HistGradientBoostingClassifier` now support the
method `staged_predict`, which allows monitoring of each stage.
Expand Down
Loading