Skip to content

Commit

Permalink
[MRG] Use resample to compute the small training set in HistGBT (#14194)
Browse files Browse the repository at this point in the history
  • Loading branch information
Johann Faouzi authored and ogrisel committed Jul 8, 2019
1 parent 90b04a3 commit db2342f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 10 deletions.
5 changes: 5 additions & 0 deletions doc/whats_new/v0.22.rst
Expand Up @@ -60,6 +60,11 @@ Changelog
parameter called `warm_start` that enables warm starting. :pr:`14012` by
:user:`Johann Faouzi <johannfaouzi>`.

- |Enhancement| :class:`ensemble.HistGradientBoostingClassifier` the training
loss or score is now monitored on a class-wise stratified subsample to
preserve the class balance of the original training set. :pr:`14194`
by :user:`Johann Faouzi <johannfaouzi>`.

:mod:`sklearn.linear_model`
...........................

Expand Down
24 changes: 14 additions & 10 deletions sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Expand Up @@ -5,8 +5,9 @@

import numpy as np
from timeit import default_timer as time
from ...base import BaseEstimator, RegressorMixin, ClassifierMixin
from ...utils import check_X_y, check_random_state, check_array
from ...base import (BaseEstimator, RegressorMixin, ClassifierMixin,
is_classifier)
from ...utils import check_X_y, check_random_state, check_array, resample
from ...utils.validation import check_is_fitted
from ...utils.multiclass import check_classification_targets
from ...metrics import check_scoring
Expand Down Expand Up @@ -386,15 +387,18 @@ def _get_small_trainset(self, X_binned_train, y_train, seed):
with scorers.
"""
subsample_size = 10000
rng = check_random_state(seed)
indices = np.arange(X_binned_train.shape[0])
if X_binned_train.shape[0] > subsample_size:
# TODO: not critical but stratify using resample()
indices = rng.choice(indices, subsample_size, replace=False)
X_binned_small_train = X_binned_train[indices]
y_small_train = y_train[indices]
X_binned_small_train = np.ascontiguousarray(X_binned_small_train)
return X_binned_small_train, y_small_train
indices = np.arange(X_binned_train.shape[0])
stratify = y_train if is_classifier(self) else None
indices = resample(indices, n_samples=subsample_size,
replace=False, random_state=seed,
stratify=stratify)
X_binned_small_train = X_binned_train[indices]
y_small_train = y_train[indices]
X_binned_small_train = np.ascontiguousarray(X_binned_small_train)
return X_binned_small_train, y_small_train
else:
return X_binned_train, y_train

def _check_early_stopping_scorer(self, X_binned_small_train, y_small_train,
X_binned_val, y_val):
Expand Down
Expand Up @@ -7,6 +7,7 @@
from sklearn.ensemble import HistGradientBoostingRegressor
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper
from sklearn.utils import shuffle


X_classification, y_classification = make_classification(random_state=0)
Expand Down Expand Up @@ -190,3 +191,32 @@ def test_zero_division_hessians(data):
X, y = data
gb = HistGradientBoostingClassifier(learning_rate=100, max_iter=10)
gb.fit(X, y)


def test_small_trainset():
# Make sure that the small trainset is stratified and has the expected
# length (10k samples)
n_samples = 20000
original_distrib = {0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4}
rng = np.random.RandomState(42)
X = rng.randn(n_samples).reshape(n_samples, 1)
y = [[class_] * int(prop * n_samples) for (class_, prop)
in original_distrib.items()]
y = shuffle(np.concatenate(y))
gb = HistGradientBoostingClassifier()

# Compute the small training set
X_small, y_small = gb._get_small_trainset(X, y, seed=42)

# Compute the class distribution in the small training set
unique, counts = np.unique(y_small, return_counts=True)
small_distrib = {class_: count / 10000 for (class_, count)
in zip(unique, counts)}

# Test that the small training set has the expected length
assert X_small.shape[0] == 10000
assert y_small.shape[0] == 10000

# Test that the class distributions in the whole dataset and in the small
# training set are identical
assert small_distrib == pytest.approx(original_distrib)

0 comments on commit db2342f

Please sign in to comment.